Skip to content

Commit

Permalink
Changes to scatter_nd ops
Browse files Browse the repository at this point in the history
* Rewrite CPU impl to be single-threaded and use vectorization; avoids race conditions.  Removes use of the generator.
* Remove scatter_nd_mul and scatter_nd_div to reduce binary size until
  we figure out a better way to reduce the templating pain
* Modify scatter_nd to add for repeated indices as opposed to update
  (this is the appropriate gradient for gather_nd, for example)
* Clean up docstrings.
Change: 138452341
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Nov 9, 2016
1 parent aac685b commit fd05b5e
Show file tree
Hide file tree
Showing 12 changed files with 412 additions and 651 deletions.
154 changes: 61 additions & 93 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,43 +146,48 @@ class ScatterNdOp : public OpKernel {
&num_updates, &slice_size);
if (!c->status().ok()) return;

Tensor scratch;
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));

auto scratch_scalar = scratch.scalar<Index>();
auto indices_flat = indices.flat_inner_dims<Index>();
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});

Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out));
functor::SetZeroFunctor<Device, T> fill;
fill(c->eigen_device<Device>(), out->flat<T>());
auto output_matrix = out->template shaped<T, 2>(
{shape.num_elements() / slice_size, slice_size});

Index bad_i = -1;
switch (indices_nd) {
#define PARAMS_CASE(IXDIM) \
case IXDIM: { \
Tensor* out = nullptr; \
OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \
functor::SetZeroFunctor<Device, T> fill; \
fill(c->eigen_device<Device>(), out->flat<T>()); \
if (shape.num_elements() > 0) { \
auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \
functor::ScatterNdFunctor<Device, T, Index, \
scatter_nd_op::UpdateOp::ASSIGN, (IXDIM)> \
functor; \
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
output_flat, indices_flat, updates_flat, output_flat); \
} \

if (shape.num_elements() > 0) {
switch (indices_nd) {
#define PARAMS_CASE(IXDIM) \
case IXDIM: { \
typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
for (int i = 0; i < IXDIM; ++i) { \
output_shape_prefix[i] = shape.dim_size(i); \
} \
functor::ScatterNdFunctor<Device, T, Index, scatter_nd_op::UpdateOp::ADD, \
IXDIM> \
functor; \
bad_i = \
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
output_matrix, indices_flat, updates_flat, output_matrix); \
} break
PARAMS_CASE(0);
PARAMS_CASE(1);
PARAMS_CASE(2);
PARAMS_CASE(3);
PARAMS_CASE(4);
PARAMS_CASE(5);
// TODO(simister): Re-enable this once binary size is under control.
// PARAMS_CASE(0);
PARAMS_CASE(1);
PARAMS_CASE(2);
PARAMS_CASE(3);
PARAMS_CASE(4);
PARAMS_CASE(5);
#undef PARAMS_CASE
default:
OP_REQUIRES(c, false,
errors::InvalidArgument(
"Only indices.shape[-1] values between 0 and 5 "
"are currently supported. Requested rank: ",
indices_nd));
default:
OP_REQUIRES(c, false,
errors::InvalidArgument(
"Only indices.shape[-1] values between 1 and 5 "
"are currently supported. Requested rank: ",
indices_nd));
}
}
OP_REQUIRES(
c, bad_i < 0,
Expand Down Expand Up @@ -236,24 +241,27 @@ class ScatterNdUpdateOp : public OpKernel {
&indices_nd, &num_updates, &slice_size);
if (!c->status().ok()) return;

Tensor scratch;
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));

auto scratch_scalar = scratch.scalar<Index>();
auto indices_flat = indices.flat_inner_dims<Index>();
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});

auto params_matrix = params.template shaped<T, 2>(
{params_shape.num_elements() / slice_size, slice_size});
Index bad_i = -1;
c->forward_ref_input_to_ref_output(0, 0);

