From 4562c910fed0d8446497388415d103be40b65157 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Tue, 30 Jan 2024 21:53:18 -0800 Subject: [PATCH] Whisper Crash Fix (#19345) ### Description There is a current bug in the BeamSearch implementation of T5, GPT, and Whisper due to an interaction between two PRs merged in the past 7 months. First PR/code change is the addition of BeamSearchScorer GPU implementation. This PR accelerates some operations by executing them in the GPU and not the CPU. The approach for this code change didn't utilize a cudaStream when copying one particular variable from GPU to CPU (see nullptr value here: [[link](https://github.com/microsoft/onnxruntime/blob/b65d3d0a5374daa3bc9272c2c02763a8428660db/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h#L213)]). The second PR/code change was the alteration to utilize a cudaStream to initialize various memory buffers in BeamSearch (see `stream` included as the last argument in these allocations [[link](https://github.com/microsoft/onnxruntime/blob/d1431e1b78fb81bf90fdc58c9118cb011171f387/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L25)]). During the in-between period of these two PRs, I believe neither allocation utilized a stream and were thus synchronized. Once the latter PR was merged, the copy became desynchronized with the initialization due to different streams. The fix for this is to reintroduce the same stream into the copy operation added in the first PR. ### Motivation and Context This does not happen reliably on every hardware with every script due to the race condition nature, but the bug completely breaks ORT execution with a BeamSearch model. --------- Co-authored-by: Peter McAughan --- onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h | 2 +- onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h | 2 +- .../contrib_ops/cpu/transformers/beam_search_impl_whisper.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index dc72a038c3d58..b18e122980eda 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -258,7 +258,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index cd891a9508019..8f5cdc97f27e5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -214,7 +214,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 4d6643c68a98b..72e6d3930a548 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -226,7 +226,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); }