Skip to content

Commit

Permalink
[NVIDIA] Fix Gather behavior with out of range indices to be consiste…
Browse files Browse the repository at this point in the history
…nt with Gather v8 requirements (#694)

* [NVIDIA] Fix Gather behavior with out of range indices to be consistent with Gather v8 requirements

* [NVIDIA] Add Gather test for out of range indices
  • Loading branch information
apavliuk55 authored Jul 27, 2023
1 parent 34f5c54 commit 42f03c5
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 130 deletions.
174 changes: 68 additions & 106 deletions modules/nvidia_plugin/src/kernels/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace nvidia_gpu {

namespace kernel {

template <bool IsBenchmarkMode, typename DataType, typename IndexType>
template <typename DataType, typename IndexType>
static inline __device__ void gather(unsigned data_length,
size_t index_range,
unsigned els_per_thread,
Expand All @@ -30,26 +30,23 @@ static inline __device__ void gather(unsigned data_length,
if (dict_index < 0) {
dict_index += index_range;
}
if (dict_index >= index_range) {
if (IsBenchmarkMode) {
dict_index = 0;
} else {
// TODO: find a way to handle an error raised in a kernel (assertion or trap) properly
__trap();
}
}
unsigned thread_offset;
for (int el = 0; el < els_per_thread; ++el) {
thread_offset = chunk + el;
if (thread_offset >= data_length) {
return;
}
dst_data[data_length * (indices_index + dict * indices_size) + thread_offset] =
src_dict[data_length * (dict_index + dict * index_range) + thread_offset];
const auto dst_index = data_length * (indices_index + dict * indices_size) + thread_offset;
if (dict_index < index_range) {
const auto src_index = data_length * (dict_index + dict * index_range) + thread_offset;
dst_data[dst_index] = src_dict[src_index];
} else {
dst_data[dst_index] = static_cast<DataType>(0.0f);
}
}
}

template <bool IsBenchmarkMode, typename DataType, typename IndexType>
template <typename DataType, typename IndexType>
static __global__ void chunks_gather(unsigned data_length,
unsigned indices_size,
size_t index_range,
Expand All @@ -64,19 +61,19 @@ static __global__ void chunks_gather(unsigned data_length,
const auto indices_index = blockIdx.x % indices_size;
const auto batch = blockIdx.x / indices_size;
const auto chunk = (blockIdx.z * blockDim.x + threadIdx.x) * els_per_thread;
gather<IsBenchmarkMode>(data_length,
index_range,
els_per_thread,
indices_size,
indices_index,
dict,
chunk,
src_dict + batch * dicts_batch_stride,
src_index + batch * indices_batch_stride,
dst_data + batch * out_batch_stride);
gather(data_length,
index_range,
els_per_thread,
indices_size,
indices_index,
dict,
chunk,
src_dict + batch * dicts_batch_stride,
src_index + batch * indices_batch_stride,
dst_data + batch * out_batch_stride);
}

template <bool IsBenchmarkMode, typename DataType, typename IndexType>
template <typename DataType, typename IndexType>
static __global__ void dicts_gather(unsigned num_dicts,
unsigned indices_size,
size_t index_range,
Expand All @@ -95,16 +92,16 @@ static __global__ void dicts_gather(unsigned num_dicts,
}
const auto indices_index = blockIdx.x % indices_size;
const auto batch = blockIdx.x / indices_size;
gather<IsBenchmarkMode>(data_length,
index_range,
els_per_thread,
indices_size,
indices_index,
dict,
chunk,
src_dict + batch * dicts_batch_stride,
src_index + batch * indices_batch_stride,
dst_data + batch * out_batch_stride);
gather(data_length,
index_range,
els_per_thread,
indices_size,
indices_index,
dict,
chunk,
src_dict + batch * dicts_batch_stride,
src_index + batch * indices_batch_stride,
dst_data + batch * out_batch_stride);
}

Gather::Gather(Type_t element_type,
Expand Down Expand Up @@ -143,16 +140,12 @@ Gather::Gather(Type_t element_type,
TypeValidator<ElementTypesSwitch<Type_t::i64, Type_t::i32>>::check(indices_type_);
}

void Gather::operator()(const cudaStream_t stream,
bool is_benchmark_mode,
const void* src_dict,
const void* src_index,
void* dst_data) const {
void Gather::operator()(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const {
switch (indices_type_) {
case Type_t::i64:
return CallByDataType<int64_t>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return CallByDataType<int64_t>(stream, src_dict, src_index, dst_data);
case Type_t::i32:
return CallByDataType<int32_t>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return CallByDataType<int32_t>(stream, src_dict, src_index, dst_data);
default:
throw_ov_exception(
fmt::format("Index element type = {} is not supported by Gather operation !!", indices_type_));
Expand All @@ -161,105 +154,74 @@ void Gather::operator()(const cudaStream_t stream,

template <typename IndexType>
void Gather::CallByDataType(const cudaStream_t stream,
bool is_benchmark_mode,
const void* src_dict,
const void* src_index,
void* dst_data) const {
switch (element_type_) {
case Type_t::boolean:
return Call<bool, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<bool, IndexType>(stream, src_dict, src_index, dst_data);
#ifdef CUDA_HAS_BF16_TYPE
case Type_t::bf16:
return Call<__nv_bfloat16, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<__nv_bfloat16, IndexType>(stream, src_dict, src_index, dst_data);
#endif
case Type_t::f16:
return Call<__half, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<__half, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::f32:
return Call<float, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<float, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::f64:
return Call<double, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<double, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::i8:
return Call<int8_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<int8_t, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::i16:
return Call<int16_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<int16_t, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::i32:
return Call<int32_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<int32_t, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::i64:
return Call<int64_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<int64_t, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::u8:
return Call<uint8_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<uint8_t, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::u16:
return Call<uint16_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<uint16_t, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::u32:
return Call<uint32_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<uint32_t, IndexType>(stream, src_dict, src_index, dst_data);
case Type_t::u64:
return Call<uint64_t, IndexType>(stream, is_benchmark_mode, src_dict, src_index, dst_data);
return Call<uint64_t, IndexType>(stream, src_dict, src_index, dst_data);
default:
throw_ov_exception(
fmt::format("Index element type = {} is not supported by Gather operation !!", indices_type_));
}
}

template <typename DataType, typename IndexType>
void Gather::Call(const cudaStream_t stream,
bool is_benchmark_mode,
const void* src_dict,
const void* src_index,
void* dst_data) const {
void Gather::Call(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const {
dim3 grid{grid_dim_x_, grid_dim_y_, blocks_per_grid_};

const auto src_dict_typed = static_cast<const DataType*>(src_dict);
const auto src_index_typed = static_cast<const IndexType*>(src_index);
auto dst_data_typed = static_cast<DataType*>(dst_data);

if (is_benchmark_mode) {
if (gather_chunks_) {
kernel::chunks_gather<true><<<grid, threads_per_block_, 0, stream>>>(data_length_,
indices_size_,
index_range_,
dicts_batch_stride_,
indices_batch_stride_,
out_batch_stride_,
els_per_thread_chunks_,
src_dict_typed,
src_index_typed,
dst_data_typed);
} else {
kernel::dicts_gather<true><<<grid, threads_per_block_, 0, stream>>>(num_dicts_,
indices_size_,
index_range_,
dicts_batch_stride_,
indices_batch_stride_,
out_batch_stride_,
els_per_thread_dicts_,
src_dict_typed,
src_index_typed,
dst_data_typed);
}
if (gather_chunks_) {
kernel::chunks_gather<<<grid, threads_per_block_, 0, stream>>>(data_length_,
indices_size_,
index_range_,
dicts_batch_stride_,
indices_batch_stride_,
out_batch_stride_,
els_per_thread_chunks_,
src_dict_typed,
src_index_typed,
dst_data_typed);
} else {
if (gather_chunks_) {
kernel::chunks_gather<false><<<grid, threads_per_block_, 0, stream>>>(data_length_,
indices_size_,
index_range_,
dicts_batch_stride_,
indices_batch_stride_,
out_batch_stride_,
els_per_thread_chunks_,
src_dict_typed,
src_index_typed,
dst_data_typed);
} else {
kernel::dicts_gather<false><<<grid, threads_per_block_, 0, stream>>>(num_dicts_,
indices_size_,
index_range_,
dicts_batch_stride_,
indices_batch_stride_,
out_batch_stride_,
els_per_thread_dicts_,
src_dict_typed,
src_index_typed,
dst_data_typed);
}
kernel::dicts_gather<<<grid, threads_per_block_, 0, stream>>>(num_dicts_,
indices_size_,
index_range_,
dicts_batch_stride_,
indices_batch_stride_,
out_batch_stride_,
els_per_thread_dicts_,
src_dict_typed,
src_index_typed,
dst_data_typed);
}
// TODO: find a way to handle an error raised in a kernel (assertion or trap) properly in CUDA Plugin
cudaError_t err = cudaGetLastError();
Expand Down
18 changes: 3 additions & 15 deletions modules/nvidia_plugin/src/kernels/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,14 @@ class Gather {
unsigned els_per_thread_chunks,
unsigned els_per_thread_dicts);

void operator()(const cudaStream_t stream,
bool is_benchmark_mode,
const void* src_dict,
const void* src_index,
void* dst_data) const;
void operator()(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const;

private:
template <typename IndexType>
void CallByDataType(const cudaStream_t stream,
bool is_benchmark_mode,
const void* src_dict,
const void* src_index,
void* dst_data) const;
void CallByDataType(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const;

template <typename DataType, typename IndexType>
void Call(const cudaStream_t stream,
bool is_benchmark_mode,
const void* src_dict,
const void* src_index,
void* dst_data) const;
void Call(const cudaStream_t stream, const void* src_dict, const void* src_index, void* dst_data) const;

Type_t element_type_;
Type_t indices_type_;
Expand Down
10 changes: 3 additions & 7 deletions modules/nvidia_plugin/src/ops/gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#include <cuda_operation_registry.hpp>
#include <error.hpp>
#include <openvino/core/except.hpp>
#include <numeric>
#include <openvino/core/except.hpp>
#include <openvino/op/gather.hpp>

#include "converters.hpp"
Expand Down Expand Up @@ -52,7 +52,7 @@ GatherOp::GatherOp(const CreationContext& context,
case ov::element::Type_t::dynamic:
case ov::element::Type_t::u1:
throw_ov_exception(fmt::format("Params element type = {} is not supported by Gather operation!",
static_cast<ov::element::Type_t>(element_type)));
static_cast<ov::element::Type_t>(element_type)));
}
OPENVINO_ASSERT(node.get_output_element_type(0) == element_type, "Node name: ", GetName());

Expand Down Expand Up @@ -175,11 +175,7 @@ void GatherOp::Execute(const InferenceRequestContext& context,
OPENVINO_ASSERT(inputs.size() == 3, "Node name: ", GetName());
OPENVINO_ASSERT(outputs.size() == 1, "Node name: ", GetName());

(*gather_kernel_)(context.getThreadContext().stream().get(),
context.isBenchmarkMode(),
inputs[0].get(),
inputs[1].get(),
outputs[0].get());
(*gather_kernel_)(context.getThreadContext().stream().get(), inputs[0].get(), inputs[1].get(), outputs[0].get());
}

bool GatherOp::IsCudaGraphCompatible() const { return true; }
Expand Down
Loading

0 comments on commit 42f03c5

Please sign in to comment.