From fd05b5ebc56316eb6ac9fcb74234979fee2fc5f9 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Mon, 7 Nov 2016 16:01:57 -0800 Subject: [PATCH] Changes to scatter_nd ops * Rewrite CPU impl to be single-threaded and use vectorization; avoids race conditions. Removes use of the generator. * Remove scatter_nd_mul and scatter_nd_div to reduce binary size until we figure out a better way to reduce the templating pain * Modify scatter_nd to add for repeated indices as opposed to update (this is the appropriate gradient for gather_nd, for example) * Clean up docstrings. Change: 138452341 --- tensorflow/core/kernels/scatter_nd_op.cc | 154 ++++------ tensorflow/core/kernels/scatter_nd_op.h | 13 +- .../core/kernels/scatter_nd_op_cpu_impl.h | 201 ++++++------- .../core/kernels/scatter_nd_op_cpu_impl_0.cc | 7 +- tensorflow/core/kernels/scatter_nd_op_test.cc | 52 ++-- tensorflow/core/ops/array_ops.cc | 30 +- .../core/ops/compat/ops_history.v0.pbtxt | 120 -------- tensorflow/core/ops/ops.pbtxt | 134 --------- tensorflow/core/ops/state_ops.cc | 280 ++++++++++-------- .../kernel_tests/scatter_nd_ops_test.py | 64 ++-- tensorflow/python/ops/standard_ops.py | 6 +- tensorflow/python/ops/state_ops.py | 2 - 12 files changed, 412 insertions(+), 651 deletions(-) diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 83b38d73381855..5aeb3d2c0ea774 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -146,43 +146,48 @@ class ScatterNdOp : public OpKernel { &num_updates, &slice_size); if (!c->status().ok()) return; - Tensor scratch; - OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch)); - - auto scratch_scalar = scratch.scalar(); auto indices_flat = indices.flat_inner_dims(); auto updates_flat = updates.shaped({num_updates, slice_size}); + Tensor* out = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); + functor::SetZeroFunctor fill; + fill(c->eigen_device(), out->flat()); + auto output_matrix = out->template shaped( + {shape.num_elements() / slice_size, slice_size}); + Index bad_i = -1; - switch (indices_nd) { -#define PARAMS_CASE(IXDIM) \ - case IXDIM: { \ - Tensor* out = nullptr; \ - OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \ - functor::SetZeroFunctor fill; \ - fill(c->eigen_device(), out->flat()); \ - if (shape.num_elements() > 0) { \ - auto output_flat = out->flat_outer_dims(); \ - functor::ScatterNdFunctor \ - functor; \ - bad_i = functor(c->eigen_device(), slice_size, scratch_scalar, \ - output_flat, indices_flat, updates_flat, output_flat); \ - } \ + + if (shape.num_elements() > 0) { + switch (indices_nd) { +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + typename Eigen::array output_shape_prefix; \ + for (int i = 0; i < IXDIM; ++i) { \ + output_shape_prefix[i] = shape.dim_size(i); \ + } \ + functor::ScatterNdFunctor \ + functor; \ + bad_i = \ + functor(c->eigen_device(), slice_size, output_shape_prefix, \ + output_matrix, indices_flat, updates_flat, output_matrix); \ } break - PARAMS_CASE(0); - PARAMS_CASE(1); - PARAMS_CASE(2); - PARAMS_CASE(3); - PARAMS_CASE(4); - PARAMS_CASE(5); + // TODO(simister): Re-enable this once binary size is under control. + // PARAMS_CASE(0); + PARAMS_CASE(1); + PARAMS_CASE(2); + PARAMS_CASE(3); + PARAMS_CASE(4); + PARAMS_CASE(5); #undef PARAMS_CASE - default: - OP_REQUIRES(c, false, - errors::InvalidArgument( - "Only indices.shape[-1] values between 0 and 5 " - "are currently supported. Requested rank: ", - indices_nd)); + default: + OP_REQUIRES(c, false, + errors::InvalidArgument( + "Only indices.shape[-1] values between 1 and 5 " + "are currently supported. Requested rank: ", + indices_nd)); + } } OP_REQUIRES( c, bad_i < 0, @@ -236,24 +241,27 @@ class ScatterNdUpdateOp : public OpKernel { &indices_nd, &num_updates, &slice_size); if (!c->status().ok()) return; - Tensor scratch; - OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch)); - - auto scratch_scalar = scratch.scalar(); auto indices_flat = indices.flat_inner_dims(); auto updates_flat = updates.shaped({num_updates, slice_size}); - + auto params_matrix = params.template shaped( + {params_shape.num_elements() / slice_size, slice_size}); Index bad_i = -1; c->forward_ref_input_to_ref_output(0, 0); + switch (indices_nd) { -#define PARAMS_CASE(IXDIM) \ - case IXDIM: { \ - auto params_flat = params.flat_outer_dims(); \ - functor::ScatterNdFunctor functor; \ - bad_i = functor(c->eigen_device(), slice_size, scratch_scalar, \ - params_flat, indices_flat, updates_flat, params_flat); \ +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + typename Eigen::array output_shape_prefix; \ + for (int i = 0; i < IXDIM; ++i) { \ + output_shape_prefix[i] = params_shape.dim_size(i); \ + } \ + functor::ScatterNdFunctor functor; \ + bad_i = \ + functor(c->eigen_device(), slice_size, output_shape_prefix, \ + params_matrix, indices_flat, updates_flat, params_matrix); \ } break - PARAMS_CASE(0); + // TODO(simister): Re-enable this once binary size is under control. + // PARAMS_CASE(0); PARAMS_CASE(1); PARAMS_CASE(2); PARAMS_CASE(3); @@ -306,11 +314,13 @@ class ScatterNdUpdateOp : public OpKernel { REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ scatter_nd_op::UpdateOp::ADD); \ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \ - scatter_nd_op::UpdateOp::SUB); \ - REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \ - scatter_nd_op::UpdateOp::MUL); \ - REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \ - scatter_nd_op::UpdateOp::DIV); + scatter_nd_op::UpdateOp::SUB); +// TODO(simister): Find a way to reduce amount of templated generated code +// to reduce build size, then re-enable these additional operations. +// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \ +// scatter_nd_op::UpdateOp::MUL); \ +// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \ +// scatter_nd_op::UpdateOp::DIV); #define REGISTER_SCATTER_ND(type, dev) \ REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd"); @@ -329,8 +339,9 @@ class ScatterNdUpdateOp : public OpKernel { #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU); TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU); -TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); -TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU); +// TODO(simister): Re-enable all types after binary size is under control. +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU); // Registers GPU kernels. #if GOOGLE_CUDA @@ -356,47 +367,4 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU); #undef REGISTER_SCATTER_ND_KERNEL #undef REGISTER_SCATTER_ND_KERNEL_INDEX -#if GOOGLE_CUDA -// Forward declarations of the functor specializations for GPU. -namespace functor { - -#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \ - template <> \ - Index ScatterNdFunctor::operator()( \ - OpKernelContext* c, const GPUDevice& d, \ - typename TTypes::Tensor params, \ - typename TTypes::ConstTensor indices, \ - typename TTypes::ConstTensor updates); \ - extern template struct ScatterNdFunctor; - -#define DECLARE_GPU_SPECS_OPS(T, Index, op) \ - DECLARE_GPU_SPECS_OP(T, Index, op, 0); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 1); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 2); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 3); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 4); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 5) - -#define DECLARE_GPU_SPECS_INDEX(T, Index) \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV); - -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPECS_INDEX(T, int32); \ - DECLARE_GPU_SPECS_INDEX(T, int64); - -// TODO(simister): Re-enable when GPU support is working. -// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS); - -#undef DECLARE_GPU_SPECS -#undef DECLARE_GPU_SPECS_INDEX -#undef DECLARE_GPU_SPECS_OPS -#undef DECLARE_GPU_SPECS_OP - -} // namespace functor -#endif // GOOGLE_CUDA - } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_nd_op.h b/tensorflow/core/kernels/scatter_nd_op.h index 51917b5a0de0d5..10ee94c0bba9e7 100644 --- a/tensorflow/core/kernels/scatter_nd_op.h +++ b/tensorflow/core/kernels/scatter_nd_op.h @@ -48,12 +48,13 @@ template struct ScatterNdFunctor { // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index. - Index operator()(const Device& d, const Index slice_size, - typename TTypes::Scalar Tscratch, - typename TTypes::Tensor Tparams, - typename TTypes::ConstTensor Tindices, - typename TTypes::ConstTensor Tupdates, - typename TTypes::Tensor Toutput); + Index operator()( + const Device& d, const Index slice_size, + const Eigen::array output_shape_prefix, + typename TTypes::Tensor Tparams, + typename TTypes::ConstTensor Tindices, + typename TTypes::ConstTensor Tupdates, + typename TTypes::Tensor Toutput); }; } // namespace functor diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h index d2a7746c35e011..442721d37ba2a0 100644 --- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h @@ -42,147 +42,113 @@ typedef Eigen::ThreadPoolDevice CPUDevice; class OpKernelContext; // Specialization of UpdateExecutor to CPU -namespace generator { +namespace update_executor { -template +template class UpdateExecutor { public: - static void Update(T* input, const T* updates, T* output, Index slice_size); + EIGEN_STRONG_INLINE static void Execute(Input value, Update update, + Output output); }; -template -class UpdateExecutor { +template +class UpdateExecutor { public: - static void Update(T* /* unused */, const T* updates, T* output, - Index slice_size) { - std::copy_n(updates, slice_size, output); + EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, + Output output) { + output = update; } }; -template -class UpdateExecutor { +template +class UpdateExecutor { public: - static void Update(T* input, const T* updates, T* output, Index slice_size) { - std::transform(input, input + slice_size, updates, output, std::plus()); + EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, + Output output) { + output += update; } }; -template -class UpdateExecutor { +template +class UpdateExecutor { public: - static void Update(T* input, const T* updates, T* output, Index slice_size) { - std::transform(input, input + slice_size, updates, output, std::minus()); + EIGEN_STRONG_INLINE static void Execute(Input /* input */, Update update, + Output output) { + output -= update; } }; -template -class UpdateExecutor { +template +class UpdateExecutor { public: - static void Update(T* input, const T* updates, T* output, Index slice_size) { - std::transform(input, input + slice_size, updates, output, - std::multiplies()); + EIGEN_STRONG_INLINE static void Execute(Input input, Update update, + Output output) { + output = input * update; } }; -template -class UpdateExecutor { +template +class UpdateExecutor { public: - static void Update(T* input, const T* updates, T* output, Index slice_size) { - std::transform(input, input + slice_size, updates, output, - std::divides()); + EIGEN_STRONG_INLINE static void Execute(Input input, Update update, + Output output) { + output = input / update; } }; -template -class ScatterNdSliceGenerator { - public: - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ScatterNdSliceGenerator( - const Index slice_size, typename TTypes::Tensor Tparams, - typename TTypes::ConstTensor Tindices, - typename TTypes::ConstTensor Tupdates, - typename TTypes::Tensor Toutput, - std::atomic* error_loc) - : slice_size_(slice_size), - Tparams_(Tparams), - Tindices_(Tindices), - Tupdates_(Tupdates), - Toutput_(Toutput), - error_loc_(error_loc) {} - - EIGEN_DEVICE_FUNC bool GenerateIndices( - const Index loc, Eigen::array* ix) const { - (*ix)[IXDIM] = 0; - bool out_of_bounds = false; - for (int i = 0; i < IXDIM; ++i) { - const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); - (*ix)[i] = ix_i; - out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); - } - return out_of_bounds; - } - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 - operator()(const Eigen::array& loc_array) const { - auto loc = loc_array[0]; - Eigen::array ix_params; - Eigen::array ix_updates; - ix_updates[0] = loc; - ix_updates[1] = 0; - const bool out_of_bounds = GenerateIndices(loc, &ix_params); - if (TF_PREDICT_FALSE(out_of_bounds)) { - error_loc_->store(loc); - } else { - UpdateExecutor::Update(&Tparams_(ix_params), - &Tupdates_(ix_updates), - &Toutput_(ix_params), slice_size_); - } - return static_cast(0); // Return something... - } - - protected: - const Index slice_size_; - mutable typename TTypes::Tensor Tparams_; - const typename TTypes::ConstTensor Tindices_; - const typename TTypes::ConstTensor Tupdates_; - mutable typename TTypes::Tensor Toutput_; - std::atomic* error_loc_; -}; - -} // namespace generator +} // namespace update_executor namespace functor { // Implementation of update functor for CPU. -template -struct ScatterNdFunctor { - Index operator()(const CPUDevice& d, const Index slice_size, - typename TTypes::Scalar Tscratch, - typename TTypes::Tensor Tparams, - typename TTypes::ConstTensor Tindices, - typename TTypes::ConstTensor Tupdates, - typename TTypes::Tensor Toutput) { - std::atomic error_loc(-1); +template +struct ScatterNdFunctor { + Index operator()( + const CPUDevice& d, const Index slice_size, + const Eigen::array output_shape_prefix, + typename TTypes::Tensor Tparams, + typename TTypes::ConstTensor Tindices, + typename TTypes::ConstTensor Tupdates, + typename TTypes::Tensor Toutput) { + // error_loc is -1 if there's no out-of-bounds index, + // otherwise it is the location of an OOB index in Tindices. + Index error_loc = -1; const Eigen::DenseIndex batch_size = Tindices.dimension(0); -#if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::Tensor::Dimensions reshape_dims{{ 1 }}; - Eigen::array broadcast_dims{{ batch_size }}; -#else - Eigen::IndexList > reshape_dims; - Eigen::IndexList broadcast_dims; - broadcast_dims.set(0, batch_size); -#endif - - generator::ScatterNdSliceGenerator generator( - slice_size, Tparams, Tindices, Tupdates, Toutput, &error_loc); - Tscratch.device(d) = Tscratch.reshape(reshape_dims) - .broadcast(broadcast_dims) - .generate(generator) - .sum(); - - // error_loc() returns -1 if there's no out-of-bounds index, - // otherwise it returns the location of an OOB index in Tindices. - return error_loc.load(); + + Index batch_strides[IXDIM]; + for (int dim = IXDIM - 1; dim >= 0; --dim) { + if (dim == IXDIM - 1) { + batch_strides[dim] = 1; + } else { + batch_strides[dim] = + batch_strides[dim + 1] * output_shape_prefix[dim + 1]; + } + } + + for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { + Index i = 0; + bool out_of_bounds = false; + for (int dim = 0; dim < IXDIM; ++dim) { + const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); + out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); + i += ix_d * batch_strides[dim]; + } + if (TF_PREDICT_FALSE(out_of_bounds)) { + error_loc = loc; + break; + } else { + auto input_chip = Toutput.template chip<0>(i); + auto output_chip = input_chip.device(d); + auto update_chip = Tupdates.template chip<0>(loc); + update_executor::UpdateExecutor< + decltype(input_chip), decltype(update_chip), decltype(output_chip), + OP>::Execute(input_chip, update_chip, output_chip); + } + } + + return error_loc; } }; @@ -190,11 +156,12 @@ struct ScatterNdFunctor { template Index \ ScatterNdFunctor::operator()( \ const CPUDevice& d, const Index slice_size, \ - typename TTypes::Scalar Tscratch, \ - typename TTypes::Tensor Tparams, \ + const Eigen::array \ + output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ typename TTypes::ConstTensor Tindices, \ typename TTypes::ConstTensor Tupdates, \ - typename TTypes::Tensor Toutput) + typename TTypes::Tensor Toutput) #define REGISTER_SCATTER_ND_INDEX(type, op) \ REGISTER_SCATTER_ND_FULL(type, int32, op); \ @@ -205,9 +172,11 @@ struct ScatterNdFunctor { #define REGISTER_SCATTER_ND_MATH(type) \ REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \ - REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); \ - REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MUL); \ - REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::DIV); + REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); +// TODO(simister): Re-enable after identifying a way to reduce the binary size +// due to too many template instantiations. +// REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MUL); \ +// REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::DIV); TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE); TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH) diff --git a/tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc b/tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc index e978c5c348ae1b..04574ccf1bc442 100644 --- a/tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc +++ b/tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define CPU_PROVIDED_IXDIM 0 -#include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h" -#undef CPU_PROVIDED_IXDIM +// TODO(simister): Re-enable once binary size is under control. +// #define CPU_PROVIDED_IXDIM 0 +// #include "tensorflow/core/kernels/scatter_nd_op_cpu_impl.h" +// #undef CPU_PROVIDED_IXDIM diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index d6743a68674917..d69909b6ded9a4 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -48,31 +48,32 @@ class ScatterNdUpdateOpTest : public OpsTestBase { } }; -TEST_F(ScatterNdUpdateOpTest, Simple_StringType) { - MakeOp(DT_STRING_REF, DT_INT32); - AddInputFromArray(TensorShape({1}), {"Brain"}); - AddInputFromArray(TensorShape({1}), {0}); - AddInputFromArray(TensorShape({1}), {"TensorFlow"}); - TF_ASSERT_OK(RunOpKernel()); - // Check the new state of the input - Tensor params_tensor = *mutable_input(0).tensor; - Tensor expected(allocator(), DT_STRING, TensorShape({1})); - test::FillValues(&expected, {"TensorFlow"}); - test::ExpectTensorEqual(expected, params_tensor); -} - -TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) { - MakeOp(DT_BOOL_REF, DT_INT32); - AddInputFromArray(TensorShape({1}), {false}); - AddInputFromArray(TensorShape({1}), {0}); - AddInputFromArray(TensorShape({1}), {true}); - TF_ASSERT_OK(RunOpKernel()); - // Check the new state of the input - Tensor params_tensor = *mutable_input(0).tensor; - Tensor expected(allocator(), DT_BOOL, TensorShape({1})); - test::FillValues(&expected, {true}); - test::ExpectTensorEqual(expected, params_tensor); -} +// TODO(simister): Re-enable this once binary size is under control. +// TEST_F(ScatterNdUpdateOpTest, Simple_StringType) { +// MakeOp(DT_STRING_REF, DT_INT32); +// AddInputFromArray(TensorShape({1}), {"Brain"}); +// AddInputFromArray(TensorShape({1}), {0}); +// AddInputFromArray(TensorShape({1}), {"TensorFlow"}); +// TF_ASSERT_OK(RunOpKernel()); +// // Check the new state of the input +// Tensor params_tensor = *mutable_input(0).tensor; +// Tensor expected(allocator(), DT_STRING, TensorShape({1})); +// test::FillValues(&expected, {"TensorFlow"}); +// test::ExpectTensorEqual(expected, params_tensor); +// } + +// TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) { +// MakeOp(DT_BOOL_REF, DT_INT32); +// AddInputFromArray(TensorShape({1}), {false}); +// AddInputFromArray(TensorShape({1}), {0}); +// AddInputFromArray(TensorShape({1}), {true}); +// TF_ASSERT_OK(RunOpKernel()); +// // Check the new state of the input +// Tensor params_tensor = *mutable_input(0).tensor; +// Tensor expected(allocator(), DT_BOOL, TensorShape({1})); +// test::FillValues(&expected, {true}); +// test::ExpectTensorEqual(expected, params_tensor); +// } TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) { MakeOp(DT_FLOAT_REF, DT_INT32); @@ -111,6 +112,7 @@ TEST_F(ScatterNdUpdateOpTest, Simple_Two64) { 10002, 0, 0, 0, 777, 778, 779}); test::ExpectTensorEqual(expected, params_tensor); } + /*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) { MakeOp(DT_FLOAT_REF, DT_INT32); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index ce1f76503c8e62..e84feaedf19a8f 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -4395,12 +4395,16 @@ REGISTER_OP("ScatterNd") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Doc( - R"doc(Creates a new tensor by applying sparse `updates` to individual values or slices within a zero tensor of the given `shape` tensor according to indices. -This operator is the inverse of the [tf.gather_nd](#gather_nd) operator which extracts values or slices from a given tensor. + R"doc(Creates a new tensor by applying sparse `updates` to individual +values or slices within a zero tensor of the given `shape` tensor according to +indices. This operator is the inverse of the [tf.gather_nd](#gather_nd) +operator which extracts values or slices from a given tensor. -TODO(simister): Add a link to Variable.__getitem__ documentation on slice syntax. +TODO(simister): Add a link to Variable.__getitem__ documentation on slice +syntax. -`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank `Q`. +`shape` is a `TensorShape` with rank `P` and `indices` is a `Tensor` of rank +`Q`. `indices` must be integer tensor, containing indices into `shape`. It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. @@ -4415,7 +4419,9 @@ dimension of `shape`. [d_0, ..., d_{Q-2}, shape[K], ..., shape[P-1]]. ``` -The simplest form of scatter is to insert individual elements in a tensor by index. For example, say we want to insert 4 scattered elements in a rank-1 tensor with 8 elements. +The simplest form of scatter is to insert individual elements in a tensor by +index. For example, say we want to insert 4 scattered elements in a rank-1 +tensor with 8 elements.
@@ -4434,7 +4440,9 @@ The resulting tensor would look like this: [0, 11, 0, 10, 9, 0, 0, 12] -We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values. +We can also, insert entire slices of a higher rank tensor all at once. For +example, if we wanted to insert two slices in the first dimension of a +rank-3 tensor with two matrices of new values.
@@ -4459,10 +4467,14 @@ The resulting tensor would look like this: [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] -indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. -updates: A Tensor. Must have the same type as tensor. A tensor of updated values to store in ref. +indices: A Tensor. Must be one of the following types: int32, int64. + A tensor of indices into ref. +updates: A Tensor. Must have the same type as tensor. A tensor of updated values + to store in ref. shape: A vector. The shape of the resulting tensor. -output: A new tensor with the given shape and updates applied according to the indices.)doc"); +output: A new tensor with the given shape and updates applied according + to the indices. +)doc"); REGISTER_OP("FakeQuantWithMinMaxArgs") .Attr("min: float = -6.0") diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index dd44b73f40a7a0..f190c60d1bc910 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -24864,126 +24864,6 @@ op { } } } -op { - name: "ScatterNdDiv" - input_arg { - name: "ref" - type_attr: "T" - is_ref: true - } - input_arg { - name: "indices" - type_attr: "Tindices" - } - input_arg { - name: "updates" - type_attr: "T" - } - output_arg { - name: "output_ref" - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tindices" - type: "type" - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: false - } - } -} -op { - name: "ScatterNdMul" - input_arg { - name: "ref" - type_attr: "T" - is_ref: true - } - input_arg { - name: "indices" - type_attr: "Tindices" - } - input_arg { - name: "updates" - type_attr: "T" - } - output_arg { - name: "output_ref" - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tindices" - type: "type" - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: false - } - } -} op { name: "ScatterNdSub" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 4e970465a0bd96..8ab182304426ab 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -15579,140 +15579,6 @@ op { summary: "Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`." description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n add = tf.scatter_nd_add(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(add)\n\nThe resulting update to ref would look like this:\n\n [1, 13, 3, 14, 14, 6, 7, 20]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices." } -op { - name: "ScatterNdDiv" - input_arg { - name: "ref" - description: "A mutable Tensor. Should be from a Variable node." - type_attr: "T" - is_ref: true - } - input_arg { - name: "indices" - description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref." - type_attr: "Tindices" - } - input_arg { - name: "updates" - description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref." - type_attr: "T" - } - output_arg { - name: "output_ref" - description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done." - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tindices" - type: "type" - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: false - } - description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention." - } - summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`." - description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this:\n\n ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([2, 3, 4, 5])\n sub = tf.scatter_nd_div(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [10, 5, 30, 13, 25, 60, 70, 16]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices." -} -op { - name: "ScatterNdMul" - input_arg { - name: "ref" - description: "A mutable Tensor. Should be from a Variable node." - type_attr: "T" - is_ref: true - } - input_arg { - name: "indices" - description: "A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref." - type_attr: "Tindices" - } - input_arg { - name: "updates" - description: "A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref." - type_attr: "T" - } - output_arg { - name: "output_ref" - description: "Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done." - type_attr: "T" - is_ref: true - } - attr { - name: "T" - type: "type" - allowed_values { - list { - type: DT_FLOAT - type: DT_DOUBLE - type: DT_INT64 - type: DT_INT32 - type: DT_UINT8 - type: DT_UINT16 - type: DT_INT16 - type: DT_INT8 - type: DT_COMPLEX64 - type: DT_COMPLEX128 - type: DT_QINT8 - type: DT_QUINT8 - type: DT_QINT32 - type: DT_HALF - } - } - } - attr { - name: "Tindices" - type: "type" - allowed_values { - list { - type: DT_INT32 - type: DT_INT64 - } - } - } - attr { - name: "use_locking" - type: "bool" - default_value { - b: false - } - description: "An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention." - } - summary: "Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`." - description: "`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.\n\n`indices` must be integer tensor, containing indices into `ref`.\nIt must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.\n\nThe innermost dimension of `indices` (with length `K`) corresponds to\nindices into elements (if `K = P`) or slices (if `K < P`) along the `K`th\ndimension of `ref`.\n\n`updates` is `Tensor` of rank `Q-1+P-K` with shape:\n\n```\n[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].\n```\n\nFor example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this:\n\n ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])\n indices = tf.constant([[4], [3], [1], [7]])\n updates = tf.constant([9, 10, 11, 12])\n sub = tf.scatter_nd_mul(ref, indices, updates)\n with tf.Session() as sess:\n print sess.run(sub)\n\nThe resulting update to ref would look like this:\n\n [1, 22, 3, 40, 45, 6, 7, 96]\n\nSee [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices." -} op { name: "ScatterNdSub" input_arg { diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 9339b9b82143e5..f6d21b909a87bc 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -453,8 +453,9 @@ REGISTER_OP("ScatterNdUpdate") .Attr("T: type") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = true") - .Doc( - R"doc(Applies sparse `updates` to individual values or slices within a given variable according to `indices`. + .Doc(R"doc( +Applies sparse `updates` to individual values or slices within a given +variable according to `indices`. `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. @@ -471,7 +472,8 @@ dimension of `ref`. [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. ``` -For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this: +For example, say we want to update 4 scattered elements to a rank-1 tensor to +8 elements. In Python, that update would look like this: ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) indices = tf.constant([[4], [3], [1] ,[7]]) @@ -484,13 +486,20 @@ The resulting update to ref would look like this: [1, 11, 3, 10, 9, 6, 7, 12] -See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to +slices. ref: A mutable Tensor. Should be from a Variable node. -indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. -updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref. -use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); +indices: A Tensor. Must be one of the following types: int32, int64. + A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated + values to add to ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will + be protected by a lock; otherwise the behavior is undefined, + but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want to + use the updated values after the update is done. +)doc"); REGISTER_OP("ScatterNdAdd") .Input("ref: Ref(T)") @@ -500,8 +509,9 @@ REGISTER_OP("ScatterNdAdd") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .Doc( - R"doc(Applies sparse addition between `updates` and individual values or slices within a given variable according to `indices`. + .Doc(R"doc( +Applies sparse addition between `updates` and individual values or slices +within a given variable according to `indices`. `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. @@ -518,7 +528,8 @@ dimension of `ref`. [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. ``` -For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that addition would look like this: +For example, say we want to add 4 scattered elements to a rank-1 tensor to 8 +elements. In Python, that addition would look like this: ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) indices = tf.constant([[4], [3], [1], [7]]) @@ -531,13 +542,20 @@ The resulting update to ref would look like this: [1, 13, 3, 14, 14, 6, 7, 20] -See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to +slices. ref: A mutable Tensor. Should be from a Variable node. -indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. -updates: A Tensor. Must have the same type as ref. A tensor of updated values to add to ref. -use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); +indices: A Tensor. Must be one of the following types: int32, int64. + A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated values + to add to ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will + be protected by a lock; otherwise the behavior is undefined, + but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want + to use the updated values after the update is done. +)doc"); REGISTER_OP("ScatterNdSub") .Input("ref: Ref(T)") @@ -547,8 +565,9 @@ REGISTER_OP("ScatterNdSub") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Attr("use_locking: bool = false") - .Doc( - R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. + .Doc(R"doc( +Applies sparse subtraction between `updates` and individual values or slices +within a given variable according to `indices`. `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. @@ -565,7 +584,8 @@ dimension of `ref`. [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. ``` -For example, say we want to subtract 4 scattered elements from a rank-1 tensor with 8 elements. In Python, that subtraction would look like this: +For example, say we want to subtract 4 scattered elements from a rank-1 tensor +with 8 elements. In Python, that subtraction would look like this: ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) indices = tf.constant([[4], [3], [1], [7]]) @@ -578,107 +598,133 @@ The resulting update to ref would look like this: [1, -9, 3, -6, -4, 6, 7, -4] -See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. +See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to +slices. ref: A mutable Tensor. Should be from a Variable node. -indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. -updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref. -use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); - -REGISTER_OP("ScatterNdMul") - .Input("ref: Ref(T)") - .Input("indices: Tindices") - .Input("updates: T") - .Output("output_ref: Ref(T)") - .Attr("T: numbertype") - .Attr("Tindices: {int32, int64}") - .Attr("use_locking: bool = false") - .Doc( - R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. - -`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. - -`indices` must be integer tensor, containing indices into `ref`. -It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. - -The innermost dimension of `indices` (with length `K`) corresponds to -indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -dimension of `ref`. - -`updates` is `Tensor` of rank `Q-1+P-K` with shape: - -``` -[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. -``` - -For example, say we want to multiply 4 scattered elements with a rank-1 tensor with 8 elements. In Python, that multiplication would look like this: - - ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) - indices = tf.constant([[4], [3], [1], [7]]) - updates = tf.constant([9, 10, 11, 12]) - sub = tf.scatter_nd_mul(ref, indices, updates) - with tf.Session() as sess: - print sess.run(sub) - -The resulting update to ref would look like this: - - [1, 22, 3, 40, 45, 6, 7, 96] - -See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. - -ref: A mutable Tensor. Should be from a Variable node. -indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. -updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref. -use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); - -REGISTER_OP("ScatterNdDiv") - .Input("ref: Ref(T)") - .Input("indices: Tindices") - .Input("updates: T") - .Output("output_ref: Ref(T)") - .Attr("T: numbertype") - .Attr("Tindices: {int32, int64}") - .Attr("use_locking: bool = false") - .Doc( - R"doc(Applies sparse subtraction between `updates` and individual values or slices within a given variable according to `indices`. - -`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. - -`indices` must be integer tensor, containing indices into `ref`. -It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. - -The innermost dimension of `indices` (with length `K`) corresponds to -indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th -dimension of `ref`. - -`updates` is `Tensor` of rank `Q-1+P-K` with shape: - -``` -[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. -``` - -For example, say we want to divide a rank-1 tensor with 8 elements by 4 scattered elements. In Python, that division would look like this: - - ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80]) - indices = tf.constant([[4], [3], [1], [7]]) - updates = tf.constant([2, 3, 4, 5]) - sub = tf.scatter_nd_div(ref, indices, updates) - with tf.Session() as sess: - print sess.run(sub) - -The resulting update to ref would look like this: - - [10, 5, 30, 13, 25, 60, 70, 16] - -See [tf.scatter_nd](#scatter_nd) for more details about how to make updates to slices. +indices: A Tensor. Must be one of the following types: int32, int64. + A tensor of indices into ref. +updates: A Tensor. Must have the same type as ref. A tensor of updated values + to subtract from ref. +use_locking: An optional bool. Defaults to True. If True, the assignment will + be protected by a lock; otherwise the behavior is undefined, + but may exhibit less contention. +output_ref: Same as ref. Returned as a convenience for operations that want + to use the updated values after the update is done. +)doc"); -ref: A mutable Tensor. Should be from a Variable node. -indices: A Tensor. Must be one of the following types: int32, int64. A tensor of indices into ref. -updates: A Tensor. Must have the same type as ref. A tensor of updated values to subtract from ref. -use_locking: An optional bool. Defaults to True. If True, the assignment will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. -output_ref: Same as ref. Returned as a convenience for operations that want to use the updated values after the update is done.)doc"); +// TODO(simister): Re-enable once these additional ops do not dramatically +// increase binary size. + +// REGISTER_OP("ScatterNdMul") +// .Input("ref: Ref(T)") +// .Input("indices: Tindices") +// .Input("updates: T") +// .Output("output_ref: Ref(T)") +// .Attr("T: numbertype") +// .Attr("Tindices: {int32, int64}") +// .Attr("use_locking: bool = false") +// .Doc( +// R"doc(Applies sparse subtraction between `updates` and individual +// values or slices within a given variable according to `indices`. + +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. + +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: + +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +// ``` + +// For example, say we want to multiply 4 scattered elements with a rank-1 +// tensor with 8 elements. In Python, that multiplication would look like this: + +// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([9, 10, 11, 12]) +// sub = tf.scatter_nd_mul(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(sub) + +// The resulting update to ref would look like this: + +// [1, 22, 3, 40, 45, 6, 7, 96] + +// See [tf.scatter_nd](#scatter_nd) for more details about how to make updates +// to slices. + +// ref: A mutable Tensor. Should be from a Variable node. +// indices: A Tensor. Must be one of the following types: int32, int64. A tensor +// of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of updated values +// to subtract from ref. +// use_locking: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, but may exhibit +// less contention. +// output_ref: Same as ref. Returned as a convenience for operations that want +// to use the updated values after the update is done.)doc"); + +// REGISTER_OP("ScatterNdDiv") +// .Input("ref: Ref(T)") +// .Input("indices: Tindices") +// .Input("updates: T") +// .Output("output_ref: Ref(T)") +// .Attr("T: numbertype") +// .Attr("Tindices: {int32, int64}") +// .Attr("use_locking: bool = false") +// .Doc( +// R"doc(Applies sparse subtraction between `updates` and individual +// values or slices within a given variable according to `indices`. + +// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. + +// `indices` must be integer tensor, containing indices into `ref`. +// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. + +// The innermost dimension of `indices` (with length `K`) corresponds to +// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th +// dimension of `ref`. + +// `updates` is `Tensor` of rank `Q-1+P-K` with shape: + +// ``` +// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. +// ``` + +// For example, say we want to divide a rank-1 tensor with 8 elements by 4 +// scattered elements. In Python, that division would look like this: + +// ref = tf.Variable([10, 20, 30, 40, 50, 60, 70, 80]) +// indices = tf.constant([[4], [3], [1], [7]]) +// updates = tf.constant([2, 3, 4, 5]) +// sub = tf.scatter_nd_div(ref, indices, updates) +// with tf.Session() as sess: +// print sess.run(sub) + +// The resulting update to ref would look like this: + +// [10, 5, 30, 13, 25, 60, 70, 16] + +// See [tf.scatter_nd](#scatter_nd) for more details about how to make updates +// to slices. + +// ref: A mutable Tensor. Should be from a Variable node. +// indices: A Tensor. Must be one of the following types: int32, int64. A tensor +// of indices into ref. +// updates: A Tensor. Must have the same type as ref. A tensor of updated values +// to subtract from ref. +// use_locking: An optional bool. Defaults to True. If True, the assignment will +// be protected by a lock; otherwise the behavior is undefined, but may exhibit +// less contention. +// output_ref: Same as ref. Returned as a convenience for operations that want +// to use the updated values after the update is done.)doc"); REGISTER_OP("CountUpTo") .Input("ref: Ref(T)") diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index 0461758d27e70d..3d2ac798cd29f2 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -78,7 +78,7 @@ def _NumpyDiv(ref, indices, updates): return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u) -class ScatterTest(tf.test.TestCase): +class ScatterNdTest(tf.test.TestCase): def _VariableRankTest(self, np_scatter, @@ -145,11 +145,13 @@ def testVariableRankAdd(self): def testVariableRankSub(self): self._VariableRankTests(_NumpySub, tf.scatter_nd_sub) - def testVariableRankMul(self): - self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul) + # TODO(simister): Re-enable once binary size increase due to + # scatter_nd ops is under control. + # def testVariableRankMul(self): + # self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul) - def testVariableRankDiv(self): - self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div) + # def testVariableRankDiv(self): + # self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div) def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter): for vtype in (np.float32, np.float64): @@ -167,25 +169,29 @@ def testScatterRepeatIndices(self): """This tests scatter_add using indices that repeat.""" self._ScatterRepeatIndicesTest(_NumpyAdd, tf.scatter_nd_add) self._ScatterRepeatIndicesTest(_NumpySub, tf.scatter_nd_sub) - self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul) - self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div) - - def testBooleanScatterUpdate(self): - with self.test_session(use_gpu=False) as session: - var = tf.Variable([True, False]) - update0 = tf.scatter_nd_update(var, [[1]], [True]) - update1 = tf.scatter_nd_update( - var, tf.constant( - [[0]], dtype=tf.int64), [False]) - var.initializer.run() - - session.run([update0, update1]) - - self.assertAllEqual([False, True], var.eval()) + # TODO(simister): Re-enable once binary size increase due to + # extra templating is back under control. + # self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul) + # self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div) + + # TODO(simister): Re-enable once binary size increase due to + # extra templating is back under control and this op is re-enabled + # def testBooleanScatterUpdate(self): + # with self.test_session(use_gpu=False) as session: + # var = tf.Variable([True, False]) + # update0 = tf.scatter_nd_update(var, [[1]], [True]) + # update1 = tf.scatter_nd_update( + # var, tf.constant( + # [[0]], dtype=tf.int64), [False]) + # var.initializer.run() + # session.run([update0, update1]) + # self.assertAllEqual([False, True], var.eval()) def testScatterOutOfRangeCpu(self): - for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul, - tf.scatter_nd_div, tf.scatter_nd_update): + # TODO(simister): Re-enable once binary size increase due to + # scatter_nd ops is under control. + # tf.scatter_nd_mul, tf.scatter_nd_div, + for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update): params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) updates = np.array([-3, -4, -5]).astype(np.float32) with self.test_session(use_gpu=False): @@ -355,8 +361,10 @@ def testConcurrentUpdates(self): def _disabledTestScatterOutOfRangeGpu(self): if not tf.test.IsBuiltWithCuda(): return - for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul, - tf.scatter_nd_div, tf.scatter_nd_update): + # TODO(simister): Re-enable once binary size increase due to + # scatter_nd ops is under control. + # tf.scatter_nd_mul, tf.scatter_nd_div, + for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update): params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) updates = np.array([-3, -4, -5]).astype(np.float32) # With GPU, the code ignores indices that are out of range. @@ -375,6 +383,14 @@ def _disabledTestScatterOutOfRangeGpu(self): indices = np.array([2, 0, 6]) op(ref, indices, updates).eval() + def testScatterNdRepatedIndicesAdd(self): + indices = tf.zeros([100000, 1], tf.int32) + values = np.random.randn(100000) + shape = [1] + with self.test_session(): + val = tf.scatter_nd(indices, values, shape).eval() + self.assertAllClose([np.sum(values)], val) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 847d1b99c83a2d..12811c54b44748 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -69,8 +69,10 @@ from tensorflow.python.ops.state_ops import scatter_update from tensorflow.python.ops.state_ops import scatter_nd_add from tensorflow.python.ops.state_ops import scatter_nd_sub -from tensorflow.python.ops.state_ops import scatter_nd_mul -from tensorflow.python.ops.state_ops import scatter_nd_div +# TODO(simister): Re-enable once binary size increase due to scatter_nd +# ops is under control. +# from tensorflow.python.ops.state_ops import scatter_nd_mul +# from tensorflow.python.ops.state_ops import scatter_nd_div from tensorflow.python.ops.state_ops import scatter_nd_update from tensorflow.python.ops.string_ops import * from tensorflow.python.ops.template import * diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 2c12865df06da7..e196bdd3ff9999 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -98,8 +98,6 @@ @@scatter_nd_update @@scatter_nd_add @@scatter_nd_sub -@@scatter_nd_mul -@@scatter_nd_div @@sparse_mask @@IndexedSlices