diff --git a/modules/nvidia_plugin/src/kernels/gather.cu b/modules/nvidia_plugin/src/kernels/gather.cu index baab09576..c487d9d26 100644 --- a/modules/nvidia_plugin/src/kernels/gather.cu +++ b/modules/nvidia_plugin/src/kernels/gather.cu @@ -15,7 +15,7 @@ namespace nvidia_gpu { namespace kernel { -template +template static inline __device__ void gather(unsigned data_length, size_t index_range, unsigned els_per_thread, @@ -30,26 +30,23 @@ static inline __device__ void gather(unsigned data_length, if (dict_index < 0) { dict_index += index_range; } - if (dict_index >= index_range) { - if (IsBenchmarkMode) { - dict_index = 0; - } else { - // TODO: find a way to handle an error raised in a kernel (assertion or trap) properly - __trap(); - } - } unsigned thread_offset; for (int el = 0; el < els_per_thread; ++el) { thread_offset = chunk + el; if (thread_offset >= data_length) { return; } - dst_data[data_length * (indices_index + dict * indices_size) + thread_offset] = - src_dict[data_length * (dict_index + dict * index_range) + thread_offset]; + const auto dst_index = data_length * (indices_index + dict * indices_size) + thread_offset; + if (dict_index < index_range) { + const auto src_index = data_length * (dict_index + dict * index_range) + thread_offset; + dst_data[dst_index] = src_dict[src_index]; + } else { + dst_data[dst_index] = static_cast(0.0f); + } } } -template +template static __global__ void chunks_gather(unsigned data_length, unsigned indices_size, size_t index_range, @@ -64,19 +61,19 @@ static __global__ void chunks_gather(unsigned data_length, const auto indices_index = blockIdx.x % indices_size; const auto batch = blockIdx.x / indices_size; const auto chunk = (blockIdx.z * blockDim.x + threadIdx.x) * els_per_thread; - gather(data_length, - index_range, - els_per_thread, - indices_size, - indices_index, - dict, - chunk, - src_dict + batch * dicts_batch_stride, - src_index + batch * indices_batch_stride, - dst_data + batch * out_batch_stride); + gather(data_length, + index_range, + els_per_thread, + indices_size, + indices_index, + dict, + chunk, + src_dict + batch * dicts_batch_stride, + src_index + batch * indices_batch_stride, + dst_data + batch * out_batch_stride); } -template +template static __global__ void dicts_gather(unsigned num_dicts, unsigned indices_size, size_t index_range, @@ -95,16 +92,16 @@ static __global__ void dicts_gather(unsigned num_dicts, } const auto indices_index = blockIdx.x % indices_size; const auto batch = blockIdx.x / indices_size; - gather(data_length, - index_range, - els_per_thread, - indices_size, - indices_index, - dict, - chunk, - src_dict + batch * dicts_batch_stride, - src_index + batch * indices_batch_stride, - dst_data + batch * out_batch_stride); + gather(data_length, + index_range, + els_per_thread, + indices_size, + indices_index, + dict, + chunk, + src_dict + batch * dicts_batch_stride, + src_index + batch * indices_batch_stride, + dst_data + batch * out_batch_stride); } Gather::Gather(Type_t element_type, @@ -143,16 +140,12 @@ Gather::Gather(Type_t element_type, TypeValidator>::check(indices_type_); } -void Gather::operator()(const cudaStream_t stream, - bool is_benchmark_mode, - const void* src_dict, - const void* src_index, - void* dst_data) const { +void Gather::operator()(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const { switch (indices_type_) { case Type_t::i64: - return CallByDataType(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return CallByDataType(stream, src_dict, src_index, dst_data); case Type_t::i32: - return CallByDataType(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return CallByDataType(stream, src_dict, src_index, dst_data); default: throw_ov_exception( fmt::format("Index element type = {} is not supported by Gather operation !!", indices_type_)); @@ -161,39 +154,38 @@ void Gather::operator()(const cudaStream_t stream, template void Gather::CallByDataType(const cudaStream_t stream, - bool is_benchmark_mode, const void* src_dict, const void* src_index, void* dst_data) const { switch (element_type_) { case Type_t::boolean: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); #ifdef CUDA_HAS_BF16_TYPE case Type_t::bf16: - return Call<__nv_bfloat16, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call<__nv_bfloat16, IndexType>(stream, src_dict, src_index, dst_data); #endif case Type_t::f16: - return Call<__half, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call<__half, IndexType>(stream, src_dict, src_index, dst_data); case Type_t::f32: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::f64: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::i8: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::i16: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::i32: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::i64: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::u8: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::u16: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::u32: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); case Type_t::u64: - return Call(stream, is_benchmark_mode, src_dict, src_index, dst_data); + return Call(stream, src_dict, src_index, dst_data); default: throw_ov_exception( fmt::format("Index element type = {} is not supported by Gather operation !!", indices_type_)); @@ -201,65 +193,35 @@ void Gather::CallByDataType(const cudaStream_t stream, } template -void Gather::Call(const cudaStream_t stream, - bool is_benchmark_mode, - const void* src_dict, - const void* src_index, - void* dst_data) const { +void Gather::Call(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const { dim3 grid{grid_dim_x_, grid_dim_y_, blocks_per_grid_}; const auto src_dict_typed = static_cast(src_dict); const auto src_index_typed = static_cast(src_index); auto dst_data_typed = static_cast(dst_data); - if (is_benchmark_mode) { - if (gather_chunks_) { - kernel::chunks_gather<<>>(data_length_, - indices_size_, - index_range_, - dicts_batch_stride_, - indices_batch_stride_, - out_batch_stride_, - els_per_thread_chunks_, - src_dict_typed, - src_index_typed, - dst_data_typed); - } else { - kernel::dicts_gather<<>>(num_dicts_, - indices_size_, - index_range_, - dicts_batch_stride_, - indices_batch_stride_, - out_batch_stride_, - els_per_thread_dicts_, - src_dict_typed, - src_index_typed, - dst_data_typed); - } + if (gather_chunks_) { + kernel::chunks_gather<<>>(data_length_, + indices_size_, + index_range_, + dicts_batch_stride_, + indices_batch_stride_, + out_batch_stride_, + els_per_thread_chunks_, + src_dict_typed, + src_index_typed, + dst_data_typed); } else { - if (gather_chunks_) { - kernel::chunks_gather<<>>(data_length_, - indices_size_, - index_range_, - dicts_batch_stride_, - indices_batch_stride_, - out_batch_stride_, - els_per_thread_chunks_, - src_dict_typed, - src_index_typed, - dst_data_typed); - } else { - kernel::dicts_gather<<>>(num_dicts_, - indices_size_, - index_range_, - dicts_batch_stride_, - indices_batch_stride_, - out_batch_stride_, - els_per_thread_dicts_, - src_dict_typed, - src_index_typed, - dst_data_typed); - } + kernel::dicts_gather<<>>(num_dicts_, + indices_size_, + index_range_, + dicts_batch_stride_, + indices_batch_stride_, + out_batch_stride_, + els_per_thread_dicts_, + src_dict_typed, + src_index_typed, + dst_data_typed); } // TODO: find a way to handle an error raised in a kernel (assertion or trap) properly in CUDA Plugin cudaError_t err = cudaGetLastError(); diff --git a/modules/nvidia_plugin/src/kernels/gather.hpp b/modules/nvidia_plugin/src/kernels/gather.hpp index 8be30d3c2..b581be501 100644 --- a/modules/nvidia_plugin/src/kernels/gather.hpp +++ b/modules/nvidia_plugin/src/kernels/gather.hpp @@ -32,26 +32,14 @@ class Gather { unsigned els_per_thread_chunks, unsigned els_per_thread_dicts); - void operator()(const cudaStream_t stream, - bool is_benchmark_mode, - const void* src_dict, - const void* src_index, - void* dst_data) const; + void operator()(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const; private: template - void CallByDataType(const cudaStream_t stream, - bool is_benchmark_mode, - const void* src_dict, - const void* src_index, - void* dst_data) const; + void CallByDataType(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const; template - void Call(const cudaStream_t stream, - bool is_benchmark_mode, - const void* src_dict, - const void* src_index, - void* dst_data) const; + void Call(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const; Type_t element_type_; Type_t indices_type_; diff --git a/modules/nvidia_plugin/src/ops/gather.cpp b/modules/nvidia_plugin/src/ops/gather.cpp index 09d79ee11..f666cc35e 100644 --- a/modules/nvidia_plugin/src/ops/gather.cpp +++ b/modules/nvidia_plugin/src/ops/gather.cpp @@ -8,8 +8,8 @@ #include #include -#include #include +#include #include #include "converters.hpp" @@ -52,7 +52,7 @@ GatherOp::GatherOp(const CreationContext& context, case ov::element::Type_t::dynamic: case ov::element::Type_t::u1: throw_ov_exception(fmt::format("Params element type = {} is not supported by Gather operation!", - static_cast(element_type))); + static_cast(element_type))); } OPENVINO_ASSERT(node.get_output_element_type(0) == element_type, "Node name: ", GetName()); @@ -175,11 +175,7 @@ void GatherOp::Execute(const InferenceRequestContext& context, OPENVINO_ASSERT(inputs.size() == 3, "Node name: ", GetName()); OPENVINO_ASSERT(outputs.size() == 1, "Node name: ", GetName()); - (*gather_kernel_)(context.getThreadContext().stream().get(), - context.isBenchmarkMode(), - inputs[0].get(), - inputs[1].get(), - outputs[0].get()); + (*gather_kernel_)(context.getThreadContext().stream().get(), inputs[0].get(), inputs[1].get(), outputs[0].get()); } bool GatherOp::IsCudaGraphCompatible() const { return true; } diff --git a/modules/nvidia_plugin/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp b/modules/nvidia_plugin/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp index e5ffba855..659a31f52 100644 --- a/modules/nvidia_plugin/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp +++ b/modules/nvidia_plugin/tests/functional/shared_tests_instances/single_layer_tests/gather.cpp @@ -78,7 +78,7 @@ struct GatherTestParams { }; template -std::vector generate_indices(const GatherTestParams& test_params) { +std::vector generate_indices(const GatherTestParams& test_params, bool add_out_of_range = false) { static std::random_device r_device; static std::default_random_engine r_engine{r_device()}; @@ -93,7 +93,15 @@ std::vector generate_indices(const GatherTestParams& test_params) { const auto indices_size = ov::shape_size(test_params.indices_shape_); std::vector indices(indices_size); - std::generate(indices.begin(), indices.end(), [&]() { return distr(r_engine); }); + auto gen_function = [&]() { return distr(r_engine); }; + if (!add_out_of_range) { + std::generate(indices.begin(), indices.end(), gen_function); + } else { + if (indices_size > 0) { + indices[0] = test_params.params_shape_[normalized_axis]; + } + std::generate(indices.begin() + 1, indices.end(), gen_function); + } return indices; } @@ -113,6 +121,20 @@ INSTANTIATE_TEST_CASE_P(smoke_Gather_v1_01, ::testing::Values(smoke_01_params_v1_v7.device_)), CudaGatherLayerTest::getTestCaseName); +INSTANTIATE_TEST_CASE_P(smoke_Gather_v1_01_out_of_range_index, + CudaGatherLayerTest, + ::testing::Combine(::testing::Values(generate_indices(smoke_01_params_v1_v7, true)), + ::testing::Values(smoke_01_params_v1_v7.indices_shape_), + ::testing::Values(smoke_01_params_v1_v7.axis_), + ::testing::Values(smoke_01_params_v1_v7.params_shape_), + ::testing::ValuesIn(smoke_01_params_v1_v7.net_precisions_), + ::testing::Values(smoke_01_params_v1_v7.input_precision_), + ::testing::Values(smoke_01_params_v1_v7.output_precision_), + ::testing::Values(smoke_01_params_v1_v7.input_layout_), + ::testing::Values(smoke_01_params_v1_v7.output_layout_), + ::testing::Values(smoke_01_params_v1_v7.device_)), + CudaGatherLayerTest::getTestCaseName); + INSTANTIATE_TEST_CASE_P(smoke_Gather_v7_01, Gather7LayerTest, ::testing::Combine(::testing::Values(smoke_01_params_v1_v7.params_shape_),