From fcc3923ab50a98dcbe0f972231b3c4656ebb9228 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Tue, 22 Nov 2016 10:04:37 -0800 Subject: [PATCH] Change division in TensorFlow to flooring semantics. - tf.div changes to new behavior, but it will be deprecated - tf.divide is currently a synonym for tf.div but will remain - tf.mod changes to new behavior, but it will be deprecated, you can use % or tf.floormod in the future. - the op FloorDiv now is extended to work on reals Change: 139922734 --- RELEASE.md | 4 ++ tensorflow/core/kernels/cwise_op_floor_div.cc | 4 ++ .../core/kernels/cwise_op_gpu_floor_div.cu.cc | 1 + tensorflow/core/kernels/cwise_ops.h | 21 ++++++++++ tensorflow/python/ops/math_grad.py | 40 +++++++++++++++++-- tensorflow/python/ops/math_ops.py | 37 +++++++++++++++-- tensorflow/python/ops/math_ops_test.py | 21 ++++++++++ 7 files changed, 120 insertions(+), 8 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index d618c865f53aea..939eee0f2d5eda 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -2,6 +2,10 @@ ## Breaking Changes to the API +* Division and modulus operators (/, //, %) now match Python (flooring) + semantics. tf.div is renamed to tf.division. New operators tf.truncatediv and + tf.truncatemod are available for achieving the previous C++ (truncation) + division/modulus semantics. * `BusAdjacency` enum replaced with a protocol buffer `DeviceLocality`. PCI bus indexing now starts from 1 instead of 0, and bus_id==0 is used where previously BUS_ANY was used. diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc index 7930d83413eff5..a5767476c3fc96 100644 --- a/tensorflow/core/kernels/cwise_op_floor_div.cc +++ b/tensorflow/core/kernels/cwise_op_floor_div.cc @@ -28,9 +28,13 @@ REGISTER5(BinaryOp, CPU, "FloorDiv", functor::safe_floor_div, uint8, uint16, TF_CALL_INTEGRAL_TYPES(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL +REGISTER3(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float, + Eigen::half, double); #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16, int64); +REGISTER3(BinaryOp, GPU, "FloorDiv", functor::floor_div_real, float, + Eigen::half, double); #endif #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc index 1300bf2232b348..0e4887eafd6a9d 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc @@ -20,6 +20,7 @@ limitations under the License. namespace tensorflow { namespace functor { DEFINE_BINARY5(floor_div, uint8, uint16, int16, int32, int64); +DEFINE_BINARY3(floor_div_real, Eigen::half, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 7f35e03feb1987..34103347fb975e 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -269,6 +269,24 @@ struct functor_traits> { }; }; +// TODO(b/32239616): This kernel should be moved into Eigen and vectorized. +template +struct google_floor_div_real { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, + const T& y) const { + return Eigen::numext::floor(x / y); + } +}; + +template +struct functor_traits> { + enum { + Cost = 2 * Eigen::internal::scalar_div_cost::value + + 2 * NumTraits::AddCost, + PacketAccess = false + }; +}; + // TODO(b//32239616): This kernel should be moved into Eigen and vectorized. template struct google_floor_fmod { @@ -611,6 +629,9 @@ struct safe_floor_div : base +struct floor_div_real : base> {}; + template struct pow : base> {}; diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 1fd69ae717cf05..3502f118921bdd 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -613,16 +613,48 @@ def _MulGrad(op, grad): @ops.RegisterGradient("Div") def _DivGrad(op, grad): + """The gradient for the Div operator.""" x = op.inputs[0] y = op.inputs[1] sx = array_ops.shape(x) sy = array_ops.shape(y) - rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) # pylint: disable=protected-access + # pylint: disable=protected-access + rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + # pylint: enable=protected-access + x = math_ops.conj(x) + y = math_ops.conj(y) + return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx), + array_ops.reshape(math_ops.reduce_sum( + grad * math_ops.div(-x, math_ops.square(y)), ry), sy)) + + +@ops.RegisterGradient("FloorDiv") +def _FloorDivGrad(_, unused_grad): + """The gradient for the FloorDiv operator.""" + return None, None + + +@ops.RegisterGradient("TruncateDiv") +def _TruncateDivGrad(_, unused_grad): + return None, None + + +@ops.RegisterGradient("RealDiv") +def _RealDivGrad(op, grad): + """RealDiv op gradient.""" + x = op.inputs[0] + y = op.inputs[1] + sx = array_ops.shape(x) + sy = array_ops.shape(y) + # pylint: disable=protected-access + rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) + # pylint: enable=protected-access x = math_ops.conj(x) y = math_ops.conj(y) - return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx), - array_ops.reshape(math_ops.reduce_sum(grad * - (-x / math_ops.square(y)), ry), sy)) + return (array_ops.reshape(math_ops.reduce_sum( + math_ops.realdiv(grad, y), rx), sx), + array_ops.reshape(math_ops.reduce_sum( + grad * math_ops.realdiv(-x, math_ops.square(y)), ry), sy)) @ops.RegisterGradient("Pow") diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index c2aab4c945389f..6fce264bd9262e 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -928,7 +928,32 @@ def truediv(x, y, name=None): if dtype is not None: x = cast(x, dtype) y = cast(y, dtype) - return gen_math_ops.div(x, y, name=name) + return gen_math_ops.real_div(x, y, name=name) + + +def div(x, y, name=None): + with ops.name_scope(name, "truediv", [x, y]) as name: + x = ops.convert_to_tensor(x, name="x") + y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype) + x_dtype = x.dtype.base_dtype + y_dtype = y.dtype.base_dtype + if x_dtype != y_dtype: + raise TypeError("x and y must have the same dtype, got %r != %r" % + (x_dtype, y_dtype)) + if x_dtype.is_floating or x_dtype.is_complex: + return gen_math_ops.real_div(x, y, name=name) + else: + return gen_math_ops.floor_div(x, y, name=name) + + +def div_deprecated(x, y, name=None): + return gen_math_ops.div(x, y, name) + +mod = gen_math_ops.floor_mod + + +def mod_deprecated(x, y, name=None): + return gen_math_ops.mod(x, y, name) # TODO(aselle): Deprecate this once all internal functionality uses @@ -959,6 +984,11 @@ def floordiv(x, y, name=None): Raises: TypeError: If the inputs are complex. """ + with ops.name_scope(name, "floordiv", [x, y]) as name: + return gen_math_ops.floor_div(x, y, name=name) + + +def floordiv_deprecated(x, y, name=None): with ops.name_scope(name, "floordiv", [x, y]) as name: x = ops.convert_to_tensor(x, name="x") dtype = x.dtype @@ -971,7 +1001,6 @@ def floordiv(x, y, name=None): # return gen_math_ops.floor_div(x, y, name=name) return gen_math_ops.div(x, y, name=name) - realdiv = gen_math_ops.real_div truncatediv = gen_math_ops.truncate_div # TODO(aselle): Rename this to floordiv when we can. @@ -1002,12 +1031,12 @@ def _mul_dispatch(x, y, name=None): _OverrideBinaryOperatorHelper(gen_math_ops.add, "add") _OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub") _OverrideBinaryOperatorHelper(_mul_dispatch, "mul") -_OverrideBinaryOperatorHelper(gen_math_ops.div, "div") +_OverrideBinaryOperatorHelper(div, "div") _OverrideBinaryOperatorHelper(truediv, "truediv") _OverrideBinaryOperatorHelper(floordiv, "floordiv") # TODO(aselle): Switch mod to floor_mod when ready # _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") -_OverrideBinaryOperatorHelper(gen_math_ops.mod, "mod") +_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod") _OverrideBinaryOperatorHelper(pow, "pow") diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 4bbbc7b4f76c11..197ddb6a75d129 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -305,6 +306,26 @@ def testRealDiv(self): np_result = np.divide(nums, divs) self.assertAllEqual(tf_result, np_result) + def testComplexDiv(self): + foo = array_ops.constant([1.+3.j]) + with self.test_session(): + _ = math_ops.div_deprecated(foo, 1.).eval() + _ = math_ops.div(foo, 2.).eval() + + def testFloorDivGrad(self): + with self.test_session(): + a = variables.Variable(2.) + b = variables.Variable(4.) + with self.test_session() as sess: + sess.run(variables.initialize_all_variables()) + c_grad = gradients.gradients(math_ops.div_deprecated(a, b), [a, b]) + self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125]) + c_grad = gradients.gradients(math_ops.div(a, b), [a, b]) + self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125]) + c_grad = gradients.gradients(math_ops.floordiv(a, b), [a, b]) + self.assertAllEqual([None if x is None else x.eval() for x in c_grad], + [None, None]) + def testConsistent(self): nums, divs = self.intTestData() with self.test_session():