switch (indices_nd) {
#define PARAMS_CASE(IXDIM) \
case IXDIM: { \
auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \
functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
params_flat, indices_flat, updates_flat, params_flat); \
#define PARAMS_CASE(IXDIM) \
case IXDIM: { \
typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
for (int i = 0; i < IXDIM; ++i) { \
output_shape_prefix[i] = params_shape.dim_size(i); \
} \
functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
bad_i = \
functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
params_matrix, indices_flat, updates_flat, params_matrix); \
} break
PARAMS_CASE(0);
// TODO(simister): Re-enable this once binary size is under control.
// PARAMS_CASE(0);
PARAMS_CASE(1);
PARAMS_CASE(2);
PARAMS_CASE(3);
Expand Down Expand Up @@ -306,11 +314,13 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
scatter_nd_op::UpdateOp::SUB); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
scatter_nd_op::UpdateOp::MUL); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
scatter_nd_op::UpdateOp::DIV);
scatter_nd_op::UpdateOp::SUB);
// TODO(simister): Find a way to reduce amount of templated generated code
// to reduce build size, then re-enable these additional operations.
// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
// scatter_nd_op::UpdateOp::MUL); \
// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
// scatter_nd_op::UpdateOp::DIV);

#define REGISTER_SCATTER_ND(type, dev) \
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
Expand All @@ -329,8 +339,9 @@ class ScatterNdUpdateOp : public OpKernel {
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);

TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
// TODO(simister): Re-enable all types after binary size is under control.
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);

// Registers GPU kernels.
#if GOOGLE_CUDA
Expand All @@ -356,47 +367,4 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
#undef REGISTER_SCATTER_ND_KERNEL
#undef REGISTER_SCATTER_ND_KERNEL_INDEX

#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {

#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \
template <> \
Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \
OpKernelContext* c, const GPUDevice& d, \
typename TTypes<T, IXDIM>::Tensor params, \
typename TTypes<Index, 2>::ConstTensor indices, \
typename TTypes<T, 2>::ConstTensor updates); \
extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>;

#define DECLARE_GPU_SPECS_OPS(T, Index, op) \
DECLARE_GPU_SPECS_OP(T, Index, op, 0); \
DECLARE_GPU_SPECS_OP(T, Index, op, 1); \
DECLARE_GPU_SPECS_OP(T, Index, op, 2); \
DECLARE_GPU_SPECS_OP(T, Index, op, 3); \
DECLARE_GPU_SPECS_OP(T, Index, op, 4); \
DECLARE_GPU_SPECS_OP(T, Index, op, 5)

#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \
DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV);

#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64);

// TODO(simister): Re-enable when GPU support is working.
// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);

#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX
#undef DECLARE_GPU_SPECS_OPS
#undef DECLARE_GPU_SPECS_OP

} // namespace functor
#endif // GOOGLE_CUDA

} // namespace tensorflow
13 changes: 7 additions & 6 deletions tensorflow/core/kernels/scatter_nd_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ template <typename Device, typename T, typename Index,
scatter_nd_op::UpdateOp op, int IXDIM>
struct ScatterNdFunctor {
// Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index.
Index operator()(const Device& d, const Index slice_size,
typename TTypes<Index>::Scalar Tscratch,
typename TTypes<T, IXDIM + 1>::Tensor Tparams,
typename TTypes<Index, 2>::ConstTensor Tindices,
typename TTypes<T, 2>::ConstTensor Tupdates,
typename TTypes<T, IXDIM + 1>::Tensor Toutput);
Index operator()(
const Device& d, const Index slice_size,
const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
typename TTypes<T, 2>::Tensor Tparams,
typename TTypes<Index, 2>::ConstTensor Tindices,
typename TTypes<T, 2>::ConstTensor Tupdates,
typename TTypes<T, 2>::Tensor Toutput);
};

} // namespace functor
Expand Down
Loading

0 comments on commit fd05b5e

Please sign in to comment.