Skip to content

Commit

Permalink
Support output cross qk in masked decoder multihead attention kernel
Browse files Browse the repository at this point in the history
Support cross qk in beam search for whisper model and related features
Make whisper exporting tools support cross qk and some related features,
    * extra_decoding_ids
    * no_speech_prob
Implement DTW kernel, UnfoldTensor kernel with unit test and shape inference
Several fix related with multiple session running parallel, like:
    * guard multihead_attention, fused_fp16_runner_
    * some memory allocation with stream awareness
    * add use_ep_level_unified_stream option
Make timestamp Logits Processor for GPU after beam scoring move to GPU,
add GPU test for whisper timestamp Logits Procesor
  • Loading branch information
zhanghuanrong committed Sep 20, 2023
1 parent d522cc7 commit 67c43a7
Show file tree
Hide file tree
Showing 59 changed files with 2,054 additions and 279 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ struct OrtCUDAProviderOptionsV2 {
int tunable_op_max_tuning_duration_ms = 0; // Max tuning duration time limit for TunableOp.
int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel.
// The strict mode has better accuracy but lower performance.
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not

Check warning on line 35 in include/onnxruntime/core/providers/cuda/cuda_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_provider_options.h#L35

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_provider_options.h:35:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ Status CheckInputs(const T* query,
}
if (past_key_dims[2] != past_value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length)");
"Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ",
past_key_dims[2], " vs ", past_value_dims[2]);
}
if (past_key_dims[3] != head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
create_beam_scorer_func_};
create_beam_scorer_func_,
update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK,

Check warning on line 323 in onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search.cc#L323

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search.cc:323:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK,

Check warning on line 324 in onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search.cc#L324

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search.cc:324:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
cuda_device_prop_,
cuda_device_arch_};

#ifdef USE_CUDA
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
Expand All @@ -340,7 +345,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
update_decoder_feeds_fp16_func_ ? update_decoder_feeds_fp16_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<MLFloat16>,
expand_buffer_float_func_,
expand_buffer_float16_func_,
create_beam_scorer_func_};
create_beam_scorer_func_,
update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK,

Check warning on line 349 in onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search.cc#L349

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search.cc:349:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK,

Check warning on line 350 in onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search.cc#L350

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search.cc:350:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
cuda_device_prop_,
cuda_device_arch_};

#ifdef USE_CUDA
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,16 @@ class BeamSearch : public IControlFlowKernel {
const GenerationDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func,
const GenerationDeviceHelper::ExpandBufferFunc<int32_t>& expand_buffer_int32_func,
const GenerationDeviceHelper::ExpandBufferFunc<float>& expand_buffer_float_func,
const GenerationDeviceHelper::ExpandBufferFunc<MLFloat16>& expand_buffer_float16_func) {
const GenerationDeviceHelper::ExpandBufferFunc<MLFloat16>& expand_buffer_float16_func,
const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func,
const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) {
update_decoder_feeds_func_ = update_decoder_feeds_func;
update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
expand_buffer_int32_func_ = expand_buffer_int32_func;
expand_buffer_float_func_ = expand_buffer_float_func;
expand_buffer_float16_func_ = expand_buffer_float16_func;
update_decoder_cross_qk_func_ = update_decoder_cross_qk_func;
finalize_decoder_cross_qk_func_ = finalize_decoder_cross_qk_func;
}

#ifdef USE_CUDA
Expand Down Expand Up @@ -175,6 +179,10 @@ class BeamSearch : public IControlFlowKernel {
BeamSearchParameters parameters_;

bool has_init_decoder_ = false;

GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_;

GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_;
};

} // namespace transformers
Expand Down
73 changes: 39 additions & 34 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,35 @@ struct BeamSearchState : IBeamSearchState<T> {
BeamSearchState(const IGenerationParameters& parameters,
AllocatorPtr allocator,
int has_decoder_masked_attention,
bool use_position) {
bool use_position,
Stream* stream) {
size_t batch_beam_size = SafeInt<size_t>(parameters.batch_size) * parameters.num_beams;

size_t next_token_size = SafeInt<size_t>(batch_beam_size) * parameters.vocab_size;
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size);
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size);
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_scores = AllocateBuffer<float>(allocator, next_scores_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size, stream);
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size, stream);
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size, stream);

Check warning on line 27 in onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L27

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h:27:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size, stream);

Check warning on line 28 in onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L28

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h:28:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
this->next_scores = AllocateBuffer<float>(allocator, next_scores_buffer_, SafeInt<size_t>(2) * batch_beam_size, stream);

Check warning on line 29 in onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L29

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h:29:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

constexpr size_t max_parts_of_vocab = 128;
size_t topk_buffer_size = SafeInt<size_t>(batch_beam_size) * (max_parts_of_vocab + 1) * parameters.num_beams * 2 * 2;
this->topk_buffer = AllocateBuffer<float>(allocator, topk_temp_buffer_, topk_buffer_size);
this->topk_buffer = AllocateBuffer<float>(allocator, topk_temp_buffer_, topk_buffer_size, stream);

if (allocator->Info().device.Type() == OrtDevice::GPU) {
size_t sequences_elements = SafeInt<size_t>(2) * batch_beam_size * parameters.max_length;
this->sequences_device = AllocateBuffer<int32_t>(allocator, sequences_device_buffer_, sequences_elements);
this->sequences_device = AllocateBuffer<int32_t>(allocator, sequences_device_buffer_, sequences_elements, stream);
}

