From 67c43a70256c8ac6b7f44a7c706521214653f0ff Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 17 May 2023 10:07:22 -0700 Subject: [PATCH] Support output cross qk in masked decoder multihead attention kernel Support cross qk in beam search for whisper model and related features Make whisper exporting tools support cross qk and some related features, * extra_decoding_ids * no_speech_prob Implement DTW kernel, UnfoldTensor kernel with unit test and shape inference Several fix related with multiple session running parallel, like: * guard multihead_attention, fused_fp16_runner_ * some memory allocation with stream awareness * add use_ep_level_unified_stream option Make timestamp Logits Processor for GPU after beam scoring move to GPU, add GPU test for whisper timestamp Logits Procesor --- .../providers/cuda/cuda_provider_options.h | 1 + .../cpu/bert/multihead_attention_helper.h | 3 +- .../cpu/transformers/beam_search.cc | 14 +- .../cpu/transformers/beam_search.h | 10 +- .../cpu/transformers/beam_search_impl_base.h | 73 +++--- .../cpu/transformers/beam_search_impl_gpt.h | 6 +- .../cpu/transformers/beam_search_impl_t5.h | 6 +- .../transformers/beam_search_impl_whisper.h | 118 ++++++++- .../transformers/beam_search_parameters.cc | 18 ++ .../cpu/transformers/generate_impl_base.h | 31 ++- .../transformers/generation_device_helper.cc | 34 +++ .../transformers/generation_device_helper.h | 57 +++++ .../cpu/transformers/generation_shared.h | 6 + .../transformers/greedy_search_impl_base.h | 89 +++---- .../cpu/transformers/greedy_search_impl_gpt.h | 6 +- .../cpu/transformers/logits_processor.cc | 136 ----------- .../cpu/transformers/logits_processor.h | 132 +++++++++- .../cpu/transformers/subgraph_base.h | 1 + .../cpu/transformers/subgraph_t5_decoder.cc | 2 +- .../cpu/transformers/subgraph_t5_decoder.h | 10 +- .../transformers/subgraph_whisper_decoder.cc | 23 +- .../contrib_ops/cuda/bert/attention.cc | 20 +- onnxruntime/contrib_ops/cuda/bert/attention.h | 2 + .../decoder_masked_multihead_attention.cc | 12 +- .../bert/decoder_masked_multihead_attention.h | 1 + ...decoder_masked_multihead_attention_impl.cu | 9 + .../decoder_masked_multihead_attention_impl.h | 1 + .../cuda/bert/multihead_attention.cc | 6 +- .../cuda/bert/multihead_attention.h | 1 + .../contrib_ops/cuda/bert/packed_attention.h | 3 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 + .../cuda/tensor/dynamic_time_warping.cc | 56 +++++ .../cuda/tensor/dynamic_time_warping.h | 26 ++ .../cuda/tensor/dynamic_time_warping_impl.cu | 141 +++++++++++ .../cuda/tensor/dynamic_time_warping_impl.h | 25 ++ onnxruntime/contrib_ops/cuda/tensor/unfold.cc | 55 +++++ onnxruntime/contrib_ops/cuda/tensor/unfold.h | 39 +++ .../contrib_ops/cuda/tensor/unfold_impl.cu | 107 ++++++++ .../contrib_ops/cuda/tensor/unfold_impl.h | 25 ++ .../cuda/transformers/beam_search.cc | 4 +- .../cuda/transformers/generation_cuda_impl.cu | 229 ++++++++++++++++++ .../cuda/transformers/generation_cuda_impl.h | 49 ++++ .../transformers/generation_device_helper.cc | 159 +++++++++++- .../transformers/generation_device_helper.h | 28 +++ .../core/graph/contrib_ops/bert_defs.cc | 10 + .../core/graph/contrib_ops/contrib_defs.cc | 119 ++++++++- onnxruntime/core/graph/contrib_ops/ms_opset.h | 4 + .../providers/cuda/cuda_execution_provider.cc | 2 +- .../cuda/cuda_execution_provider_info.cc | 4 + .../cuda/cuda_execution_provider_info.h | 2 + .../providers/cuda/cuda_provider_factory.cc | 2 + .../core/session/provider_bridge_ort.cc | 1 + .../tools/transformers/convert_generation.py | 41 +++- .../models/whisper/convert_to_onnx.py | 64 +++++ .../models/whisper/whisper_chain.py | 91 ++++++- .../models/whisper/whisper_helper.py | 12 +- .../dynamic_time_warping_op_test.cc | 119 +++++++++ .../test/contrib_ops/tensor_op_test.cc | 71 ++++++ .../test_whisper_timestamp_processor.py | 13 +- 59 files changed, 2054 insertions(+), 279 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc create mode 100644 onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h create mode 100644 onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h create mode 100644 onnxruntime/contrib_ops/cuda/tensor/unfold.cc create mode 100644 onnxruntime/contrib_ops/cuda/tensor/unfold.h create mode 100644 onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h create mode 100644 onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index eaf0e5337b8b6..5f266dd14d36d 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -32,4 +32,5 @@ struct OrtCUDAProviderOptionsV2 { int tunable_op_max_tuning_duration_ms = 0; // Max tuning duration time limit for TunableOp. int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel. // The strict mode has better accuracy but lower performance. + int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not }; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 1dc85e6d345d7..73b83057bdbe9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -103,7 +103,8 @@ Status CheckInputs(const T* query, } if (past_key_dims[2] != past_value_dims[2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length)"); + "Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ", + past_key_dims[2], " vs ", past_value_dims[2]); } if (past_key_dims[3] != head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c391f47e1927b..ff6ba86cb25ef 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -319,7 +319,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK, + cuda_device_prop_, + cuda_device_arch_}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif @@ -340,7 +345,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { update_decoder_feeds_fp16_func_ ? update_decoder_feeds_fp16_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds, expand_buffer_float_func_, expand_buffer_float16_func_, - create_beam_scorer_func_}; + create_beam_scorer_func_, + update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK, + finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK, + cuda_device_prop_, + cuda_device_arch_}; + #ifdef USE_CUDA ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_)); #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 93b7e08fabf94..f3c5b50dfb84d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -88,12 +88,16 @@ class BeamSearch : public IControlFlowKernel { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, - const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) { + const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) { update_decoder_feeds_func_ = update_decoder_feeds_func; update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func; expand_buffer_int32_func_ = expand_buffer_int32_func; expand_buffer_float_func_ = expand_buffer_float_func; expand_buffer_float16_func_ = expand_buffer_float16_func; + update_decoder_cross_qk_func_ = update_decoder_cross_qk_func; + finalize_decoder_cross_qk_func_ = finalize_decoder_cross_qk_func; } #ifdef USE_CUDA @@ -175,6 +179,10 @@ class BeamSearch : public IControlFlowKernel { BeamSearchParameters parameters_; bool has_init_decoder_ = false; + + GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + + GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 8832b4314bad3..29b38fc234de5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -17,34 +17,35 @@ struct BeamSearchState : IBeamSearchState { BeamSearchState(const IGenerationParameters& parameters, AllocatorPtr allocator, int has_decoder_masked_attention, - bool use_position) { + bool use_position, + Stream* stream) { size_t batch_beam_size = SafeInt(parameters.batch_size) * parameters.num_beams; size_t next_token_size = SafeInt(batch_beam_size) * parameters.vocab_size; - this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size); - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size); - this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size); - this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size); + this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size, stream); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, stream); + this->next_scores = AllocateBuffer(allocator, next_scores_buffer_, SafeInt(2) * batch_beam_size, stream); constexpr size_t max_parts_of_vocab = 128; size_t topk_buffer_size = SafeInt(batch_beam_size) * (max_parts_of_vocab + 1) * parameters.num_beams * 2 * 2; - this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size); + this->topk_buffer = AllocateBuffer(allocator, topk_temp_buffer_, topk_buffer_size, stream); if (allocator->Info().device.Type() == OrtDevice::GPU) { size_t sequences_elements = SafeInt(2) * batch_beam_size * parameters.max_length; - this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements); + this->sequences_device = AllocateBuffer(allocator, sequences_device_buffer_, sequences_elements, stream); } if (use_position) { - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size, stream); } - this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size); + this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size, stream); if (parameters.output_scores) { size_t elements = SafeInt(parameters.max_length - parameters.sequence_length) * parameters.batch_size * parameters.num_beams * parameters.vocab_size; - this->scores = AllocateBuffer(allocator, scores_buffer_, elements); + this->scores = AllocateBuffer(allocator, scores_buffer_, elements, stream); this->remaining_scores = this->scores; } @@ -68,35 +69,38 @@ struct BeamSearchState : IBeamSearchState { } private: - BufferUniquePtr next_token_logits_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_indices_buffer_; - BufferUniquePtr next_scores_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr beam_scores_buffer_; - BufferUniquePtr scores_buffer_; - BufferUniquePtr topk_temp_buffer_; - BufferUniquePtr sequences_device_buffer_; + IAllocatorUniquePtr next_token_logits_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_indices_buffer_; + IAllocatorUniquePtr next_scores_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr beam_scores_buffer_; + IAllocatorUniquePtr scores_buffer_; + IAllocatorUniquePtr topk_temp_buffer_; + IAllocatorUniquePtr sequences_device_buffer_; }; struct BeamSearchCpuState : IBeamSearchCpuState { Sequences sequences; - BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda) + BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda, Stream* stream) : parameters_{parameters} { - sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_); + sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size_, stream); size_t sequences_bytes = SafeInt(2) * batch_beam_size_ * parameters.max_length; - sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, true /* fill */); + sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes, stream, true /* fill */); sequences.Init(sequences_space, batch_beam_size_, parameters.sequence_length, parameters.max_length); if (is_cuda) { // buffers used by CUDA operator but not by CPU operator. - topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_)); - topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_)); - topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_)); - final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_); + topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * static_cast(batch_beam_size_), stream); + topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * static_cast(batch_beam_size_), stream); + final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size_, stream); + + size_t next_token_size = SafeInt(batch_beam_size_) * parameters.vocab_size; + next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); } } @@ -124,12 +128,13 @@ struct BeamSearchCpuState : IBeamSearchCpuState { const IGenerationParameters& parameters_; const int batch_beam_size_{parameters_.batch_size * parameters_.num_beams}; - BufferUniquePtr final_beam_scores_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr topk_scores_buffer_; - BufferUniquePtr topk_tokens_buffer_; - BufferUniquePtr topk_indices_buffer_; - BufferUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr final_beam_scores_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr topk_scores_buffer_; + IAllocatorUniquePtr topk_tokens_buffer_; + IAllocatorUniquePtr topk_indices_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; }; // Base class of beam search implementation that is common for GPT-2, T5, and Whisper. diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 205d94fae9fab..56d950ca2f41e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -215,7 +215,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; // buffer in GPU for input_ids, position_ids and attention_mask IAllocatorUniquePtr buffer; @@ -240,7 +241,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch BeamSearchState beam_state{*parameters, this->temp_space_allocator_, gpt_subgraph_.has_decoder_masked_attention_, - true /* use_position */}; + true /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 14a0db57c45de..94547887d3a90 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -144,7 +144,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -195,7 +196,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 198dec011c56f..d97f2e5f1a19e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -36,7 +36,11 @@ class BeamSearchWhisper : public BeamSearchBase { const GenerationDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, const GenerationDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, - const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func) + const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func, + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func, + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func, + const void* cuda_device_prop, + int cuda_device_arch) : BeamSearchBase(context, decoder_session_state, thread_pool, ort_stream, cuda_dumper, params, topk_func, process_logits_func, device_copy_func, device_copy_int32_func), @@ -49,7 +53,11 @@ class BeamSearchWhisper : public BeamSearchBase { update_decoder_feeds_func_(update_decoder_feeds_func), expand_buffer_float_func_(expand_buffer_float_func), expand_buffer_float16_func_(expand_buffer_float16_func), - create_beam_scorer_func_(create_beam_scorer_func) {} + create_beam_scorer_func_(create_beam_scorer_func), + update_decoder_cross_qk_func_(update_decoder_cross_qk_func), + finalize_decoder_cross_qk_func_(finalize_decoder_cross_qk_func), + cuda_device_prop_(cuda_device_prop), + cuda_device_arch_(cuda_device_arch) {} #ifdef USE_CUDA Status InitializeCuda( @@ -95,6 +103,8 @@ class BeamSearchWhisper : public BeamSearchBase { GenerationDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_; + const GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_; + const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_; const void* cuda_device_prop_ = nullptr; int cuda_device_arch_ = 0; }; @@ -122,6 +132,15 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0])); Tensor* output_scores = this->context_.Output(2, scores_shape); + TensorShape no_speech_probs_shape{parameters->batch_size}; + Tensor* no_speech_probs = this->context_.Output(4, no_speech_probs_shape); + if (no_speech_probs && no_speech_probs->MutableData()) { + ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, + "no_speech_token id out of range, it is ", parameters->no_speech_token, + ", vocab_size is ", parameters->vocab_size); + this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); + } + // Update the flag to indicate whether scores exists in output this->parameters_->output_scores = (output_scores != nullptr); @@ -136,7 +155,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, - this->IsCuda()}; + this->IsCuda(), + this->ort_stream_}; IAllocatorUniquePtr buffer; @@ -188,7 +208,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - false /* use_position */}; + false /* use_position */, + this->ort_stream_}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, @@ -222,6 +243,16 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe std::vector decoder_feeds; int current_length = parameters->sequence_length; + // for decoder subgraph output cross qk + int64_t frames_of_k = 0LL; + Tensor* cross_qk_output = nullptr; // output tensor + int64_t cross_qk_layer_head_pair_count = 0LL; + OrtValue cross_qk_buffer_value; + float* cross_qk_buffer_data = nullptr; + std::vector cross_qk_all_layer_heads; + const int32_t* cross_qk_layer_head_pairs = nullptr; + IAllocatorUniquePtr qk_layer_pointers; // if needed, device array hold the cross qk data pointers, shape of [num_layers] + std::vector decoder_fetches; if (current_length + 1 < parameters->max_length) { @@ -265,6 +296,41 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + ORT_ENFORCE(decoder_subgraph_.has_decoder_masked_attention_, "decoder subgraph: output_cross_qk could only work with has_decoder_masked_attention"); + ORT_ENFORCE(decoder_subgraph_.past_present_share_buffer_, "decoder subgraph: output_cross_qk could only work with past_present_share_buffer"); + + cross_qk_layer_head_pair_count = parameters->num_layers * parameters->num_heads; + const auto* input_tensor_cross_qk_layer_head = this->context_.template Input(12); + ORT_ENFORCE(input_tensor_cross_qk_layer_head != nullptr, "Must specify input cross_qk_layer_head"); + cross_qk_layer_head_pair_count = input_tensor_cross_qk_layer_head->Shape()[0]; + cross_qk_layer_head_pairs = input_tensor_cross_qk_layer_head->template Data(); // it is on GPU + + size_t decoder_input_first_cross_key = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + (2 * decoder_subgraph_.num_layers); + auto first_cross_attention_key = decoder_feeds[decoder_input_first_cross_key].GetMutable(); + frames_of_k = first_cross_attention_key->Shape()[2]; + + TensorShape layer_cross_qk_shape{ + static_cast(parameters->BatchBeamSize()), + static_cast(parameters->num_heads), + 1LL, + static_cast(frames_of_k)}; + for (int layer = 0; layer < decoder_subgraph_.num_layers; layer++) { + OrtValue cross_qk_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), layer_cross_qk_shape, this->temp_space_allocator_, cross_qk_value); + decoder_fetches.emplace_back(cross_qk_value); + } + + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_beams), + cross_qk_layer_head_pair_count, + static_cast(parameters->max_length), + frames_of_k}; + Tensor::InitOrtValue(DataTypeImpl::GetType(), cross_qk_shape, this->temp_space_allocator_, cross_qk_buffer_value); + cross_qk_buffer_data = cross_qk_buffer_value.GetMutable()->MutableData(); + } + if (decoder_subgraph_.has_decoder_masked_attention_) { size_t offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()); // Need to check cross attention's past key tensor size, suppose all layers cross attention key size are same @@ -316,6 +382,21 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe ORT_RETURN_IF_ERROR(status); + if (decoder_subgraph_.output_cross_qk_) { + int decoder_output_first_cross_qk = decoder_subgraph_.GetFirstPresentOutputIndex() + (2 * decoder_subgraph_.num_layers); + ORT_RETURN_IF_ERROR(this->update_decoder_cross_qk_func_( + iteration_counter, + this->ort_stream_, + &decoder_fetches[decoder_output_first_cross_qk], + qk_layer_pointers, + parameters->num_layers, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + cross_qk_buffer_data, + parameters->max_length, + this->temp_space_allocator_)); + } + #ifdef DEBUG_GENERATION for (int i = 0; i <= decoder_subgraph_.GetFirstPresentOutputIndex(); i++) { dumper->Print("decoder_fetches", i, true); @@ -383,6 +464,35 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe } } + if (decoder_subgraph_.output_cross_qk_) { + TensorShape cross_qk_shape{ + static_cast(parameters->batch_size), + static_cast(parameters->num_return_sequences), + cross_qk_layer_head_pair_count, + static_cast(iteration_counter - 1), + frames_of_k}; + cross_qk_output = this->context_.Output(3, cross_qk_shape); + + size_t cache_indir_input_offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast(decoder_subgraph_.num_layers) + 2; + const int* cache_indir_data = decoder_feeds[cache_indir_input_offset].GetMutable()->Data(); + auto beam_indices = this->beam_scorer_->GetNextIndicesGPU(); // currently only support on GPU + ORT_RETURN_IF_ERROR(this->finalize_decoder_cross_qk_func_( + this->ort_stream_, + iteration_counter, + parameters->sequence_length, + parameters->batch_size, + parameters->num_beams, + parameters->max_length, + static_cast(cross_qk_layer_head_pair_count), + cross_qk_layer_head_pairs, + static_cast(frames_of_k), + cross_qk_buffer_data, + cross_qk_output->MutableData(), + parameters->num_return_sequences, + cache_indir_data, + beam_indices)); + } + gsl::span final_beam_scores = beam_state.beam_scores; this->beam_scorer_->Finalize(cpu_state.sequences, final_beam_scores, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 76011a5c89b66..fa8c7d85ff3d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -25,6 +25,7 @@ void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { decoder_start_token_id = static_cast(info.GetAttrOrDefault("decoder_start_token_id", -1)); no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); vocab_size = static_cast(info.GetAttrOrDefault("vocab_size", -1)); + no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); } void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { @@ -47,6 +48,23 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { } batch_size = static_cast(dims[0]); + extra_decoding_ids = gsl::span(); + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + const Tensor* extra_decoder_tensor = context->Input(13); + if (extra_decoder_tensor != nullptr) { + const auto& extra_decoder_tensor_dims = extra_decoder_tensor->Shape().GetDims(); + ORT_ENFORCE(extra_decoder_tensor_dims.size() == 2, + "extra_decoder_tensor shall have 2 dimensions. Got ", + extra_decoder_tensor_dims.size()); + ORT_ENFORCE(extra_decoder_tensor_dims[0] == batch_size, + "extra_decoder_tensor first dim not same as batch_size. Got ", + extra_decoder_tensor_dims[0], ", expecting ", batch_size); + if (extra_decoder_tensor->Shape().Size() > 0) { + extra_decoding_ids = gsl::span(extra_decoder_tensor->Data(), extra_decoder_tensor->Shape().Size()); + } + } + } + if (this->model_type == IGenerationParameters::kModelTypeGpt) { sequence_length = static_cast(dims[1]); } else if (this->model_type == IGenerationParameters::kModelTypeWhisper) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index e889281abb023..680cb23fd887a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -33,24 +33,43 @@ gsl::span AllocateBuffer(AllocatorPtr allocator, return span; } +template +gsl::span AllocateBuffer(AllocatorPtr allocator, + IAllocatorUniquePtr& buffer, + size_t elements, + Stream* stream, + bool fill = false, + T fill_value = T{}) { + size_t bytes = SafeInt(sizeof(T)) * elements; + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + T* first = reinterpret_cast(buffer.get()); + auto span = gsl::make_span(first, elements); + + if (fill) { + std::fill_n(first, elements, fill_value); + } + + return span; +} + template inline void AllocateTempBufferForGetGreedySearchTopOne( int32_t batch_size, AllocatorPtr allocator, - BufferUniquePtr& buffer, + IAllocatorUniquePtr& buffer, gsl::span& stage_1_scores, // shape (batch_size, parts_of_vocab) gsl::span& stage_1_tokens, // shape (batch_size, parts_of_vocab) gsl::span& output_scores, // shape (batch_size) - gsl::span& output_tokens // shape (batch_size) -) { + gsl::span& output_tokens, // shape (batch_size) + Stream* stream) { constexpr size_t kMaxPartsPerVocab = 128; const size_t stage_1_element_size = kMaxPartsPerVocab * batch_size; const size_t output_element_size = batch_size; // Note: use float to allocate buffer for temporary value buffer to avoid unalignment - void* topk_data = allocator->Alloc((stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t))); - BufferUniquePtr temp_buffer(topk_data, BufferDeleter(allocator)); - buffer = std::move(temp_buffer); + size_t bytes = (stage_1_element_size + output_element_size) * (sizeof(float) + sizeof(int32_t)); + buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); + void* topk_data = buffer.get(); ElementType* stage_1_scores_data = reinterpret_cast(topk_data); stage_1_scores = gsl::make_span(stage_1_scores_data, stage_1_element_size); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 88348ad88dc27..ea80e01a7adda 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -1084,6 +1084,40 @@ template Status CreateWhisperEncoderInputs( OrtValue& encoder_input_features, OrtValue& decoder_input_ids); +Status UpdateDecoderCrossQK( + [[maybe_unused]] int iteration_number, + [[maybe_unused]] Stream* tream, + [[maybe_unused]] OrtValue* cross_qks, + [[maybe_unused]] IAllocatorUniquePtr& qk_layer_pointers, + [[maybe_unused]] int num_layers, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] float* cross_qk_buffer_data, + [[maybe_unused]] int max_length, + [[maybe_unused]] AllocatorPtr allocator) { + throw std::runtime_error("CPU beam search current not support output cross QK."); + return Status::OK(); +} + +Status FinalizeDecoderCrossQK( + [[maybe_unused]] Stream* stream, + [[maybe_unused]] int iteration_number, + [[maybe_unused]] int context_decoding_len, + [[maybe_unused]] int batch_size, + [[maybe_unused]] int num_beams, + [[maybe_unused]] int max_length, + [[maybe_unused]] int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + [[maybe_unused]] int frames_of_k, + [[maybe_unused]] const float* cross_qk_buffer_data, + [[maybe_unused]] float* cross_qk_output, + [[maybe_unused]] int num_return_sequences, + [[maybe_unused]] const int* cache_indir_data, + [[maybe_unused]] gsl::span beam_indices) { + throw std::runtime_error("CPU beam search current not support output cross QK."); + return Status::OK(); +} + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index ba1b0b662f1a5..6dfdc6b027671 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -204,6 +204,35 @@ using ExpandBufferFunc = std::function; + +using UpdateDecoderCrossQKFunc = std::function& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator)>; + +using FinalizeDecoderCrossQKFunc = std::function beam_indices)>; + } // namespace GenerationDeviceHelper // These are CPU specific device helper implementations @@ -368,6 +397,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 719dd302d274d..2846c1db642d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -53,6 +53,7 @@ struct IBeamSearchCpuState { gsl::span topk_tokens; // shape (batch_size, 2*num_beams), tokens of topk candidates. gsl::span topk_indices; // shape (batch_size, 2*num_beams), beam indices of topk candidates. gsl::span final_beam_scores; // shape (batch_size, num_beams) + gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) }; template @@ -175,6 +176,11 @@ struct IGenerationParameters { int seed = 0; int min_tokens_to_keep = 1; bool custom_sampling = false; + + bool decoder_output_cross_qk = false; + gsl::span extra_decoding_ids; + int32_t no_speech_token = -1; + void* no_speech_probs = nullptr; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index be974ed2159d9..9f372e5b3a673 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -20,26 +20,27 @@ struct SamplingState : public ISamplingState { int vocab_size, int max_iter, int seed, - bool is_cuda) { + bool is_cuda, + Stream* stream) { int total_count = batch_size * vocab_size; - this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count)); + this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count), stream); this->generator = std::default_random_engine{gsl::narrow_cast(seed)}; if (is_cuda) { - this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); - this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); - this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); - this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); - this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); - this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); - this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); - this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count), stream); + this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count), stream); + this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1), stream); + this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count), stream); + this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count), stream); + this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size), stream); + this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter), stream); + this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size), stream); this->temp_storage_bytes = 0; // TODO: Do not allocate this buffer if there's no presence_mask - this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); + this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count), stream); std::uniform_real_distribution distribution(0.0, 1.0); static_cast(distribution(this->generator)); @@ -48,25 +49,25 @@ struct SamplingState : public ISamplingState { } } else { // TODO: Some buffer can be reused for CPU - this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count)); - this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count)); + this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count), stream); + this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count), stream); } } private: - BufferUniquePtr d_index_in_buffer_; - BufferUniquePtr d_index_out_buffer_; - BufferUniquePtr d_offset_buffer_; - BufferUniquePtr d_sorted_score_buffer_; - BufferUniquePtr d_sorted_softmaxed_score_buffer_; - BufferUniquePtr d_softmaxed_score_buffer_; - BufferUniquePtr h_softmaxed_score_buffer_; - BufferUniquePtr d_sampled_buffer_; - BufferUniquePtr h_sampled_all_buffer_; - BufferUniquePtr d_indices_buffer_; - BufferUniquePtr d_presence_mask_buffer_; - BufferUniquePtr sorted_scores_buffer_; - BufferUniquePtr cumulative_probs_buffer_; + IAllocatorUniquePtr d_index_in_buffer_; + IAllocatorUniquePtr d_index_out_buffer_; + IAllocatorUniquePtr d_offset_buffer_; + IAllocatorUniquePtr d_sorted_score_buffer_; + IAllocatorUniquePtr d_sorted_softmaxed_score_buffer_; + IAllocatorUniquePtr d_softmaxed_score_buffer_; + IAllocatorUniquePtr h_softmaxed_score_buffer_; + IAllocatorUniquePtr d_sampled_buffer_; + IAllocatorUniquePtr h_sampled_all_buffer_; + IAllocatorUniquePtr d_indices_buffer_; + IAllocatorUniquePtr d_presence_mask_buffer_; + IAllocatorUniquePtr sorted_scores_buffer_; + IAllocatorUniquePtr cumulative_probs_buffer_; }; template @@ -82,24 +83,25 @@ struct GreedySearchState : public IGreedySearchState { int num_heads, int head_size, bool has_decoder_masked_self_attention, - bool is_cuda) { + bool is_cuda, + Stream* stream) { // below buffers are on cpu this->sequences_space = AllocateBuffer(cpu_allocator, sequences_space_buffer_, - SafeInt(2) * batch_size * max_length); + SafeInt(2) * batch_size * max_length, stream); memset(this->sequences_space.data(), 0, this->sequences_space.size_bytes()); this->sequences.Init(this->sequences_space, static_cast(batch_size), sequence_length, max_length); - this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size); - this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size); + this->sequence_lengths = AllocateBuffer(cpu_allocator, sequence_lengths_buffer_, batch_size, stream); + this->eos_meet = AllocateBuffer(cpu_allocator, eos_meet_buffer_, batch_size, stream); memset(this->eos_meet.data(), 0, this->eos_meet.size_bytes()); - this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size)); + this->next_tokens = AllocateBuffer(cpu_allocator, next_tokens_buffer_, SafeInt(batch_size), stream); // below buffers are on cpu or cuda size_t next_token_size = SafeInt(batch_size) * vocab_size; - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, stream); + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_size, stream); if (is_cuda) { AllocateTempBufferForGetGreedySearchTopOne( @@ -109,7 +111,8 @@ struct GreedySearchState : public IGreedySearchState { this->temp_topk_scores_buffer, this->temp_topk_tokens_buffer, this->topk_scores_buffer, - this->topk_tokens_buffer); + this->topk_tokens_buffer, + stream); // If at all we need to, we only need to re-order past state for CUDA as //`DecoderMaskedSelfAttention` is only supported on CUDA @@ -137,14 +140,14 @@ struct GreedySearchState : public IGreedySearchState { } private: - BufferUniquePtr sequences_space_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr eos_meet_buffer_; - BufferUniquePtr temp_topk_buffer_; - BufferUniquePtr staging_for_past_state_reorder_buffer_; + IAllocatorUniquePtr sequences_space_buffer_; + IAllocatorUniquePtr sequence_lengths_buffer_; + IAllocatorUniquePtr next_token_scores_buffer_; + IAllocatorUniquePtr next_tokens_buffer_; + IAllocatorUniquePtr next_positions_buffer_; + IAllocatorUniquePtr eos_meet_buffer_; + IAllocatorUniquePtr temp_topk_buffer_; + IAllocatorUniquePtr staging_for_past_state_reorder_buffer_; }; // Base class of gready search implementation that is common for both GPT-2 and Bart/T5. diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index 4504b099e32bd..69d25eaabbe02 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -211,7 +211,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->num_heads), static_cast(parameters->head_size), gpt_subgraph_.has_decoder_masked_attention_, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); SamplingState sampling_state; if (std::is_same::value) { @@ -221,7 +222,8 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ static_cast(parameters->vocab_size), static_cast(parameters->max_length - parameters->sequence_length), parameters->seed, - this->IsCuda()); + this->IsCuda(), + this->ort_stream_); } IAllocatorUniquePtr buffer; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 9f77c32f0c7cc..f39f090c78b0c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,20 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -template -gsl::span NextTokenScores::GetScores(int batch_beam_index) { - assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); - return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); -} - -template -void NextTokenScores::SetScore(int token_id, T score) { - assert(token_id >= 0 && token_id < vocab_size); - for (int i = 0; i < batch_beam_size; i++) { - scores[static_cast(i) * vocab_size + token_id] = score; - } -} - #ifdef DEBUG_GENERATION template void DumpScores(const char* name, const NextTokenScores& next_token_scores) { @@ -238,128 +224,6 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, #endif } -template -TimestampLogitsProcessor::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} - -template -void TimestampLogitsProcessor::Process(const ISequences* sequences, - NextTokenScores& next_token_scores) { - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - - const int batch_beam_size = next_token_scores.batch_beam_size; - const int vocab_size = next_token_scores.vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span beam_token_scores = next_token_scores.GetScores(i); - gsl::span sequence = sequences->GetSequence(i); - const size_t seq_length = sequence.size(); - - // Find first timestamp - size_t sample_begin = 0; - for (size_t j = 0; j < seq_length; j++) { - sample_begin++; - if (sequence[j] >= beg_token_id_) { - break; - } - } - - // Suppress tokens - for (int j = 0; j < vocab_size; j++) { - // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - - // Suppress sot, translate and transcribe tokens - if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; - if (last_was_timestamp) { - if (penultimate_was_timestamp) { - // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } else { - // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - - // Find timestamp tokens - std::vector timestamps; - for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { - timestamps.push_back(word_id); - } - } - - // Timestamps will not decrease - const size_t timestamps_len = timestamps.size(); - if (timestamps_len > 0) { - int timestamp_last = 0; - if (last_was_timestamp && !penultimate_was_timestamp) { - // For single timestamp at the end, next timestamp must not be smaller - timestamp_last = timestamps.back(); - } else { - // For paired timestamp at the end, next timestamp must be greater - timestamp_last = timestamps.back() + 1; - } - - for (int j = beg_token_id_; j < timestamp_last; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; - for (int j = last_allowed + 1; j < vocab_size; j++) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - - // Caculate logsumexp on timestamps - float timestamp_logprob = std::numeric_limits::lowest(); - { - float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { - if (beam_token_scores[j] > std::numeric_limits::lowest()) { - logsumexp += expf(beam_token_scores[j] - logprob_max); - } - } - if (logsumexp > 0.0f) { - timestamp_logprob = logf(logsumexp) + logprob_max; - } - } - - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); - if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { - beam_token_scores[j] = std::numeric_limits::lowest(); - } - } - } - -#ifdef DEBUG_GENERATION - DumpScores("TimestampLogitsProcessor", next_token_scores); -#endif -} - void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 664c497a106d4..c0ef83675399e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -6,6 +6,7 @@ #include "core/common/inlined_containers.h" #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_parameters.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" @@ -20,9 +21,17 @@ struct NextTokenScores { int batch_beam_size; int vocab_size; - gsl::span GetScores(int batch_beam_index); + gsl::span GetScores(int batch_beam_index) { + assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size); + return scores.subspan(static_cast(batch_beam_index) * vocab_size, vocab_size); + } - void SetScore(int token_id, T score); + void SetScore(int token_id, T score) { + assert(token_id >= 0 && token_id < vocab_size); + for (int i = 0; i < batch_beam_size; i++) { + scores[static_cast(i) * vocab_size + token_id] = score; + } + } }; // Interface for all scorers for beam search or beam sample. @@ -141,10 +150,125 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index); + TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) + : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, - NextTokenScores& next_token_scores) override; + NextTokenScores& next_token_scores) override { + const int beg_token_id_ = eos_token_id_ + 107; + const int not_token_id_ = eos_token_id_ + 106; + const int solm_token_id_ = eos_token_id_ + 105; + const int sot_token_id_ = eos_token_id_ + 1; + constexpr int translate_token_id_ = 50358; + constexpr int transcribe_token_id_ = 50359; + + const int batch_beam_size = next_token_scores.batch_beam_size; + const int vocab_size = next_token_scores.vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + gsl::span sequence = sequences->GetSequence(i); + const size_t seq_length = sequence.size(); + + // Find first timestamp + size_t sample_begin = 0; + for (size_t j = 0; j < seq_length; j++) { + sample_begin++; + if (sequence[j] >= beg_token_id_) { + break; + } + } + + // Suppress tokens + for (int j = 0; j < vocab_size; j++) { + // Suppress notimestamps and solm tokens + if (j == not_token_id_ || j == solm_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + + // Suppress sot, translate and transcribe tokens + if (seq_length > sample_begin) { + if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Timestamps should be in pair except the first one + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated + for (int j = beg_token_id_; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } else { + // If timestamp doesn't show up in pair, generate timestamp + for (int j = 0; j < eos_token_id_; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Find timestamp tokens + std::vector timestamps; + for (const auto& word_id : sequence) { + if (word_id >= beg_token_id_) { + timestamps.push_back(word_id); + } + } + + // Timestamps will not decrease + const size_t timestamps_len = timestamps.size(); + if (timestamps_len > 0) { + int timestamp_last = 0; + if (last_was_timestamp && !penultimate_was_timestamp) { + // For single timestamp at the end, next timestamp must not be smaller + timestamp_last = timestamps.back(); + } else { + // For paired timestamp at the end, next timestamp must be greater + timestamp_last = timestamps.back() + 1; + } + + for (int j = beg_token_id_; j < timestamp_last; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + if (seq_length == sample_begin) { + const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + for (int j = last_allowed + 1; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + // Caculate logsumexp on timestamps + float timestamp_logprob = std::numeric_limits::lowest(); + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); + for (int j = beg_token_id_; j < vocab_size; ++j) { + if (beam_token_scores[j] > std::numeric_limits::lowest()) { + logsumexp += expf(beam_token_scores[j] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + if (timestamp_logprob > max_text_token_logprob) { + for (int j = 0; j < beg_token_id_; ++j) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + +#ifdef DEBUG_GENERATION + DumpScores("TimestampLogitsProcessor", next_token_scores); +#endif + } private: int eos_token_id_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index 3c11d2d324a85..487a35c55a85f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -45,6 +45,7 @@ class Subgraph { int num_layers; bool past_present_share_buffer_; bool has_decoder_masked_attention_; + bool output_cross_qk_ = false; // Setup execution Status Setup(const SessionState& session_state, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 28acd81ae95fd..4d61ce71c69be 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -172,7 +172,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 085d8f3903976..83dae49c7dcbd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -5,6 +5,7 @@ #include "contrib_ops/cpu/transformers/subgraph_base.h" #include "contrib_ops/cpu/transformers/sequences.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { namespace contrib { @@ -20,6 +21,13 @@ class T5DecoderSubgraph : public Subgraph { has_hidden_state_(false), use_sequence_as_input_ids_(true) { first_present_output_index_ = 1; + + // Currently just using parent node's attribute. Maybe better to find it purely in subgraph. + const auto& attributes = node_in.GetAttributes(); + if (attributes.find("decoder_output_cross_qk") != attributes.end()) { + auto& attr = attributes.at("decoder_output_cross_qk"); + output_cross_qk_ = (attr.i() != 0LL); + } } // Create inputs for first inference of decoder subgraph. @@ -62,7 +70,7 @@ class T5DecoderSubgraph : public Subgraph { return first_present_output_index_; } - bool UseSequenceAsInputIds() const { + inline bool UseSequenceAsInputIds() const { return use_sequence_as_input_ids_; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc index 887a6a8984b83..7d0c62b618ee2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_whisper_decoder.cc @@ -70,8 +70,15 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr "number of inputs expected to be kFirstPastInputIndex + 4 * layers + 1, got:", num_subgraph_inputs); } - ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, - "number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs); + if (!output_cross_qk_) { + ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - first_present_output_index_) % 2 != 0, + "number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 2 * layers, got:", num_subgraph_outputs); + } else { + ORT_RETURN_IF(num_subgraph_outputs < 4 || (num_subgraph_outputs - first_present_output_index_) % 3 != 0, + "When outputing cross qk, number of outputs expected to be first_present_output_index_", + first_present_output_index_, " + 3 * layers, got:", num_subgraph_outputs); + } ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); @@ -90,7 +97,8 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr // Save parameters related to the subgraph. ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); - num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / 2; + + num_layers = (static_cast(subgraph_outputs.size()) - first_present_output_index_) / (output_cross_qk_ ? 3 : 2); // If input_ids's shape is ['batch_size', 1] then use next token as input_ids. // Otherwise in the case of shape ['batch_size', 'sequence'], use sequence as input_ids. @@ -112,12 +120,7 @@ Status WhisperDecoderSubgraph::Validate(const std::vector& subgr for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) { ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states"); - } - - for (int i = 0; i < num_subgraph_outputs; i++) { - ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, - "decoder subgraph output shall have same data type as that of encoder_hidden_states"); + "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states."); } is_output_float16_ = (subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() == float16_type); @@ -166,7 +169,7 @@ Status WhisperDecoderSubgraph::CreateInitialFeeds( AllocatorPtr buffer_allocator = std::make_shared(); size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index f0385ea5abdfb..5e33e1307fcc4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -155,8 +155,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_causal_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); } // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. @@ -175,8 +177,10 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -213,11 +217,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; - IAllocatorUniquePtr gemm_buffer; + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); int m = batch_size * sequence_length; int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); int k = parameters.input_hidden_size; - gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); + IAllocatorUniquePtr gemm_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(m * n) * sizeof(T), false, context->GetComputeStream()); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); @@ -244,7 +249,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, workSpaceSize, false, context->GetComputeStream()); + ; typedef typename ToCudaType::MappedType CudaT; AttentionData data; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 455e55ba05a66..acafb379d713f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" @@ -28,6 +29,7 @@ class Attention final : public CudaKernel, public AttentionBase { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index 4bdc6db30b036..54aad9cbaf387 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -22,6 +22,7 @@ static constexpr int kBeamWidthInputIndex = 8; static constexpr int kCacheIndirectionInputIndex = 9; static constexpr int kPastInputIndex = 5; static constexpr int kPresentOutputIndex = 1; +static constexpr int kQKOutputIndex = 3; static constexpr int kBiasIndex = 10; #define REGISTER_KERNEL_TYPED(T1, T2) \ @@ -50,6 +51,7 @@ DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const O mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); + output_qk_ = info.GetAttrOrDefault("output_qk", 0LL); } template @@ -98,7 +100,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* // This kernel is for decoding only (i.e.) sequence length has to be 1 if (sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention. Actual length is ", sequence_length); } if (parameters.head_size != parameters.v_head_size) { @@ -125,6 +127,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* TensorShape present_shape(present_dims); Tensor* present_key = context->Output(kPresentOutputIndex, present_shape); Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape); + Tensor* cross_qk = nullptr; auto cuda_stream = Stream(context); @@ -191,6 +194,13 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.v_cache = present_value_data; } + if (output_qk_) { + int64_t qk_dims[] = {parameters.batch_size, parameters.num_heads, 1, parameters.total_sequence_length}; + TensorShape qk_shape(&qk_dims[0], sizeof(qk_dims) / sizeof(qk_dims[0])); + cross_qk = context->Output(kQKOutputIndex, qk_shape); + parameters.out_qk = cross_qk->MutableData(); + } + parameters.out = output->MutableDataRaw(); // Scale diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h index 8200a66db383f..b5476e6b54c44 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.h @@ -22,6 +22,7 @@ class DecoderMaskedMultiHeadAttention final : public CudaKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_; + bool output_qk_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5827bdfee1ab5..c9a06e2c6798b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -430,6 +430,15 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // Compute the logits and start the sum. float sum = 0.f; int sum_tlength = params.is_cross_attention ? tlength - 1 : tlength; + + if (params.out_qk != nullptr) { + // store cross qk before softmax, out_qk has shape [B(batchxbeam), #Head, 1, total_sequence_length] + float* target = ((float*)params.out_qk) + ((int64_t)bhi * tlength); + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { + target[ti] = (float)(qk_smem[ti]); + } + } + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { // This is a deviation from FasterTransformer kernel implementation // but this aligns with ORT's other Attention kernels which strives to diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index 6d7f368db4dd4..4b408dafa2d81 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -37,6 +37,7 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { void* v_cache = nullptr; void* out = nullptr; + void* out_qk = nullptr; const int32_t* cache_indir = nullptr; const int32_t* mask = nullptr; // [B, total_sequence_length] diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 25f3f59165e43..e3f53ca6a63cb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -194,8 +194,10 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { constexpr bool is_unidirectional = false; - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create( - num_heads_, parameters.head_size, sm, is_unidirectional, enable_trt_flash_attention_, parameters.scale); + std::call_once(fused_fp16_runner_created_, [&]() { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional, + enable_trt_flash_attention_, parameters.scale); + }); } // In case some kernel not loaded due to shared memory limit, we need to double check here. diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 33fa3d50e4564..c162f7133cc1c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -32,6 +32,7 @@ class MultiHeadAttention final : public CudaKernel { bool disable_memory_efficient_attention_; int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h index 0cdd8021de4a1..f00c112fc73d2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.h @@ -24,10 +24,11 @@ class TrtFusedAttention { protected: MHARunner* GetFusedRunner(const cudaDeviceProp& device_prop, const PackedAttentionParameters& parameters) const; - private: + protected: bool disable_fused_runner_; bool enable_trt_flash_attention_; mutable std::unique_ptr fused_fp16_runner_; + mutable std::once_flag fused_fp16_runner_created_; }; template diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 86c1cb93e8b6f..58102d5eb496c 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -113,6 +113,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); @@ -270,6 +272,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc new file mode 100644 index 0000000000000..381316f605fc9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping.h" +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + DynamicTimeWarping, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("F", DataTypeImpl::GetTensorType()) + .TypeConstraint("I", DataTypeImpl::GetTensorType()), + DynamicTimeWarping); + +Status DynamicTimeWarping::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + ORT_ENFORCE(rank == 2 || (rank == 3 && input_dims[0] == 1), "Currently input rank must be 2, or (3 with first dim equal to 1), but got:", rank); + + const size_t rows = SafeInt(input_dims[rank == 3 ? 1 : 0]); + const size_t cols = SafeInt(input_dims[rank == 3 ? 2 : 1]); + size_t max_index_len = 0; + + size_t buffer_size_in_bytes = GetDynamicTimeWarpingBufferSize(1, rows, cols, max_index_len); + IAllocatorUniquePtr buffer = GetScratchBuffer(buffer_size_in_bytes, ctx->GetComputeStream()); + + size_t result_len = 0; + ORT_RETURN_IF_ERROR(LaunchDynamicTimeWarping( + this->Stream(ctx), this->GetDeviceProp(), 1, rows, cols, + input_tensor.Data(), buffer.get(), result_len)); + + Tensor* output_tensor = ctx->Output(0, TensorShape{2LL, SafeInt(result_len)}); + + return CUDA_CALL(cudaMemcpy2DAsync( + output_tensor->MutableData(), result_len * sizeof(int32_t), + buffer.get() + ((max_index_len - result_len) * sizeof(int32_t)), max_index_len * sizeof(int32_t), + result_len * sizeof(int32_t), 2, + cudaMemcpyDeviceToDevice, this->Stream(ctx))); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h new file mode 100644 index 0000000000000..3083e19aff6f2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class DynamicTimeWarping final : public CudaKernel { + public: + DynamicTimeWarping(const OpKernelInfo& info) : CudaKernel(info) {} + + ~DynamicTimeWarping() = default; + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu new file mode 100644 index 0000000000000..2072f001bc3a1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/dynamic_time_warping_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +__global__ void DynamicTimeWarpingInitCost(float* cost_buffer, int8_t* trace_buffer, size_t cols_plus_1) { + int r = blockIdx.x; + cost_buffer += cols_plus_1 * r; + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + cost_buffer[i] = FLT_MAX; + } + if (r == 0) { + for (size_t i = threadIdx.x; i < cols_plus_1; i += blockDim.x) { + trace_buffer[i] = 2; + } + } + if (threadIdx.x == 0) trace_buffer[cols_plus_1 * r] = 1; + if (threadIdx.x == 0 && r == 0) *cost_buffer = 0.0f; +} + +__global__ void DynamicTimeWarpingKernel( + size_t rows, + size_t cols, + size_t max_index_len, + const float* input, + float* cost_buffer, + int8_t* trace_buffer, + int32_t* result_buffer, + size_t* result_len_device +) { + const int diag_max = static_cast(rows + cols); + for (int d = 1; d <= diag_max; d++) { + for (int c = threadIdx.x + 1; c <= cols; c += blockDim.x) { + int r = d - c; + if (r >= 1 && r <= rows) { + int cost_idx = ((r - 1) * (cols + 1) + (c - 1)); //[r - 1, c - 1] + const float c0 = cost_buffer[cost_idx]; + const float c1 = cost_buffer[cost_idx + 1]; // [r - 1, c] + const float c2 = cost_buffer[cost_idx + cols + 1]; // [r, c - 1] + + float cost; + int8_t t; + if (c0 < c1 && c0 < c2) { + cost = c0; + t = 0; + } else if (c1 < c0 && c1 < c2) { + cost = c1; + t = 1; + } else { + cost = c2; + t = 2; + } + cost_idx += ((cols + 1) + 1); + cost_buffer[cost_idx] = cost + input[(r - 1) * cols + (c - 1)]; + trace_buffer[cost_idx] = t; + } + } + __syncthreads(); + } + + //back tracing, reverse append to result buffer + if (threadIdx.x == 0) { + int r = rows - 1; + int c = cols - 1; + int pos = static_cast(max_index_len); // reverse put + while (r >= 0 && c >= 0) { + --pos; + result_buffer[pos] = r; + result_buffer[max_index_len + pos] = c; + const int trace_index = (r + 1) * (cols + 1) + (c + 1); + int8_t t = trace_buffer[trace_index]; + switch (t) { + case 0: r -= 1; c -= 1; break; + case 1: r -= 1; break; + default: c -= 1; break; + } + } + *result_len_device = max_index_len - static_cast(pos); + } +} + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len) { + max_index_len = rows + cols + 1; + size_t cost_buffer_size = ((rows + 1) * (cols + 1)); + return batch * max_index_len * 2 * sizeof(int32_t) + // two index arrays + sizeof(int64_t) + // final index array length + batch* cost_buffer_size * sizeof(float) + // cost buffer + batch* cost_buffer_size * sizeof(int8_t); // trace buffer +} + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len +) { + ORT_ENFORCE(batch == 1); + size_t max_index_len = rows + cols + 1; + int32_t* result_buffer = (int32_t*)buffer; + size_t* result_len_device_buf = (size_t*)(result_buffer + (batch * max_index_len * 2)); + float* cost_buffer = (float*)(result_len_device_buf + 1); + int8_t* trace_buffer = (int8_t*)(cost_buffer + ((rows + 1) * (cols + 1))); + + dim3 block(device_prop.maxThreadsPerBlock); + dim3 grid_init((unsigned)SafeInt(rows + 1), (unsigned)SafeInt(batch)); + DynamicTimeWarpingInitCost<<>>(cost_buffer, trace_buffer, cols+1); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + dim3 grid(1, (unsigned)SafeInt(batch)); + DynamicTimeWarpingKernel<<>>( + rows, + cols, + max_index_len, + input, + cost_buffer, + trace_buffer, + result_buffer, + result_len_device_buf); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemcpyAsync(&result_len, result_len_device_buf, sizeof(size_t), cudaMemcpyDeviceToHost, stream))); + return CUDA_CALL(cudaGetLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h new file mode 100644 index 0000000000000..cb4a0dfb16807 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +size_t GetDynamicTimeWarpingBufferSize(size_t batch, size_t rows, size_t cols, size_t& max_index_len); + +Status LaunchDynamicTimeWarping( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t batch, + size_t rows, + size_t cols, + const float* input, + void* buffer, + size_t& result_len); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.cc b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc new file mode 100644 index 0000000000000..c38c8c5317f0a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.cc @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold.h" +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cpu/tensor/utils.h" + +#include +#include + +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + UnfoldTensor, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), + UnfoldTensor); + +Status UnfoldTensor::ComputeInternal(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + const auto& input_dims = input_tensor.Shape().GetDims(); + int rank = SafeInt(input_dims.size()); + + int dim = SafeInt(HandleNegativeAxis(dim_, rank)); + ORT_ENFORCE(dim < rank, "input rank:", rank, " is not bigger than attribut specified dim: ", dim); + ORT_ENFORCE(input_dims[dim] >= size_, "dimsize:", input_dims[dim], " is less than unfold size:", size_); + + int64_t leading_dims = std::accumulate(input_dims.begin(), input_dims.begin() + dim, 1LL, std::multiplies()); + int64_t tailing_dims = std::accumulate(input_dims.begin() + (dim + 1), input_dims.end(), 1LL, std::multiplies()); + + std::vector output_dims(rank + 1, 0); + std::copy(input_dims.begin(), input_dims.end(), output_dims.begin()); + output_dims[dim] = (input_dims[dim] - size_) / step_ + 1; + output_dims.back() = size_; + TensorShape output_shape(output_dims); + Tensor* output_tensor = ctx->Output(0, output_shape); + + cudaStream_t stream = this->Stream(ctx); + const cudaDeviceProp& device_prop = this->GetDeviceProp(); + size_t element_size = input_tensor.DataType()->Size(); + return LaunchUnfoldTensor( + stream, device_prop, element_size, input_tensor.DataRaw(), output_tensor->MutableDataRaw(), + leading_dims, input_dims[dim], tailing_dims, size_, step_); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold.h b/onnxruntime/contrib_ops/cuda/tensor/unfold.h new file mode 100644 index 0000000000000..1717687593470 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/cuda_kernel.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using onnxruntime::OpKernelContext; +using onnxruntime::OpKernelInfo; +using onnxruntime::cuda::CudaKernel; +class UnfoldTensor final : public CudaKernel { + public: + UnfoldTensor(const OpKernelInfo& info) : CudaKernel(info) { + dim_ = SafeInt(info.GetAttrOrDefault("dim", -1LL)); + step_ = SafeInt(info.GetAttrOrDefault("step", 1LL)); + ORT_ENFORCE(step_ > 0, "step must greater than zero!"); + + int64_t temp_size; + ORT_ENFORCE(info.GetAttr("size", &temp_size).IsOK()); + size_ = SafeInt(temp_size); + } + + ~UnfoldTensor() = default; + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int dim_; + int size_; + int step_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu new file mode 100644 index 0000000000000..996f340b483a3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/tensor/unfold_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/common/common.h" +#include + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void UnfoldTensorKenel( + const T* input, + T* output, + int64_t N, + int64_t unfold_size, // stride_tailing_dim_dst + int64_t tailing_dims_size, // stride_fold_dim_dst = tailing_dims_size * unfold_size, stride_append_dim_src = tailing_dims_size + int64_t stride_leading_dst, + int64_t stride_fold_dim_src, + int64_t stride_leading_src +) { + int64_t idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) return; + + const int64_t idx_leading = idx / stride_leading_dst; + int64_t n = idx % stride_leading_dst; + const int64_t stride_fold_dim_dst = tailing_dims_size * unfold_size; + const int64_t idx_fold = n / stride_fold_dim_dst; + n %= stride_fold_dim_dst; + const int64_t idx_tailing = n / unfold_size; + const int64_t idx_append = n % unfold_size; + + // stride_tailing_dim_src = 1 + int64_t idx_src = idx_leading * stride_leading_src + idx_fold * stride_fold_dim_src + idx_tailing + idx_append * tailing_dims_size; + output[idx] = input[idx_src]; +} + + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t unfold_dim_size, + int64_t tailing_dims_size, + int64_t unfold_size, + int64_t step_size +) { + int64_t TPB = device_prop.maxThreadsPerBlock; + int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1; + int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size; + int64_t num_blocks = (N + TPB - 1) / TPB; + + // int64_t stride_append_dim_dst = 1; + // int64_t stride_tailing_dim_dst = unfold_size; + // int64_t stride_fold_dim_dst = unfold_size * tailing_dims_size; + int64_t stride_leading_dst = unfold_size * tailing_dims_size * unfold_dim_size_dst; + + // int64_t stride_append_dim_src = tailing_dims_size; + // int64_t stride_tailing_dim_src = 1; + int64_t stride_fold_dim_src = tailing_dims_size * step_size; + int64_t stride_leading_src = tailing_dims_size * unfold_dim_size; + + dim3 block((unsigned)SafeInt(TPB)); + dim3 grid((unsigned)SafeInt(num_blocks)); + switch (element_size) { + case 1: + UnfoldTensorKenel<<>>( + (const int8_t*)input, (int8_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 2: + UnfoldTensorKenel<<>>( + (const int16_t*)input, (int16_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 4: + UnfoldTensorKenel<<>>( + (const int32_t*)input, (int32_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 8: + UnfoldTensorKenel<<>>( + (const int64_t*)input, (int64_t*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + case 16: + UnfoldTensorKenel<<>>( + (const float4*)input, (float4*)output, N, unfold_size, + tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); + break; + default: + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unsupported element_size"); + } + + return CUDA_CALL(cudaGetLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h new file mode 100644 index 0000000000000..9e82dccdec23c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchUnfoldTensor( + cudaStream_t stream, + const cudaDeviceProp& device_prop, + size_t element_size, + const void* input, + void* output, + int64_t leading_dims_size, + int64_t tailing_dims_size, + int64_t dim_size, + int64_t unfold_size, + int64_t step_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index d18460e016444..5e5211b2f88bd 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -58,7 +58,9 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::UpdateDecoderFeeds, GenerationCudaDeviceHelper::ExpandBuffer, GenerationCudaDeviceHelper::ExpandBuffer, - GenerationCudaDeviceHelper::ExpandBuffer); + GenerationCudaDeviceHelper::ExpandBuffer, + GenerationCudaDeviceHelper::UpdateDecoderCrossQK, + GenerationCudaDeviceHelper::FinalizeDecoderCrossQK); SetConsoleDumper(&g_cuda_dumper); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 07a8896210d2c..d3b783df2083b 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1315,6 +1315,235 @@ template void BufferExpansionKernelLauncher(const int32_t* input, int chunk_size, cudaStream_t stream); +template +__global__ void CopyCrossQKSingleDecodeStepKernel( + T* target, // shape [batchxbeam, layer_head_pair_count, max_length, frame] + T** qk_layer_pointers, + int token_index, + int num_layers, + int num_heads, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + const int pair = blockIdx.x; + const int layer_head_pair_count = gridDim.x; + const int bbm = blockIdx.y; + cross_qk_layer_head_pairs += (pair * 2); + const int layer = *cross_qk_layer_head_pairs; + const int head = *(cross_qk_layer_head_pairs + 1); + + target += ((int64_t)bbm * layer_head_pair_count + pair) * max_length * frames + ((int64_t)token_index * frames); + T* src = qk_layer_pointers[layer] + ((int64_t)bbm * num_heads + head) * frames; + + for (int tid = threadIdx.x; tid < frames; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length +) { + dim3 block(512); + dim3 grid(cross_qk_layer_head_pair_count, batchxbeam); + typedef typename ToCudaType::MappedType CudaT; + + CopyCrossQKSingleDecodeStepKernel<<>>( + (CudaT*)cross_qk_buffer_data, + (CudaT**)qk_layer_pointers, + token_index, + num_layers, + num_heads, + cross_qk_layer_head_pairs, + frames, + max_length + ); +} + + +template +__global__ void CopyDecoderCrossQKAllStepsKernel( + int context_decoding_len, + int num_beams, + int num_return_sequences, + int max_length, + int frames_of_k, + const T* cross_qk_buffer_data, // [batch, num_beams, layer_head_pair_count, max_length, frames] + T* cross_qk_output, // [batch, num_return_sequences, layer_head_pair_count, total_decoding_length, frames] + const int* cache_indir_data, // [batch, num_beams, max_length] + const int32_t* beam_indices +) { + const int pair = blockIdx.y; + const int layer_head_pair_count = gridDim.y; + const int total_decoding_length = gridDim.x; + const int token_decoding_index = blockIdx.x; + const int br = blockIdx.z; + const int batch = br / num_return_sequences; + const int ret_seq_id = br % num_return_sequences; + + // get the real beam index, as the cache_indir_data did not updated in last token + const int src_beam = beam_indices[batch * num_beams + ret_seq_id] % num_beams; + + const int64_t offset_in_cache = ((int64_t)batch * num_beams + src_beam) * max_length + token_decoding_index + context_decoding_len; + int bm_mapped = ((num_beams <= 1) ? 0: ((token_decoding_index == total_decoding_length - 1) ? ret_seq_id : cache_indir_data[offset_in_cache])); + int bi_src = batch * num_beams + bm_mapped; + + T* target = cross_qk_output + + (((int64_t)br * layer_head_pair_count + (int64_t)pair) * total_decoding_length + token_decoding_index) * frames_of_k; + const T* src = cross_qk_buffer_data + + ((int64_t)bi_src * layer_head_pair_count * max_length + (int64_t)pair * max_length + token_decoding_index) * frames_of_k; + for (int tid = threadIdx.x; tid < frames_of_k; tid += blockDim.x) { + target[tid] = src[tid]; // use vectorized read write in future if needed + } +} + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + [[maybe_unused]] const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices +) { + int64_t br = (int64_t)batch_size * num_return_sequences; + ORT_ENFORCE(br < 65536L && cross_qk_layer_head_pair_count < 65536); + const int total_decoding_length = iteration_number - 1; + dim3 block(512); + dim3 grid(total_decoding_length, cross_qk_layer_head_pair_count, (unsigned)br); + typedef typename ToCudaType::MappedType CudaT; + + CopyDecoderCrossQKAllStepsKernel<<>>( + context_decoding_len, + num_beams, + num_return_sequences, + max_length, + frames_of_k, + (const CudaT*)cross_qk_buffer_data, + (CudaT*)cross_qk_output, + cache_indir_data, + beam_indices); +} + +template +__global__ void ForceDecodingIdsKernel( + float* beam_scores, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step +) { + const int num_beams = gridDim.y; + const int beam = blockIdx.y; + const int batch = blockIdx.z; + beam_scores += (((int64_t)batch * num_beams + beam)* vocab_size); // move to (batch, beam) + const int32_t id_wanted = force_ids[((int64_t)batch * id_len) + step]; + if (id_wanted < 0 || id_wanted >= vocab_size) return; + + const int32_t elements_per_block = (int32_t)blockDim.x * ElementsPerThreads; + const int32_t block_start_id = blockIdx.x * elements_per_block; + + int32_t token_id = block_start_id + (int)threadIdx.x; + #pragma unroll + for (int elem = 0; elem < ElementsPerThreads; elem++) { + if (token_id < vocab_size) { + beam_scores[token_id] = ((token_id == id_wanted) ? 0.0f : cub::FpLimits::Lowest()); + } + token_id += (int)blockDim.x; + } +} + + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream +) { + dim3 blocks(512); + constexpr int ElementsPerThreads = 4; + unsigned gridx = static_cast((vocab_size + 512 * ElementsPerThreads - 1) / (512 * ElementsPerThreads)); + dim3 grids(gridx, num_beams, batch_size); + ForceDecodingIdsKernel<<>>( + beam_scores, vocab_size, force_ids, id_len, step + ); +} + +template +__global__ void SaveNoSpeechProbsKernel( + T* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id +) { + int b = blockIdx.x * blockDim.x + threadIdx.x; + if (b < batch_size) { + int64_t src_offset = b * num_beams * vocab_size + no_speech_token_id; + result_no_speech_probs[b] = (T)(probs[src_offset]); + } +} + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +) { + int tpb = 256; + int bpg = (batch_size + 255) / 256; + + typedef typename ToCudaType::MappedType CudaT; + SaveNoSpeechProbsKernel<<>>( + (CudaT*)result_no_speech_probs, probs, batch_size, num_beams, vocab_size, no_speech_token_id); +} + +template void LaunchSaveNoSpeechProbs( + float* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + +template void LaunchSaveNoSpeechProbs( + MLFloat16* result_no_speech_probs, + const float* probs, + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream +); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 8c52f6fd52385..07ccbeced456e 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -213,6 +213,55 @@ void BufferExpansionKernelLauncher(const T* input, int chunk_size, cudaStream_t stream); +void LaunchCopyCrossQKSingleDecodeStep( + cudaStream_t stream, + float* cross_qk_buffer_data, + float** qk_layer_pointers, + int token_index, + int batchxbeam, + int num_layers, + int num_heads, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames, + int max_length); + +void LaunchFinalizeCrossQK( + cudaStream_t stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + const int32_t* beam_indices); + +void LaunchForceDecodingIds( + float* beam_scores, + const int batch_size, + const int num_beams, + const int vocab_size, + const int32_t* force_ids, + int id_len, + int step, + cudaStream_t stream); + +template +void LaunchSaveNoSpeechProbs( + T* result_no_speech_probs, /* [batch]*/ + const float* probs, /* [batch, num_beams, vocab_size]*/ + const int batch_size, + const int num_beams, + const int vocab_size, + const int no_speech_token_id, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e4de33499c6ca..32cf8655f3b04 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -13,6 +13,8 @@ #include #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/generation_shared.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" @@ -210,7 +212,7 @@ Status AddToFeeds(Stream* ort_stream, ORT_ENFORCE(total_bytes > 0); cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes); + auto pinned_buffer = IAllocator::MakeUniquePtr(host_allocator, total_bytes, false, ort_stream); char* pinned_data = static_cast(pinned_buffer.get()); // Copy tensors to one pinned memory buffer (so that we only need copy to GPU once) char* destination = pinned_data; @@ -426,11 +428,21 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif + const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); + if (step == 1 && is_whisper_model && parameters->no_speech_probs) { + cuda::LaunchSaveNoSpeechProbs( + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + } + + // NOTE: currently we treat extra decoding ids are same + int extra_decoding_len = static_cast(parameters->extra_decoding_ids.size() / parameters->batch_size); + const bool need_handle_extra_decoding_ids = is_whisper_model && (!parameters->extra_decoding_ids.empty()) && (extra_decoding_len >= step); + cuda::LaunchLogitsProcessKernel( next_token_scores.data(), parameters->vocab_mask.data(), - step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. - nullptr, // parameters->presence_mask.data(), + (step > extra_decoding_len + 1) ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + nullptr, // parameters->presence_mask.data(), parameters->presence_penalty, parameters->temperature, parameters->batch_size, @@ -445,6 +457,50 @@ Status ProcessLogits(const OrtValue& logits, // parameters->no_repeat_ngram_size, cuda_stream); + // Whisper time stamp generation. + // TODO: implement it on GPU + bool gen_timestamp = is_whisper_model && + (parameters->logits_processor == onnxruntime::contrib::transformers::IGenerationParameters::kLogitsProcessorTypeWhisper); + if (gen_timestamp) { + // Copy next token scores to cpu memory, copy Sequences to cpu + std::vector cpu_next_token_scores(next_token_scores.size()); + gsl::span cpu_next_token_scores_span(cpu_next_token_scores.data(), cpu_next_token_scores.size()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_next_token_scores.data(), + next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(const_cast(sequences->GetSequence(0).data()), + sequences->GetCurrentDeviceSequences().data(), + sequences->GetSequence(0).size_bytes() * batch_beam_size, + cudaMemcpyDeviceToHost, + cuda_stream)); + constexpr int max_initial_timestamp_index = 50; + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + time_logit_processor.Process(sequences, next_token_scores_timestamp); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(next_token_scores.data(), + cpu_next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyHostToDevice, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + + if (need_handle_extra_decoding_ids && !parameters->extra_decoding_ids.empty()) { + cuda::LaunchForceDecodingIds( + next_token_scores.data(), + parameters->batch_size, + parameters->num_beams, + parameters->vocab_size, + parameters->extra_decoding_ids.data(), + parameters->extra_decoding_ids.size() / parameters->batch_size, + step - 1, + cuda_stream); + } + #ifdef DEBUG_GENERATION dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif @@ -807,13 +863,11 @@ Status GreedySearchProcessLogits( // Sequences generated by beam scorer is currently stored in CPU. // Copy sequences to device only when repetition penalty or no repeat ngram is used in kernel - BufferUniquePtr sequences_buffer; + IAllocatorUniquePtr sequences_buffer; int current_sequence_length = sequences->GetSequenceLength(); if (parameters->repetition_penalty != 1.0f) { size_t bytes = SafeInt(sizeof(int32_t)) * batch_beam_size * parameters->max_length; - void* data = allocator->Alloc(bytes); - BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); - sequences_buffer = std::move(temp_buffer); + sequences_buffer = IAllocator::MakeUniquePtr(allocator, bytes, false, stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, cudaMemcpyHostToDevice, cuda_stream)); } @@ -1196,14 +1250,14 @@ Status UpdateDecoderFeeds( if (past_present_share_buffer) { // Update past sequence length input - const ptrdiff_t past_sequence_length_idx = 2 * (static_cast(last_outputs.size()) - t5_decoder_first_present_output_idx) + t5_decoder_first_past_input_idx; + const ptrdiff_t past_sequence_length_idx = 2 * num_present_tensors + t5_decoder_first_past_input_idx; *(next_inputs[past_sequence_length_idx].GetMutable()->MutableData()) = current_length - 1; // Update beam search specific input for DecoderMaskedSelfAttention (cache indirection) if present // If the last input is not `past_sequence_length`, then the beam search specific inputs // for `DecoderMaskedSelfAttention` is present - if (need_cache_indir) { + if (need_cache_indir && num_beams > 1) { ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiHeadAttention with BeamSearch"); // The cache indirection feed comes 2 feeds after the `past_sequence_length` feed @@ -1528,6 +1582,93 @@ template Status ExpandBuffer( OrtValue& expanded, bool only_copy_shape, int max_sequence_length); + +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + if (qk_layer_pointers.get() == nullptr) { + // Put all the qk pointers into gpu, as they did not change in following decoding steps + // also this help to use single kernel to process each step + qk_layer_pointers = IAllocator::MakeUniquePtr(allocator, static_cast(num_layers), false, stream); + std::vector qk_layer_data(num_layers, nullptr); + for (int layer = 0; layer < num_layers; layer++) { + qk_layer_data[layer] = cross_qks[layer].GetMutable()->MutableData(); + } + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync((void*)qk_layer_pointers.get(), qk_layer_data.data(), sizeof(qk_layer_data[0]) * num_layers, + cudaMemcpyHostToDevice, cuda_stream)); + } + + auto cross_qk_layer_shape = cross_qks[0].GetMutable()->Shape(); + int64_t batchxbeam = cross_qk_layer_shape[0]; + int64_t num_heads = cross_qk_layer_shape[1]; + int64_t frames = cross_qk_layer_shape[3]; + + cuda::LaunchCopyCrossQKSingleDecodeStep( + cuda_stream, + cross_qk_buffer_data, + qk_layer_pointers.get(), + iteration_number - 2, + batchxbeam, + num_layers, + num_heads, + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + frames, + max_length); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices_gpu) { + cudaStream_t cuda_stream = stream ? static_cast(stream->GetHandle()) : nullptr; + + cuda::LaunchFinalizeCrossQK( + cuda_stream, + iteration_number, + context_decoding_len, + batch_size, + num_beams, + max_length, + cross_qk_layer_head_pair_count, + cross_qk_layer_head_pairs, + frames_of_k, + cross_qk_buffer_data, + cross_qk_output, + num_return_sequences, + cache_indir_data, + beam_indices_gpu.data()); + + CUDA_RETURN_IF_ERROR(cudaGetLastError()); + + return Status::OK(); +} + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index f5f062d7a101b..7a718eb9f66c1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -150,6 +150,34 @@ Status ExpandBuffer( bool only_copy_shape, int max_sequence_length = 0); +Status UpdateDecoderCrossQK( + int iteration_number, + Stream* stream, + OrtValue* cross_qks, + IAllocatorUniquePtr& qk_layer_pointers, + int num_layers, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + float* cross_qk_buffer_data, + int max_length, + AllocatorPtr allocator); + +Status FinalizeDecoderCrossQK( + Stream* stream, + int iteration_number, + int context_decoding_len, + int batch_size, + int num_beams, + int max_length, + int cross_qk_layer_head_pair_count, + const int* cross_qk_layer_head_pairs, + int frames_of_k, + const float* cross_qk_buffer_data, + float* cross_qk_output, + int num_return_sequences, + const int* cache_indir_data, + gsl::span beam_indices); + } // namespace GenerationCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index e5956a575d73d..880b2df3543e6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -746,6 +746,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, OPTIONAL_VALUE) + .Attr("output_qk", + "Need output the cross attention MatMul(Q, K)", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", "Query with shape (batch_size, 1, hidden_size) or packed QKV with shape " @@ -837,6 +841,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "while effective_seq_length = (past_sequence_length + kv_sequence_length).", "T", OpSchema::Optional) + .Output(3, + "qk", + "normalized Q * K, of shape (batch_size, num_heads, 1, head_size). ", + "V", + OpSchema::Optional) + .TypeConstraint("V", {"tensor(float)"}, "Constrain qk output types to float32 tensors.") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index a79203a94a3a7..00740ada2677d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -407,6 +407,9 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 1); if (ctx.getNumOutputs() > 2) { ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 2); + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::updateOutputElemType(ctx, 3, ONNX_NAMESPACE::TensorProto::FLOAT); + } } } @@ -415,6 +418,7 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { // output 0 (sequences) shape: (batch_size, num_return_sequences, max_length) // output 1 (sequences_scores) shape: (batch_size, num_return_sequences) // output 2 (scores) shape: (max_length - sequence_length, batch_size, num_beams, vocab_size) + // output 3 (cross_attention): shape: (batch_size, num_return_sequences, Layers, Heads, max_length, Frames) if (!hasInputShape(ctx, 0)) { return; } @@ -490,6 +494,22 @@ void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { } updateOutputShape(ctx, 2, scores_shape); } + + if (ctx.getNumOutputs() > 3) { + ONNX_NAMESPACE::TensorShapeProto cross_attn_shape; + cross_attn_shape.add_dim()->set_dim_value(batch_size); + cross_attn_shape.add_dim()->set_dim_value(num_return_sequences_value); + cross_attn_shape.add_dim(); // num of layer is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim(); // num of head is unknown, no need to calc it from subgraph here + cross_attn_shape.add_dim()->set_dim_value(max_length_value); + cross_attn_shape.add_dim()->set_dim_value(sequence_length); + updateOutputShape(ctx, 3, cross_attn_shape); + } + if (ctx.getNumOutputs() > 4) { + ONNX_NAMESPACE::TensorShapeProto non_speech_probs_shape; + non_speech_probs_shape.add_dim()->set_dim_value(batch_size); + updateOutputShape(ctx, 4, non_speech_probs_shape); + } } } @@ -1060,6 +1080,78 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1, updateOutputShape(ctx, 0, {N, C, H_out, W_out}); })); +ONNX_MS_OPERATOR_SET_SCHEMA( + UnfoldTensor, 1, + OpSchema() + .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. " + "Step between two slices is given by step. " + "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in " + "the returned tensor will be (sizedim - size) / step + 1. " + "An additional dimension of size size is appended in the returned tensor.") + .Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1)) + .Attr("size", "specify the size", AttributeProto::INT) + .Attr("step", "specify the step.", AttributeProto::INT, static_cast(1)) + .Input(0, "input", "input tensor", "T") + .Output(0, "output", "Output tensor.", "T") + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + if (!hasInputShape(ctx, 0)) return; + auto& input_shape = getInputShape(ctx, 0); + const int rank = input_shape.dim_size(); + int64_t dim = getAttribute(ctx, "dim", -1); + dim = HandleNegativeAxis(dim, rank); + if (!input_shape.dim(dim).has_dim_value()) { + return; + } + int64_t dim_size = input_shape.dim(dim).dim_value(); + + const int64_t step = getAttribute(ctx, "step", -1); + if (step <= 0) { + fail_shape_inference("size attribute in UnfoldTensor must greater than 0.") + } + int64_t size = -1; + auto size_proto = ctx.getAttribute("size"); + if (!(size_proto)) { + fail_shape_inference("size attribute in UnfoldTensor not specified!") + } + size = size_proto->i(); + if (size > dim_size || size <= 0) { + fail_shape_inference("size attribute in UnfoldTensor not positive and less than the dim size!") + } + + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (int d = 0; d < rank; d++) { + if (d == dim) { + output_shape.add_dim()->set_dim_value((dim_size - size) / step + 1); + } else { + *output_shape.add_dim() = input_shape.dim(d); + } + } + output_shape.add_dim()->set_dim_value(size); + updateOutputShape(ctx, 0, output_shape); + })); + +ONNX_MS_OPERATOR_SET_SCHEMA( + DynamicTimeWarping, 1, + OpSchema() + .SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point" + "(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return " + "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])" + "have the lowest cost among all paths from (0, 0) to (M-1, N-1).") + .Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F") + .Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I") + .TypeConstraint("F", {"tensor(float)"}, "Constrain to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); + ONNX_NAMESPACE::TensorShapeProto resultShape; + resultShape.add_dim()->set_dim_value(2); + resultShape.add_dim(); + updateOutputShape(ctx, 0, resultShape); + })); + ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, OpSchema() .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") @@ -1068,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) - .Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast(0)) + .Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5; 2 for whisper", AttributeProto::INT, static_cast(0)) .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) .Attr("init_decoder", "The subgraph for the first decoding run. It will be called once before `decoder` subgraph. " @@ -1079,6 +1171,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Size of the vocabulary. " "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token", + "The token in whisper model that mark all sequence empty. With this model, whisper could output no_speech_prob after Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) @@ -1095,6 +1191,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) + .Input(12, "cross_qk_layer_head", + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", + "I", OpSchema::Optional) + .Input(13, "extra_decoding_ids", + "Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len)." + "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " + "are treated as stop of the extra_decoding_ids for corresponding batch.", + "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", @@ -1102,6 +1207,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) + .Output(3, "cross_qk", + "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", + "V", OpSchema::Optional) + .Output(4, "non_speech_probs", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." + "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." + "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "T", OpSchema::Optional) + .TypeConstraint("V", {"tensor(float)"}, "Constrain to float32 tensors.") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 3c31997286254..2a241dba922cb 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -102,6 +102,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TorchEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, TransposeMatMul); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Trilu); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, UnfoldTensor); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DynamicTimeWarping); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Unique); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); @@ -203,6 +205,8 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index ad892eab3b843..ef9c81440927f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -250,7 +250,7 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; - } else if (info.enable_cuda_graph) { + } else if (info.enable_cuda_graph || info.use_ep_level_unified_stream) { // current cuda graph implementation only works with single stream // use EP level unified stream for all the reqeust CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index ca88b3474b758..966448051264d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -29,6 +29,7 @@ constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; +constexpr const char* KUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; } // namespace provider_option_names } // namespace cuda @@ -99,6 +100,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) + .AddAssignmentToReference(cuda::provider_option_names::KUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -144,6 +146,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; @@ -162,6 +165,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, + {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 789b02b0e1d8c..89b266f362e8d 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -72,6 +72,8 @@ struct CUDAExecutionProviderInfo { bool enable_skip_layer_norm_strict_mode{false}; + bool use_ep_level_unified_stream{false}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 5a11f2529f38e..53746ff1e1fe6 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -222,6 +222,7 @@ struct CUDA_Provider : Provider { info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; + info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; return std::make_shared(info); } @@ -253,6 +254,7 @@ struct CUDA_Provider : Provider { cuda_options.enable_cuda_graph = internal_options.enable_cuda_graph; cuda_options.cudnn_conv1d_pad_to_nc1d = internal_options.cudnn_conv1d_pad_to_nc1d; cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; + cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index bf7a3bbd9d380..29ca5bd7b0055 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1332,6 +1332,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const cuda_options_converted.enable_cuda_graph = 0; cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0; cuda_options_converted.enable_skip_layer_norm_strict_mode = 0; + cuda_options_converted.use_ep_level_unified_stream = 0; return cuda_options_converted; } diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index c0cabbb5e9759..c1c709d6d759b 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,7 +1272,46 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: GraphProto): +def update_decoder_subgraph_output_cross_attention(subg: GraphProto): + input_self_past_0 = 1 + # w/wo attention mask, w/wo hidden_state + graph_input_names = [gi.name for gi in subg.input] + while input_self_past_0 < 3 and not graph_input_names[input_self_past_0].startswith("past"): + input_self_past_0 += 1 + output_self_present_0 = 1 + + num_layers = (len(subg.output) - output_self_present_0) // 2 + input_cross_past_0 = 2 * num_layers + input_self_past_0 + past_key_cross_inputs = {subg.input[layer * 2 + input_cross_past_0].name: layer for layer in range(num_layers)} + print(f" --past_key_cross_inputs={past_key_cross_inputs}") + + input_past_key_cross_0_shape = shape_of(subg.input[input_cross_past_0]) + print(f"past_key_cross_0_shape is {input_past_key_cross_0_shape}") + batch_size_dim = input_past_key_cross_0_shape[0] + num_heads_dim = input_past_key_cross_0_shape[1] + cross_seq_len_dim = input_past_key_cross_0_shape[2] + + num_layer_output_qk = 0 + for node in subg.node: + if (node.op_type == "DecoderMaskedMultiHeadAttention") and (node.input[1] in past_key_cross_inputs): + print(f" -- add cross QK output from: node: {node.name} with output: {node.output}") + num_layer_output_qk += 1 + layer = past_key_cross_inputs[node.input[1]] + cross_attention_out_name = f"output_cross_qk_{layer}" + appended_names = [""] * (3 - len(node.output)) + appended_names.append(cross_attention_out_name) + node.output.extend(appended_names) + node.attribute.extend([onnx.helper.make_attribute("output_qk", 1)]) + + cross_attention = onnx.helper.make_tensor_value_info( + cross_attention_out_name, TensorProto.FLOAT, [batch_size_dim, num_heads_dim, 1, cross_seq_len_dim] + ) + subg.output.extend([cross_attention]) + if num_layer_output_qk != num_layers: + raise ValueError(f"Did not add cross QK for all layers{num_layers} vs {num_layer_output_qk}") + + +def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelProto): input_self_past_0 = 1 # w/wo attention mask, w/wo hidden_state graph_input_names = [gi.name for gi in subg.input] diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 3562df1660ea9..c5f43ce782331 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -169,6 +169,69 @@ def parse_arguments(argv=None): ) parser.set_defaults(chain_model=True) + parser.add_argument( + "--extra_decoding_ids", + required=False, + action="store_true", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + ) + parser.set_defaults(extra_decoding_ids=False) + + parser.add_argument( + "--collect_cross_qk", + required=False, + action="store_true", + help="Beam search model collect stacked cross QK.", + ) + parser.set_defaults(collect_cross_qk=False) + + parser.add_argument( + "--output_cross_qk", + required=False, + action="store_true", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", + ) + parser.set_defaults(output_cross_qk=False) + + parser.add_argument( + "--no_speech_token_id", + default=50362, + type=int, + help="specify no_speech_token_id. Default is 1000. if >= 0, will be add into beam search attr", + ) + + parser.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + ) + parser.set_defaults(output_no_speech_probs=False) + + parser.add_argument( + "--output_scores", + required=False, + action="store_true", + help="Beam search model output scores over vocab per generated token.", + ) + parser.set_defaults(output_scores=False) + + parser.add_argument( + "--output_sequence_scores", + required=False, + action="store_true", + help="Beam search model output scores for each generated sequence.", + ) + parser.set_defaults(output_sequence_scores=False) + + parser.add_argument( + "--cross_qk_onnx_model", + required=False, + type=str, + default=None, + help="the model which consume cross_qk.", + ) + parser.add_argument( "--beam_output_model", type=str, @@ -220,6 +283,7 @@ def parse_arguments(argv=None): ) args = parser.parse_args(argv) + args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk return args diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 3b1e656136547..a1ed0c7ed5ca2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,12 +1,19 @@ import logging import os +import sys import onnx -from benchmark_helper import Precision -from convert_generation import get_shared_initializers, update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha from onnx import TensorProto, helper from transformers import WhisperConfig +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from benchmark_helper import Precision # noqa: E402 +from convert_generation import ( # noqa: E402 + get_shared_initializers, + update_decoder_subgraph_output_cross_attention, + update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, +) + logger = logging.getLogger(__name__) @@ -42,8 +49,24 @@ def chain_model(args): "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", ] + beam_outputs = ["sequences"] + if args.output_sequence_scores: + beam_outputs.append("sequence_scores") + if args.output_scores: + beam_outputs.append("scores") + + if args.collect_cross_qk: + while len(beam_outputs) < 3: + beam_outputs.extend([""]) + beam_outputs.extend(["cross_qk"]) + if args.output_no_speech_probs: + while len(beam_outputs) < 4: + beam_outputs.extend([""]) + beam_outputs.extend(["no_speech_probs_beam"]) input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None if args.precision == Precision.FLOAT16: @@ -81,6 +104,10 @@ def chain_model(args): helper.make_attribute("model_type", 2), ] ) + if args.collect_cross_qk: + node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) + if args.no_speech_token_id >= 0: + node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] @@ -121,17 +148,50 @@ def chain_model(args): logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) graph_inputs.append(logits_processor) + if args.collect_cross_qk: + cross_qk_layer_head = helper.make_tensor_value_info( + "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] + ) + graph_inputs.append(cross_qk_layer_head) + + if args.extra_decoding_ids: + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + graph_inputs.append(extra_decoding_ids) + # graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) graph_outputs = [sequences] + if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + graph_outputs.extend([cross_qk]) + + if args.output_no_speech_probs: + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) + graph_outputs.extend([no_speech_probs]) + + if args.output_sequence_scores: + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + graph_outputs.extend([sequence_scores]) + + if args.output_scores: + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + graph_outputs.extend([scores]) if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") else: logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph") + if hasattr(args, "collect_cross_qk") and args.collect_cross_qk: + update_decoder_subgraph_output_cross_attention(decoder_model.graph) # Initializers/opsets # Delete shared data between decoder/encoder and move to larger graph initializers @@ -150,7 +210,34 @@ def chain_model(args): if args.precision == Precision.FLOAT16 else [node] ) + if args.output_no_speech_probs: + prob_cast_node = helper.make_node( + "Cast", + inputs=["no_speech_probs_beam"], + outputs=["no_speech_probs"], + name="no_speech_probs_cast_to_fp32", + to=TensorProto.FLOAT, + ) + graph_nodes.extend([prob_cast_node]) + beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + beam_graph_input_names = [gi.name for gi in graph_inputs] + beam_graph_output_names = [go.name for go in graph_outputs] + + if args.cross_qk_onnx_model: + post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True) + post_qk_graph = post_qk_model.graph + beam_graph.initializer.extend(post_qk_graph.initializer) + beam_graph.node.extend(post_qk_graph.node) + # TODO: Treat same name same input, user need check their shapes, etc + for pgi in post_qk_graph.input: + if ( + (pgi.name not in beam_graph_input_names) + and (pgi.name not in beam_graph_output_names) + and (pgi.name != "cross_qk") + ): + beam_graph.input.extend([pgi]) + beam_graph.output.extend(post_qk_graph.output) # Verify graph's inputs match beam search's inputs verify_inputs(beam_inputs, graph_inputs) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 3a81700a7fd04..8c22cd5e745b3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -312,6 +312,7 @@ def verify_onnx( "tensor(uint8)": np.uint8, } + use_extra_decoding_ids = "extra_decoding_ids" in ort_names for name, dtype in zip(ort_names, ort_dtypes): if name == "input_features": inputs[name] = inputs[name].detach().cpu().numpy() @@ -320,9 +321,18 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) + raw_input_ids = ( + [[config.decoder_start_token_id]] + if use_extra_decoding_ids + else [[config.decoder_start_token_id, 50259, 50359, 50363]] + ) + inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) + elif name == "cross_qk_layer_head": + inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) + elif name == "extra_decoding_ids": + inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] diff --git a/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc new file mode 100644 index 0000000000000..28f15b58eb467 --- /dev/null +++ b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "test/common/cuda_op_test_utils.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +TEST(DynamicTimeWarp, simple) { + if (NeedSkipIfCudaArchLowerThan(530)) { + return; + } + + std::vector X = { + 3.0f, + 8.0f, + 5.0f, + 1.0f, + 9.0f, + 8.0f, + 5.0f, + 7.0f, + 4.0f, + 4.0f, + 9.0f, + 6.0f, + 2.0f, + 9.0f, + 7.0f, + 2.0f, + 5.0f, + 6.0f, + 1.0f, + 8.0f, + 4.0f, + 6.0f, + 5.0f, + 8.0f, + 4.0f, + 8.0f, + 3.0f, + 6.0f, + 3.0f, + 9.0f, + 1.0f, + 1.0f, + 6.0f, + 8.0f, + 3.0f, + 5.0f, + 5.0f, + 3.0f, + 3.0f, + 8.0f, + 8.0f, + 7.0f, + 1.0f, + 2.0f, + 2.0f, + 1.0f, + 5.0f, + 4.0f, + 5.0f, + 0.0f, + 3.0f, + 6.0f, + 3.0f, + 7.0f, + 4.0f, + 5.0f, + 4.0f, + 5.0f, + 4.0f, + 0.0f, + }; + + std::vector Y = { + 0, + 1, + 2, + 3, + 4, + 4, + 4, + 4, + 5, + 5, + 5, + 5, + 0, + 1, + 1, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + }; + + OpTester tester("DynamicTimeWarping", 1, onnxruntime::kMSDomain); + tester.AddInput("input", {6, 10}, X); + tester.AddOutput("output", {2, 12}, Y); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/tensor_op_test.cc b/onnxruntime/test/contrib_ops/tensor_op_test.cc index 44cb49580ce8b..323a8b2cb00ef 100644 --- a/onnxruntime/test/contrib_ops/tensor_op_test.cc +++ b/onnxruntime/test/contrib_ops/tensor_op_test.cc @@ -6,6 +6,7 @@ #include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/common/cuda_op_test_utils.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::test; @@ -14,6 +15,76 @@ namespace test { using ExpectResult = OpTester::ExpectResult; +TEST(UnfoldTensorOpTest, LastDim) { + if (NeedSkipIfCudaArchLowerThan(530)) { + return; + } + + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f, + 6.0f, 7.0f, 8.0f, 9.0f}; + + std::vector output = { + 1.0f, 2.0f, 3.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 6.0f, 7.0f, 8.0f, + 6.0f, 7.0f, 8.0f, 7.0f, 8.0f, 9.0f}; + + OpTester tester("UnfoldTensor", 1, onnxruntime::kMSDomain); + + tester.AddAttribute("size", 3LL); + tester.AddInput("input", {3, 4}, X); + tester.AddOutput("output", {3, 2, 3}, output); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(UnfoldTensorOpTest, NormalDim) { + if (NeedSkipIfCudaArchLowerThan(530)) { + return; + } + + std::vector X = { + 1, 2, 3, 4, 2, 2, 3, 4, 3, 2, 3, 4, + 4, 6, 7, 8, 5, 6, 7, 8, 6, 6, 7, 8, + 6, 7, 8, 9, 7, 7, 8, 9, 8, 7, 8, 9, + 9, 7, 8, 9, 10, 7, 8, 9, 11, 7, 8, 9}; + + std::vector output = { + 1, 2, 3, + 2, 2, 2, + 3, 3, 3, + 4, 4, 4, + + 3, 4, 5, + 2, 6, 6, + 3, 7, 7, + 4, 8, 8, + + 6, 7, 8, + 7, 7, 7, + 8, 8, 8, + 9, 9, 9, + + 8, 9, 10, + 7, 7, 7, + 8, 8, 8, + 9, 9, 9}; + + OpTester tester("UnfoldTensor", 1, onnxruntime::kMSDomain); + tester.AddAttribute("dim", 1LL); + tester.AddAttribute("size", 3LL); + tester.AddAttribute("step", 2LL); + tester.AddInput("input", {2, 6, 4}, X); + tester.AddOutput("output", {2, 2, 4, 3}, output); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(CropContribOpTest, CropBorderOnly) { constexpr int N = 2, C = 1, H = 3, W = 4; std::vector X = {1.0f, 2.0f, 3.0f, 4.0f, diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 66200af06f511..77ce09d7e793b 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -10,7 +10,7 @@ import pytest import torch -from onnxruntime import InferenceSession, SessionOptions +from onnxruntime import InferenceSession, SessionOptions, get_available_providers class TestTimestampProcessor(unittest.TestCase): @@ -52,12 +52,13 @@ def run_timestamp(self, provider: str): ort_transcription = processor.batch_decode( ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True ) + print(ort_transcription) expected_transcription = [ { - "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "timestamp": (0.0, 5.44), } ], @@ -70,6 +71,12 @@ def test_timestamp_cpu(self): provider = "CPUExecutionProvider" self.run_timestamp(provider) + @pytest.mark.slow + def test_timestamp_cuda(self): + cuda_provider = "CUDAExecutionProvider" + if cuda_provider in get_available_providers(): + self.run_timestamp(cuda_provider) + if __name__ == "__main__": unittest.main()