Skip to content

Commit

Permalink
Changes according to PR
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghuanrong committed Oct 4, 2023
1 parent 373327f commit ccc23de
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ Status LaunchDynamicTimeWarping(
ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError()));

ORT_RETURN_IF_ERROR(CUDA_CALL(cudaMemcpyAsync(&result_len, result_len_device_buf, sizeof(size_t), cudaMemcpyDeviceToHost, stream)));
return CUDA_CALL(cudaGetLastError());
ORT_RETURN_IF_ERROR(CUDA_CALL(cudaGetLastError()));
return CUDA_CALL(cudaStreamSynchronize(stream));
}

} // namespace cuda
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cuda/tensor/unfold_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace contrib {
namespace cuda {

template <typename T>
__global__ void UnfoldTensorKenel(
__global__ void UnfoldTensorKernel(
const T* input,
T* output,
int64_t N,
Expand Down Expand Up @@ -71,27 +71,27 @@ Status LaunchUnfoldTensor(
dim3 grid((unsigned)SafeInt<unsigned>(num_blocks));
switch (element_size) {
case 1:
UnfoldTensorKenel<int8_t><<<grid, block, 0, stream>>>(
UnfoldTensorKernel<int8_t><<<grid, block, 0, stream>>>(
(const int8_t*)input, (int8_t*)output, N, unfold_size,
tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src);
break;
case 2:
UnfoldTensorKenel<int16_t><<<grid, block, 0, stream>>>(
UnfoldTensorKernel<int16_t><<<grid, block, 0, stream>>>(
(const int16_t*)input, (int16_t*)output, N, unfold_size,
tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src);
break;
case 4:
UnfoldTensorKenel<int32_t><<<grid, block, 0, stream>>>(
UnfoldTensorKernel<int32_t><<<grid, block, 0, stream>>>(
(const int32_t*)input, (int32_t*)output, N, unfold_size,
tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src);
break;
case 8:
UnfoldTensorKenel<int64_t><<<grid, block, 0, stream>>>(
UnfoldTensorKernel<int64_t><<<grid, block, 0, stream>>>(
(const int64_t*)input, (int64_t*)output, N, unfold_size,
tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src);
break;
case 16:
UnfoldTensorKenel<float4><<<grid, block, 0, stream>>>(
UnfoldTensorKernel<float4><<<grid, block, 0, stream>>>(
(const float4*)input, (float4*)output, N, unfold_size,
tailing_dims_size, stride_leading_dst, stride_fold_dim_src, stride_leading_src);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def parse_arguments(argv=None):
"--no_speech_token_id",
default=50362,
type=int,
help="specify no_speech_token_id. Default is 1000. if >= 0, will be add into beam search attr",
help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr",
)

parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from onnx import TensorProto, helper
from transformers import WhisperConfig

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from benchmark_helper import Precision # noqa: E402
from convert_generation import ( # noqa: E402
get_shared_initializers,
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/python/transformers/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,21 @@ def test_logits_processor(self):
logits_processor = ["--use_logits_processor"]
self.run_configs(logits_processor)

@pytest.mark.slow
def test_cross_qk_overall(self):
decoder_input_ids = [
"--chain_model",
"--collect_cross_qk",
"--output_cross_qk",
"--use_forced_decoder_ids",
"--extra_decoding_ids",
"--output_no_speech_probs",
"--use_vocab_mask",
"--use_prefix_vocab_mask",
"--use_logits_processor",
]
self.run_configs(decoder_input_ids)


if __name__ == "__main__":
unittest.main()

0 comments on commit ccc23de

Please sign in to comment.