diff --git a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu index 2072f001bc3a1..7c3f2963207e6 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu +++ b/onnxruntime/contrib_ops/cuda/tensor/dynamic_time_warping_impl.cu @@ -133,7 +133,8 @@ Status LaunchDynamicTimeWarping( ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemcpyAsync(&result_len, result_len_device_buf, sizeof(size_t), cudaMemcpyDeviceToHost, stream))); - return CUDA_CALL(cudaGetLastError()); + ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError())); + return CUDA_CALL(cudaStreamSynchronize(stream)); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu index 996f340b483a3..1d385bc96f7e5 100644 --- a/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu +++ b/onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu @@ -13,7 +13,7 @@ namespace contrib { namespace cuda { template -__global__ void UnfoldTensorKenel( +__global__ void UnfoldTensorKernel( const T* input, T* output, int64_t N, @@ -71,27 +71,27 @@ Status LaunchUnfoldTensor( dim3 grid((unsigned)SafeInt(num_blocks)); switch (element_size) { case 1: - UnfoldTensorKenel<<>>( + UnfoldTensorKernel<<>>( (const int8_t*)input, (int8_t*)output, N, unfold_size, tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); break; case 2: - UnfoldTensorKenel<<>>( + UnfoldTensorKernel<<>>( (const int16_t*)input, (int16_t*)output, N, unfold_size, tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); break; case 4: - UnfoldTensorKenel<<>>( + UnfoldTensorKernel<<>>( (const int32_t*)input, (int32_t*)output, N, unfold_size, tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); break; case 8: - UnfoldTensorKenel<<>>( + UnfoldTensorKernel<<>>( (const int64_t*)input, (int64_t*)output, N, unfold_size, tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); break; case 16: - UnfoldTensorKenel<<>>( + UnfoldTensorKernel<<>>( (const float4*)input, (float4*)output, N, unfold_size, tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src); break; diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index c5f43ce782331..2ad248b0a22be 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -197,7 +197,7 @@ def parse_arguments(argv=None): "--no_speech_token_id", default=50362, type=int, - help="specify no_speech_token_id. Default is 1000. if >= 0, will be add into beam search attr", + help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr", ) parser.add_argument( diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index a1ed0c7ed5ca2..c0db8fb86c571 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -6,7 +6,6 @@ from onnx import TensorProto, helper from transformers import WhisperConfig -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from benchmark_helper import Precision # noqa: E402 from convert_generation import ( # noqa: E402 get_shared_initializers, diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 55c51435823c6..c9db1fbc02931 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -378,6 +378,21 @@ def test_logits_processor(self): logits_processor = ["--use_logits_processor"] self.run_configs(logits_processor) + @pytest.mark.slow + def test_cross_qk_overall(self): + decoder_input_ids = [ + "--chain_model", + "--collect_cross_qk", + "--output_cross_qk", + "--use_forced_decoder_ids", + "--extra_decoding_ids", + "--output_no_speech_probs", + "--use_vocab_mask", + "--use_prefix_vocab_mask", + "--use_logits_processor", + ] + self.run_configs(decoder_input_ids) + if __name__ == "__main__": unittest.main()