Skip to content

Commit

Permalink
Support output cross qk in masked decoder multihead attention kernel
Browse files Browse the repository at this point in the history
Support cross qk in beam search for whisper model and related features
Make whisper exporting tools support cross qk and some related features,
    * extra_decoding_ids
    * no_speech_prob
Implement DTW kernel, UnfoldTensor kernel with unit test and shape inference
Several fix related with multiple session running parallel, like:
    * guard multihead_attention, fused_fp16_runner_
    * some memory allocation with stream awareness
    * add use_ep_level_unified_stream option
Make timestamp Logits Processor for GPU after beam scoring move to GPU,
add GPU test for whisper timestamp Logits Procesor
  • Loading branch information
zhanghuanrong committed Sep 19, 2023
1 parent d522cc7 commit cb25aa7
Show file tree
Hide file tree
Showing 59 changed files with 2,007 additions and 275 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ struct OrtCUDAProviderOptionsV2 {
int tunable_op_max_tuning_duration_ms = 0; // Max tuning duration time limit for TunableOp.
int enable_skip_layer_norm_strict_mode = 0; // flag specifying if SkipLayerNorm is in strict mode. If true, use LayerNormalization kernel.
// The strict mode has better accuracy but lower performance.
int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not
};
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ Status CheckInputs(const T* query,
}
if (past_key_dims[2] != past_value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length)");
"Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length). ",
past_key_dims[2], " vs ", past_value_dims[2]);
}
if (past_key_dims[3] != head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
create_beam_scorer_func_};
create_beam_scorer_func_,
update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK,
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK,
cuda_device_prop_,
cuda_device_arch_};

#ifdef USE_CUDA
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
Expand All @@ -340,7 +345,12 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
update_decoder_feeds_fp16_func_ ? update_decoder_feeds_fp16_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<MLFloat16>,
expand_buffer_float_func_,
expand_buffer_float16_func_,
create_beam_scorer_func_};
create_beam_scorer_func_,
update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK,
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK,
cuda_device_prop_,
cuda_device_arch_};

#ifdef USE_CUDA
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,16 @@ class BeamSearch : public IControlFlowKernel {
const GenerationDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func,
const GenerationDeviceHelper::ExpandBufferFunc<int32_t>& expand_buffer_int32_func,
const GenerationDeviceHelper::ExpandBufferFunc<float>& expand_buffer_float_func,
const GenerationDeviceHelper::ExpandBufferFunc<MLFloat16>& expand_buffer_float16_func) {
const GenerationDeviceHelper::ExpandBufferFunc<MLFloat16>& expand_buffer_float16_func,
const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func,
const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func) {
update_decoder_feeds_func_ = update_decoder_feeds_func;
update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
expand_buffer_int32_func_ = expand_buffer_int32_func;
expand_buffer_float_func_ = expand_buffer_float_func;
expand_buffer_float16_func_ = expand_buffer_float16_func;
update_decoder_cross_qk_func_ = update_decoder_cross_qk_func;
finalize_decoder_cross_qk_func_ = finalize_decoder_cross_qk_func;
}

#ifdef USE_CUDA
Expand Down Expand Up @@ -175,6 +179,10 @@ class BeamSearch : public IControlFlowKernel {
BeamSearchParameters parameters_;

bool has_init_decoder_ = false;

GenerationDeviceHelper::UpdateDecoderCrossQKFunc update_decoder_cross_qk_func_;

GenerationDeviceHelper::FinalizeDecoderCrossQKFunc finalize_decoder_cross_qk_func_;
};

} // namespace transformers
Expand Down
73 changes: 39 additions & 34 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,35 @@ struct BeamSearchState : IBeamSearchState<T> {
BeamSearchState(const IGenerationParameters& parameters,
AllocatorPtr allocator,
int has_decoder_masked_attention,
bool use_position) {
bool use_position,
Stream* stream) {
size_t batch_beam_size = SafeInt<size_t>(parameters.batch_size) * parameters.num_beams;

size_t next_token_size = SafeInt<size_t>(batch_beam_size) * parameters.vocab_size;
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size);
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size);
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_scores = AllocateBuffer<float>(allocator, next_scores_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size, stream);
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size, stream);
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size, stream);
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size, stream);
this->next_scores = AllocateBuffer<float>(allocator, next_scores_buffer_, SafeInt<size_t>(2) * batch_beam_size, stream);

constexpr size_t max_parts_of_vocab = 128;
size_t topk_buffer_size = SafeInt<size_t>(batch_beam_size) * (max_parts_of_vocab + 1) * parameters.num_beams * 2 * 2;
this->topk_buffer = AllocateBuffer<float>(allocator, topk_temp_buffer_, topk_buffer_size);
this->topk_buffer = AllocateBuffer<float>(allocator, topk_temp_buffer_, topk_buffer_size, stream);

if (allocator->Info().device.Type() == OrtDevice::GPU) {
size_t sequences_elements = SafeInt<size_t>(2) * batch_beam_size * parameters.max_length;
this->sequences_device = AllocateBuffer<int32_t>(allocator, sequences_device_buffer_, sequences_elements);
this->sequences_device = AllocateBuffer<int32_t>(allocator, sequences_device_buffer_, sequences_elements, stream);
}

