Skip to content

Commit

Permalink
fix build error on pure cpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghuanrong committed Sep 21, 2023
1 parent 67c43a7 commit 373327f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
8 changes: 2 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
create_beam_scorer_func_,
update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK,
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK,
cuda_device_prop_,
cuda_device_arch_};
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK};

#ifdef USE_CUDA
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
Expand All @@ -347,9 +345,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float16_func_,
create_beam_scorer_func_,
update_decoder_cross_qk_func_ ? update_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::UpdateDecoderCrossQK,
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK,
cuda_device_prop_,
cuda_device_arch_};
finalize_decoder_cross_qk_func_ ? finalize_decoder_cross_qk_func_ : GenerationCpuDeviceHelper::FinalizeDecoderCrossQK};

#ifdef USE_CUDA
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class BeamSearchWhisper : public BeamSearchBase<T> {
const GenerationDeviceHelper::ExpandBufferFunc<MLFloat16>& expand_buffer_float16_func,
const GenerationDeviceHelper::CreateBeamScorer& create_beam_scorer_func,
const GenerationDeviceHelper::UpdateDecoderCrossQKFunc& update_decoder_cross_qk_func,
const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func,
const void* cuda_device_prop,
int cuda_device_arch)
const GenerationDeviceHelper::FinalizeDecoderCrossQKFunc& finalize_decoder_cross_qk_func)
: BeamSearchBase<T>(context, decoder_session_state, thread_pool,
ort_stream, cuda_dumper, params,
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
Expand All @@ -56,8 +54,8 @@ class BeamSearchWhisper : public BeamSearchBase<T> {
create_beam_scorer_func_(create_beam_scorer_func),
update_decoder_cross_qk_func_(update_decoder_cross_qk_func),
finalize_decoder_cross_qk_func_(finalize_decoder_cross_qk_func),
cuda_device_prop_(cuda_device_prop),
cuda_device_arch_(cuda_device_arch) {}
cuda_device_prop_(nullptr),
cuda_device_arch_(0) {}

#ifdef USE_CUDA
Status InitializeCuda(
Expand Down

0 comments on commit 373327f

Please sign in to comment.