Skip to content

Commit

Permalink
Add GatherND operator (#1089)
Browse files Browse the repository at this point in the history
Add ref and gpu implementations for ONNX op GatherND

Resolves #1032
  • Loading branch information
turneram authored Apr 29, 2022
1 parent 4c72cc9 commit 4ec35e5
Show file tree
Hide file tree
Showing 17 changed files with 671 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ register_migraphx_ops(
flatten
floor
gather
gathernd
get_tuple_elem
greater
gru
Expand Down
131 changes: 131 additions & 0 deletions src/include/migraphx/op/gathernd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct gathernd
{
int batch_dims = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.batch_dims, "batch_dims"));
}

std::string name() const { return "gathernd"; }

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(2);
auto r = inputs.front().lens().size();
auto q = inputs.back().lens().size();
auto k = inputs.back().lens().back();
if(k > r - batch_dims)
{
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
" cannot be used to access data of rank " +
std::to_string(r - batch_dims));
}
auto indices_lens_iter = inputs.back().lens().begin();
auto output_lens_size = q + r - k - batch_dims - 1;
std::vector<std::size_t> output_lens(output_lens_size);
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
if(k < r - batch_dims)
{
auto data_lens = inputs.front().lens();
std::copy(
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
}
shape output_shape{inputs.front().type(), output_lens};
return output_shape;
}

argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto data) {
args[1].visit([&](auto indices) {
auto indices_shape = indices.get_shape();
auto indices_shape_lens = indices_shape.lens();
auto data_shape = data.get_shape();
auto data_shape_lens = data_shape.lens();
auto k = indices_shape.lens().back();
const auto num_slice_dims = k;
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
std::size_t num_batches = std::accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
std::size_t data_batch_stride =
std::accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
auto num_slices_per_batch = num_slices / num_batches;

std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims);
{
auto running_product = slice_size;
for(std::size_t i = 0; i < num_slice_dims; ++i)
{
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product;
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i];
}
}

std::vector<std::size_t> input_slice_offsets(num_slices);
par_for(num_slices, [&](const auto i) {
std::size_t batch_idx = i / num_slices_per_batch;

auto slice_indices = indices.begin() + (i * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx)
{
int64_t index = *(slice_indices + dim_idx);
const std::size_t input_dim_idx = batch_dims + dim_idx;
const auto input_dim = data_shape_lens[input_dim_idx];
if(index < -static_cast<int64_t>(input_dim) or
index >= static_cast<int64_t>(input_dim))
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) +
" is out of bounds for dim of len " +
std::to_string(input_dim));
if(index < 0)
index += input_dim;

relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
}

input_slice_offsets[i] =
(batch_idx * data_batch_stride) + relative_slice_offset;
});

par_for(num_slices * slice_size, [&](const auto i) {
auto slice_offset = input_slice_offsets[i / slice_size];
output[i] = data[slice_offset + i % slice_size];
});
});
});

return result;
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
1 change: 1 addition & 0 deletions src/include/migraphx/operators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gathernd.hpp>
#include <migraphx/op/get_tuple_elem.hpp>
#include <migraphx/op/greater.hpp>
#include <migraphx/op/gru.hpp>
Expand Down
1 change: 1 addition & 0 deletions src/onnx/parse_generic_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Flatten", "flatten"},
{"Floor", "floor"},
{"Gather", "gather"},
{"GatherND", "gathernd"},
{"Identity", "identity"},
{"IsNaN", "isnan"},
{"LeakyRelu", "leaky_relu"},
Expand Down
75 changes: 75 additions & 0 deletions src/targets/gpu/jit/gathernd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>

#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

// NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";

struct gathernd_compiler : compiler<gathernd_compiler>
{
std::vector<std::string> names() const { return {"gathernd"}; }

operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gathernd_kernel";
options.virtual_inputs = inputs;

// batch_dims
assert(v.contains("batch_dims"));
auto batch_dims = v.at("batch_dims").to<int64_t>();
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);

return compile_hip_code_object(gathernd_kernel, options);
}

compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
10 changes: 10 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ struct greater
}
};

template <class InputIt, class T, class BinaryOperation>
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
{
for(; first != last; ++first)
{
init = op(std::move(init), *first);
}
return init;
}

template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
Expand Down
81 changes: 81 additions & 0 deletions src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP

#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>

namespace migraphx {

template <class T>
struct gathernd_settings
{
T batch_dims{};
};

template <class... Ts>
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
{
return {xs...};
}

template <class T, class U, class V, class Settings>
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
{
auto ind = make_index();
auto batch_dims = s.batch_dims;
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto data_shape = data_t.get_shape();

auto indices_shape_lens = indices_shape.lens;
auto data_shape_lens = data_shape.lens;
auto num_slice_dims = indices_shape_lens.back();
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
indices_shape_lens.end() - 1,
1,
std::multiplies<std::size_t>());
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
data_shape_lens.begin() + batch_dims,
1,
std::multiplies<std::size_t>());
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
data_shape_lens.end(),
1,
std::multiplies<std::size_t>());
const auto num_slices_per_batch = num_slices / num_batches;

ind.global_stride(output_shape.elements(), [&](auto i) {
const auto* indices_ptr = indices_t.data();
const std::size_t j = i / slice_size;
const std::size_t batch_idx = j / num_slices_per_batch;

auto* slice_indices = indices_ptr + (j * num_slice_dims);
std::size_t relative_slice_offset = 0;
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
{
int64_t index = slice_indices[idx];
const std::size_t input_dim_idx = batch_dims + idx;
const auto input_dim = data_shape_lens[input_dim_idx];
assert(index >= -static_cast<int64_t>(input_dim) and
index < static_cast<int64_t>(input_dim));
if(index < 0)
index += input_dim;
std::size_t size_from_slice_dims =
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
data_shape_lens.begin() + batch_dims + num_slice_dims,
slice_size,
std::multiplies<std::size_t>());
relative_slice_offset += index * size_from_slice_dims;
}

auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
output_t[i] = data_t[slice_offset + i % slice_size];
});
}

} // namespace migraphx
#endif
19 changes: 19 additions & 0 deletions test/onnx/gathernd_batch_dims_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
gathernd_batch_dims_test:�
/
data
indicesy"GatherND*

batch_dims�gathernd_batch_dims_testZ
data



Z
indices


b
y


B
16 changes: 16 additions & 0 deletions test/onnx/gathernd_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
gathernd_test:q

data
indicesy"GatherNDgathernd_testZ
data


Z
indices


b
y


B
Expand Down
Loading

0 comments on commit 4ec35e5

Please sign in to comment.