if (use_position) {
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size);
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size, stream);
}

this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size);
this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size, stream);

if (parameters.output_scores) {
size_t elements = SafeInt<size_t>(parameters.max_length - parameters.sequence_length) * parameters.batch_size * parameters.num_beams * parameters.vocab_size;
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements, stream);
this->remaining_scores = this->scores;
}

Expand All @@ -68,35 +69,38 @@ struct BeamSearchState : IBeamSearchState<T> {
}

private:
BufferUniquePtr next_token_logits_buffer_;
BufferUniquePtr next_token_scores_buffer_;
BufferUniquePtr next_tokens_buffer_;
BufferUniquePtr next_indices_buffer_;
BufferUniquePtr next_scores_buffer_;
BufferUniquePtr next_positions_buffer_;
BufferUniquePtr beam_scores_buffer_;
BufferUniquePtr scores_buffer_;
BufferUniquePtr topk_temp_buffer_;
BufferUniquePtr sequences_device_buffer_;
IAllocatorUniquePtr<void> next_token_logits_buffer_;
IAllocatorUniquePtr<void> next_token_scores_buffer_;
IAllocatorUniquePtr<void> next_tokens_buffer_;
IAllocatorUniquePtr<void> next_indices_buffer_;
IAllocatorUniquePtr<void> next_scores_buffer_;
IAllocatorUniquePtr<void> next_positions_buffer_;
IAllocatorUniquePtr<void> beam_scores_buffer_;
IAllocatorUniquePtr<void> scores_buffer_;
IAllocatorUniquePtr<void> topk_temp_buffer_;
IAllocatorUniquePtr<void> sequences_device_buffer_;
};

struct BeamSearchCpuState : IBeamSearchCpuState {
Sequences sequences;

BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda)
BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda, Stream* stream)
: parameters_{parameters} {
sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size_);
sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size_, stream);

size_t sequences_bytes = SafeInt<size_t>(2) * batch_beam_size_ * parameters.max_length;
sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, sequences_bytes, true /* fill */);
sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, sequences_bytes, stream, true /* fill */);

Check warning on line 92 in onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L92

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h:92:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
sequences.Init(sequences_space, batch_beam_size_, parameters.sequence_length, parameters.max_length);

if (is_cuda) {
// buffers used by CUDA operator but not by CPU operator.
topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * static_cast<size_t>(batch_beam_size_));
topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * static_cast<size_t>(batch_beam_size_));
topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * static_cast<size_t>(batch_beam_size_));
final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size_);
topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * static_cast<size_t>(batch_beam_size_), stream);

Check warning on line 97 in onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L97

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h:97:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * static_cast<size_t>(batch_beam_size_), stream);

Check warning on line 98 in onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L98

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h:98:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * static_cast<size_t>(batch_beam_size_), stream);

Check warning on line 99 in onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L99

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h:99:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size_, stream);

size_t next_token_size = SafeInt<size_t>(batch_beam_size_) * parameters.vocab_size;
next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size, stream);
}
}

Expand Down Expand Up @@ -124,12 +128,13 @@ struct BeamSearchCpuState : IBeamSearchCpuState {
const IGenerationParameters& parameters_;
const int batch_beam_size_{parameters_.batch_size * parameters_.num_beams};

BufferUniquePtr final_beam_scores_buffer_;
BufferUniquePtr sequence_lengths_buffer_;
BufferUniquePtr topk_scores_buffer_;
BufferUniquePtr topk_tokens_buffer_;
BufferUniquePtr topk_indices_buffer_;
BufferUniquePtr sequences_space_buffer_;
IAllocatorUniquePtr<void> final_beam_scores_buffer_;
IAllocatorUniquePtr<void> sequence_lengths_buffer_;
IAllocatorUniquePtr<void> topk_scores_buffer_;
IAllocatorUniquePtr<void> topk_tokens_buffer_;
IAllocatorUniquePtr<void> topk_indices_buffer_;
IAllocatorUniquePtr<void> sequences_space_buffer_;
IAllocatorUniquePtr<void> next_token_scores_buffer_;
};

// Base class of beam search implementation that is common for GPT-2, T5, and Whisper.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch

BeamSearchCpuState cpu_state{*parameters,
this->cpu_allocator_,
this->IsCuda()};
this->IsCuda(),
this->ort_stream_};

// buffer in GPU for input_ids, position_ids and attention_mask
IAllocatorUniquePtr<char> buffer;
Expand All @@ -240,7 +241,8 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
BeamSearchState<T> beam_state{*parameters,
this->temp_space_allocator_,
gpt_subgraph_.has_decoder_masked_attention_,
true /* use_position */};
true /* use_position */,
this->ort_stream_};

init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches

BeamSearchCpuState cpu_state{*parameters,
this->cpu_allocator_,
this->IsCuda()};
this->IsCuda(),
this->ort_stream_};

IAllocatorUniquePtr<char> buffer;

Expand Down Expand Up @@ -195,7 +196,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
BeamSearchState<T> beam_state{*parameters,
this->temp_space_allocator_,
decoder_subgraph_.has_decoder_masked_attention_,
false /* use_position */};
false /* use_position */,
this->ort_stream_};

init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
Expand Down
Loading

0 comments on commit 67c43a7

Please sign in to comment.