Skip to content

Commit

Permalink
Whisper Crash Fix (#19345)
Browse files Browse the repository at this point in the history
### Description
There is a current bug in the BeamSearch implementation of T5, GPT, and
Whisper due to an interaction between two PRs merged in the past 7
months.

First PR/code change is the addition of BeamSearchScorer GPU
implementation. This PR accelerates some operations by executing them in
the GPU and not the CPU. The approach for this code change didn't
utilize a cudaStream when copying one particular variable from GPU to
CPU (see nullptr value here:
[[link](https://github.com/microsoft/onnxruntime/blob/b65d3d0a5374daa3bc9272c2c02763a8428660db/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h#L213)]).

The second PR/code change was the alteration to utilize a cudaStream to
initialize various memory buffers in BeamSearch (see `stream` included
as the last argument in these allocations
[[link](https://github.com/microsoft/onnxruntime/blob/d1431e1b78fb81bf90fdc58c9118cb011171f387/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L25)]).

During the in-between period of these two PRs, I believe neither
allocation utilized a stream and were thus synchronized. Once the latter
PR was merged, the copy became desynchronized with the initialization
due to different streams.

The fix for this is to reintroduce the same stream into the copy
operation added in the first PR.



### Motivation and Context
This does not happen reliably on every hardware with every script due to
the race condition nature, but the bug completely breaks ORT execution
with a BeamSearch model.

---------

Co-authored-by: Peter McAughan <[email protected]>
  • Loading branch information
petermcaughan and Peter McAughan authored Jan 31, 2024
1 parent dd1f6cc commit 4562c91
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
cpu_state.sequences.InitDevice(beam_state.sequences_device);
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
nullptr,
this->ort_stream_,
DeviceCopyDirection::hostToDevice));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
cpu_state.sequences.InitDevice(beam_state.sequences_device);
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
nullptr,
this->ort_stream_,
DeviceCopyDirection::hostToDevice));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
cpu_state.sequences.InitDevice(beam_state.sequences_device);
ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2),
cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2),
nullptr,
this->ort_stream_,
DeviceCopyDirection::hostToDevice));
}

Expand Down

0 comments on commit 4562c91

Please sign in to comment.