Skip to content

Commit

Permalink
Change division in TensorFlow to flooring semantics.
Browse files Browse the repository at this point in the history
- tf.div changes to new behavior, but it will be deprecated
- tf.divide is currently a synonym for tf.div but will remain
- tf.mod changes to new behavior, but it will be deprecated,
  you can use % or tf.floormod in the future.
- the op FloorDiv now is extended to work on reals
Change: 139922734
  • Loading branch information
aselle authored and tensorflower-gardener committed Nov 22, 2016
1 parent 5591ca5 commit fcc3923
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 8 deletions.
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Breaking Changes to the API

* Division and modulus operators (/, //, %) now match Python (flooring)
semantics. tf.div is renamed to tf.division. New operators tf.truncatediv and
tf.truncatemod are available for achieving the previous C++ (truncation)
division/modulus semantics.
* `BusAdjacency` enum replaced with a protocol buffer `DeviceLocality`. PCI bus
indexing now starts from 1 instead of 0, and bus_id==0 is used where previously
BUS_ANY was used.
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/cwise_op_floor_div.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ REGISTER5(BinaryOp, CPU, "FloorDiv", functor::safe_floor_div, uint8, uint16,
TF_CALL_INTEGRAL_TYPES(REGISTER_SYCL_KERNEL);
#undef REGISTER_SYCL_KERNEL
#endif // TENSORFLOW_USE_SYCL
REGISTER3(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float,
Eigen::half, double);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16,
int64);
REGISTER3(BinaryOp, GPU, "FloorDiv", functor::floor_div_real, float,
Eigen::half, double);
#endif

#if GOOGLE_CUDA
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/cwise_op_gpu_floor_div.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
namespace tensorflow {
namespace functor {
DEFINE_BINARY5(floor_div, uint8, uint16, int16, int32, int64);
DEFINE_BINARY3(floor_div_real, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow

Expand Down
21 changes: 21 additions & 0 deletions tensorflow/core/kernels/cwise_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,24 @@ struct functor_traits<google_floor_div<Scalar>> {
};
};

// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
template <typename T, typename Enable = void>
struct google_floor_div_real {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x,
const T& y) const {
return Eigen::numext::floor(x / y);
}
};

template <typename Scalar>
struct functor_traits<google_floor_div_real<Scalar>> {
enum {
Cost = 2 * Eigen::internal::scalar_div_cost<Scalar, false>::value +
2 * NumTraits<Scalar>::AddCost,
PacketAccess = false
};
};

// TODO(b//32239616): This kernel should be moved into Eigen and vectorized.
template <typename T>
struct google_floor_fmod {
Expand Down Expand Up @@ -611,6 +629,9 @@ struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
static const bool has_errors = true;
};

template <typename T>
struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};

template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op_google<T, T>> {};

Expand Down
40 changes: 36 additions & 4 deletions tensorflow/python/ops/math_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,16 +613,48 @@ def _MulGrad(op, grad):

@ops.RegisterGradient("Div")
def _DivGrad(op, grad):
"""The gradient for the Div operator."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) # pylint: disable=protected-access
# pylint: disable=protected-access
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
# pylint: enable=protected-access
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(math_ops.reduce_sum(math_ops.div(grad, y), rx), sx),
array_ops.reshape(math_ops.reduce_sum(
grad * math_ops.div(-x, math_ops.square(y)), ry), sy))


@ops.RegisterGradient("FloorDiv")
def _FloorDivGrad(_, unused_grad):
"""The gradient for the FloorDiv operator."""
return None, None


@ops.RegisterGradient("TruncateDiv")
def _TruncateDivGrad(_, unused_grad):
return None, None


@ops.RegisterGradient("RealDiv")
def _RealDivGrad(op, grad):
"""RealDiv op gradient."""
x = op.inputs[0]
y = op.inputs[1]
sx = array_ops.shape(x)
sy = array_ops.shape(y)
# pylint: disable=protected-access
rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
# pylint: enable=protected-access
x = math_ops.conj(x)
y = math_ops.conj(y)
return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx),
array_ops.reshape(math_ops.reduce_sum(grad *
(-x / math_ops.square(y)), ry), sy))
return (array_ops.reshape(math_ops.reduce_sum(
math_ops.realdiv(grad, y), rx), sx),
array_ops.reshape(math_ops.reduce_sum(
grad * math_ops.realdiv(-x, math_ops.square(y)), ry), sy))


@ops.RegisterGradient("Pow")
Expand Down
37 changes: 33 additions & 4 deletions tensorflow/python/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,32 @@ def truediv(x, y, name=None):
if dtype is not None:
x = cast(x, dtype)
y = cast(y, dtype)
return gen_math_ops.div(x, y, name=name)
return gen_math_ops.real_div(x, y, name=name)


def div(x, y, name=None):
with ops.name_scope(name, "truediv", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
y = ops.convert_to_tensor(y, name="y", dtype=x.dtype.base_dtype)
x_dtype = x.dtype.base_dtype
y_dtype = y.dtype.base_dtype
if x_dtype != y_dtype:
raise TypeError("x and y must have the same dtype, got %r != %r" %
(x_dtype, y_dtype))
if x_dtype.is_floating or x_dtype.is_complex:
return gen_math_ops.real_div(x, y, name=name)
else:
return gen_math_ops.floor_div(x, y, name=name)


def div_deprecated(x, y, name=None):
return gen_math_ops.div(x, y, name)

mod = gen_math_ops.floor_mod


def mod_deprecated(x, y, name=None):
return gen_math_ops.mod(x, y, name)


# TODO(aselle): Deprecate this once all internal functionality uses
Expand Down Expand Up @@ -959,6 +984,11 @@ def floordiv(x, y, name=None):
Raises:
TypeError: If the inputs are complex.
"""
with ops.name_scope(name, "floordiv", [x, y]) as name:
return gen_math_ops.floor_div(x, y, name=name)


def floordiv_deprecated(x, y, name=None):
with ops.name_scope(name, "floordiv", [x, y]) as name:
x = ops.convert_to_tensor(x, name="x")
dtype = x.dtype
Expand All @@ -971,7 +1001,6 @@ def floordiv(x, y, name=None):
# return gen_math_ops.floor_div(x, y, name=name)
return gen_math_ops.div(x, y, name=name)


realdiv = gen_math_ops.real_div
truncatediv = gen_math_ops.truncate_div
# TODO(aselle): Rename this to floordiv when we can.
Expand Down Expand Up @@ -1002,12 +1031,12 @@ def _mul_dispatch(x, y, name=None):
_OverrideBinaryOperatorHelper(gen_math_ops.add, "add")
_OverrideBinaryOperatorHelper(gen_math_ops.sub, "sub")
_OverrideBinaryOperatorHelper(_mul_dispatch, "mul")
_OverrideBinaryOperatorHelper(gen_math_ops.div, "div")
_OverrideBinaryOperatorHelper(div, "div")
_OverrideBinaryOperatorHelper(truediv, "truediv")
_OverrideBinaryOperatorHelper(floordiv, "floordiv")
# TODO(aselle): Switch mod to floor_mod when ready
# _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
_OverrideBinaryOperatorHelper(gen_math_ops.mod, "mod")
_OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
_OverrideBinaryOperatorHelper(pow, "pow")


Expand Down
21 changes: 21 additions & 0 deletions tensorflow/python/ops/math_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
Expand Down Expand Up @@ -305,6 +306,26 @@ def testRealDiv(self):
np_result = np.divide(nums, divs)
self.assertAllEqual(tf_result, np_result)

def testComplexDiv(self):
foo = array_ops.constant([1.+3.j])
with self.test_session():
_ = math_ops.div_deprecated(foo, 1.).eval()
_ = math_ops.div(foo, 2.).eval()

def testFloorDivGrad(self):
with self.test_session():
a = variables.Variable(2.)
b = variables.Variable(4.)
with self.test_session() as sess:
sess.run(variables.initialize_all_variables())
c_grad = gradients.gradients(math_ops.div_deprecated(a, b), [a, b])
self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
c_grad = gradients.gradients(math_ops.div(a, b), [a, b])
self.assertAllEqual([x.eval() for x in c_grad], [.25, -.125])
c_grad = gradients.gradients(math_ops.floordiv(a, b), [a, b])
self.assertAllEqual([None if x is None else x.eval() for x in c_grad],
[None, None])

def testConsistent(self):
nums, divs = self.intTestData()
with self.test_session():
Expand Down

0 comments on commit fcc3923

Please sign in to comment.