if (use_position) {
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size);
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size, stream);
}

this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size);
this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size, stream);

if (parameters.output_scores) {
size_t elements = SafeInt<size_t>(parameters.max_length - parameters.sequence_length) * parameters.batch_size * parameters.num_beams * parameters.vocab_size;
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements, stream);
this->remaining_scores = this->scores;
}

Expand All @@ -68,35 +69,38 @@ struct BeamSearchState : IBeamSearchState<T> {
}

private:
BufferUniquePtr next_token_logits_buffer_;
BufferUniquePtr next_token_scores_buffer_;
BufferUniquePtr next_tokens_buffer_;
BufferUniquePtr next_indices_buffer_;
BufferUniquePtr next_scores_buffer_;
BufferUniquePtr next_positions_buffer_;
BufferUniquePtr beam_scores_buffer_;
BufferUniquePtr scores_buffer_;
BufferUniquePtr topk_temp_buffer_;
BufferUniquePtr sequences_device_buffer_;
IAllocatorUniquePtr<void> next_token_logits_buffer_;
IAllocatorUniquePtr<void> next_token_scores_buffer_;
IAllocatorUniquePtr<void> next_tokens_buffer_;
IAllocatorUniquePtr<void> next_indices_buffer_;
IAllocatorUniquePtr<void> next_scores_buffer_;
IAllocatorUniquePtr<void> next_positions_buffer_;
IAllocatorUniquePtr<void> beam_scores_buffer_;
IAllocatorUniquePtr<void> scores_buffer_;
IAllocatorUniquePtr<void> topk_temp_buffer_;
IAllocatorUniquePtr<void> sequences_device_buffer_;
};

struct BeamSearchCpuState : IBeamSearchCpuState {
Sequences sequences;

BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda)
BeamSearchCpuState(const IGenerationParameters& parameters, AllocatorPtr allocator, bool is_cuda, Stream* stream)
: parameters_{parameters} {
sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size_);
sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size_, stream);

size_t sequences_bytes = SafeInt<size_t>(2) * batch_beam_size_ * parameters.max_length;
sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, sequences_bytes, true /* fill */);
sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, sequences_bytes, stream, true /* fill */);
sequences.Init(sequences_space, batch_beam_size_, parameters.sequence_length, parameters.max_length);

if (is_cuda) {
// buffers used by CUDA operator but not by CPU operator.
topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * static_cast<size_t>(batch_beam_size_));
topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * static_cast<size_t>(batch_beam_size_));
topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * static_cast<size_t>(batch_beam_size_));
final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size_);
topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * static_cast<size_t>(batch_beam_size_), stream);
topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * static_cast<size_t>(batch_beam_size_), stream);
topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * static_cast<size_t>(batch_beam_size_), stream);
final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size_, stream);

size_t next_token_size = SafeInt<size_t>(batch_beam_size_) * parameters.vocab_size;
next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size, stream);
}
}

Expand Down Expand Up @@ -124,12 +128,13 @@ struct BeamSearchCpuState : IBeamSearchCpuState {
const IGenerationParameters& parameters_;
const int batch_beam_size_{parameters_.batch_size * parameters_.num_beams};

BufferUniquePtr final_beam_scores_buffer_;
BufferUniquePtr sequence_lengths_buffer_;
BufferUniquePtr topk_scores_buffer_;
BufferUniquePtr topk_tokens_buffer_;
BufferUniquePtr topk_indices_buffer_;
BufferUniquePtr sequences_space_buffer_;
IAllocatorUniquePtr<void> final_beam_scores_buffer_;
IAllocatorUniquePtr<void> sequence_lengths_buffer_;
IAllocatorUniquePtr<void> topk_scores_buffer_;
IAllocatorUniquePtr<void> topk_tokens_buffer_;
IAllocatorUniquePtr<void> topk_indices_buffer_;
IAllocatorUniquePtr<void> sequences_space_buffer_;
IAllocatorUniquePtr<void> next_token_scores_buffer_;
};

// Base class of beam search implementation that is common for GPT-2, T5, and Whisper.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch

BeamSearchCpuState cpu_state{*parameters,
this->cpu_allocator_,
this->IsCuda()};
this->IsCuda(),
this->ort_stream_};

// buffer in GPU for input_ids, position_ids and attention_mask
IAllocatorUniquePtr<char> buffer;
Expand All @@ -240,7 +241,8 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
BeamSearchState<T> beam_state{*parameters,
this->temp_space_allocator_,
gpt_subgraph_.has_decoder_masked_attention_,
true /* use_position */};
true /* use_position */,
this->ort_stream_};

init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches

BeamSearchCpuState cpu_state{*parameters,
this->cpu_allocator_,
this->IsCuda()};
this->IsCuda(),
this->ort_stream_};

IAllocatorUniquePtr<char> buffer;

Expand Down Expand Up @@ -195,7 +196,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
BeamSearchState<T> beam_state{*parameters,
this->temp_space_allocator_,
decoder_subgraph_.has_decoder_masked_attention_,
false /* use_position */};
false /* use_position */,
this->ort_stream_};

init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
Expand Down
Loading

0 comments on commit cb25aa7

Please sign in to comment.