From 7db65daffa07691b6fffd63c7de0d3ab9fb7eb1e Mon Sep 17 00:00:00 2001 From: kailums Date: Thu, 27 Jun 2024 04:09:41 +0000 Subject: [PATCH] fix for review comments --- onnxruntime/core/providers/cuda/tensor/split_impl.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index 60815741a6c5a..00f94694f83c0 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -227,7 +227,7 @@ Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t // determine threads based on the size of the output auto threadsPerBlock = kNumThreadsPerBlock; - if ((inner_size_in_byte / VEC_SIZE) < 128) { + if ((inner_size_in_byte / VEC_SIZE) <= 128) { // use less threads when the size is small threadsPerBlock = 128; } @@ -247,16 +247,16 @@ Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t CASE_ELEMENT_TYPE(int4); break; case 8: - CASE_ELEMENT_TYPE(int2); + CASE_ELEMENT_TYPE(int64_t); break; case 4: - CASE_ELEMENT_TYPE(int1); + CASE_ELEMENT_TYPE(int32_t); break; case 2: - CASE_ELEMENT_TYPE(short1); + CASE_ELEMENT_TYPE(int16_t); break; default: - CASE_ELEMENT_TYPE(char1); + CASE_ELEMENT_TYPE(int8_t); break; #undef CASE_ELEMENT_TYPE }