Skip to content

Commit

Permalink
Add dynamic shape check to Bernoulli. This lets us avoid broadcasting…
Browse files Browse the repository at this point in the history
… in the case where logits and event are the same shape.

Change: 141110652
  • Loading branch information
tensorflower-gardener committed Dec 6, 2016
1 parent 9423429 commit 8f332c0
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def _testPmf(self, **kwargs):
self.assertAllClose(dist.pmf(x).eval(), expected_pmf)
self.assertAllClose(dist.log_pmf(x).eval(), np.log(expected_pmf))

def testPmfCorrectBroadcastDynamicShape(self):
with self.test_session():
p = tf.placeholder(dtype=tf.float32)
dist = tf.contrib.distributions.Bernoulli(p=p)
event1 = [1, 0, 1]
event2 = [[1, 0, 1]]
self.assertAllClose(dist.pmf(event1).eval({p: [0.2, 0.3, 0.4]}),
[0.2, 0.7, 0.4])
self.assertAllClose(dist.pmf(event2).eval({p: [0.2, 0.3, 0.4]}),
[[0.2, 0.7, 0.4]])

def testPmfWithP(self):
p = [[0.2, 0.4], [0.3, 0.6]]
self._testPmf(p=p)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,89 @@ def testLogCombinationsShape(self):
self.assertEqual([2, 2], log_binom.get_shape())


class DynamicShapeTest(tf.test.TestCase):

def testSameDynamicShape(self):
with self.test_session():
scalar = tf.constant(2.0)
scalar1 = tf.placeholder(dtype=tf.float32)

vector = [0.3, 0.4, 0.5]
vector1 = tf.placeholder(dtype=tf.float32, shape=[None])
vector2 = tf.placeholder(dtype=tf.float32, shape=[None])

multidimensional = [[0.3, 0.4], [0.2, 0.6]]
multidimensional1 = tf.placeholder(dtype=tf.float32, shape=[None, None])
multidimensional2 = tf.placeholder(dtype=tf.float32, shape=[None, None])

# Scalar
self.assertTrue(distribution_util.same_dynamic_shape(
scalar, scalar1).eval({
scalar1: 2.0}))

# Vector

self.assertTrue(distribution_util.same_dynamic_shape(
vector, vector1).eval({
vector1: [2.0, 3.0, 4.0]}))
self.assertTrue(distribution_util.same_dynamic_shape(
vector1, vector2).eval({
vector1: [2.0, 3.0, 4.0],
vector2: [2.0, 3.5, 6.0]}))

# Multidimensional
self.assertTrue(distribution_util.same_dynamic_shape(
multidimensional, multidimensional1).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]}))
self.assertTrue(distribution_util.same_dynamic_shape(
multidimensional1, multidimensional2).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]],
multidimensional2: [[1.0, 3.5], [6.3, 2.3]]}))


# Scalar, X
self.assertFalse(distribution_util.same_dynamic_shape(
scalar, vector1).eval({
vector1: [2.0, 3.0, 4.0]}))
self.assertFalse(distribution_util.same_dynamic_shape(
scalar1, vector1).eval({
scalar1: 2.0,
vector1: [2.0, 3.0, 4.0]}))
self.assertFalse(distribution_util.same_dynamic_shape(
scalar, multidimensional1).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]}))
self.assertFalse(distribution_util.same_dynamic_shape(
scalar1, multidimensional1).eval({
scalar1: 2.0,
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]}))

# Vector, X
self.assertFalse(distribution_util.same_dynamic_shape(
vector, vector1).eval({
vector1: [2.0, 3.0]}))
self.assertFalse(distribution_util.same_dynamic_shape(
vector1, vector2).eval({
vector1: [2.0, 3.0, 4.0],
vector2: [6.0]}))
self.assertFalse(distribution_util.same_dynamic_shape(
vector, multidimensional1).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]}))
self.assertFalse(distribution_util.same_dynamic_shape(
vector1, multidimensional1).eval({
vector1: [2.0, 3.0, 4.0],
multidimensional1: [[2.0, 3.0], [3.0, 4.0]]}))

# Multidimensional, X
self.assertFalse(distribution_util.same_dynamic_shape(
multidimensional, multidimensional1).eval({
multidimensional1: [[1.0, 3.5, 5.0], [6.3, 2.3, 7.1]]}))
self.assertFalse(distribution_util.same_dynamic_shape(
multidimensional1, multidimensional2).eval({
multidimensional1: [[2.0, 3.0], [3.0, 4.0]],
multidimensional2: [[1.0, 3.5, 5.0], [6.3, 2.3, 7.1]]}))



class RotateTransposeTest(tf.test.TestCase):

def _np_rotate_transpose(self, x, shift):
Expand Down
23 changes: 16 additions & 7 deletions tensorflow/contrib/distributions/python/ops/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
Expand Down Expand Up @@ -131,13 +132,21 @@ def _log_prob(self, event):
logits = self.logits
# sigmoid_cross_entropy_with_logits doesn't broadcast shape,
# so we do this here.
# TODO(b/30637701): Check dynamic shape, and don't broadcast if the
# dynamic shapes are the same.
if (not event.get_shape().is_fully_defined() or
not logits.get_shape().is_fully_defined() or
event.get_shape() != logits.get_shape()):
logits = array_ops.ones_like(event) * logits
event = array_ops.ones_like(logits) * event

broadcast = lambda logits, event: (
array_ops.ones_like(event) * logits,
array_ops.ones_like(logits) * event)

# First check static shape.
if (event.get_shape().is_fully_defined() and
logits.get_shape().is_fully_defined()):
if event.get_shape() != logits.get_shape():
logits, event = broadcast(logits, event)
else:
logits, event = control_flow_ops.cond(
distribution_util.same_dynamic_shape(logits, event),
lambda: (logits, event),
lambda: broadcast(logits, event))
return -nn.sigmoid_cross_entropy_with_logits(logits, event)

def _prob(self, event):
Expand Down
31 changes: 31 additions & 0 deletions tensorflow/contrib/distributions/python/ops/distribution_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn

Expand Down Expand Up @@ -104,6 +105,36 @@ def assert_symmetric(matrix):
[check_ops.assert_equal(matrix, matrix_t)], matrix)


def same_dynamic_shape(a, b):
"""Returns whether a and b have the same dynamic shape.
Args:
a: `Tensor`
b: `Tensor`
Returns:
`Boolean` `Tensor` representing if both tensors have the same shape.
"""
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")

# One of the shapes isn't fully defined, so we need to use the dynamic
# shape.
return control_flow_ops.cond(
math_ops.equal(array_ops.rank(a), array_ops.rank(b)),
# Here we can't just do math_ops.equal(a.shape, b.shape), since
# static shape inference may break the equality comparison between
# shape(a) and shape(b) in math_ops.equal.
lambda: math_ops.reduce_all(math_ops.equal(
array_ops.concat(0, (
array_ops.shape(a),
array_ops.shape(b))),
array_ops.concat(0, (
array_ops.shape(b),
array_ops.shape(a))))),
lambda: constant_op.constant(False))


def get_logits_and_prob(
logits=None, p=None,
multidimensional=False, validate_args=False, name="GetLogitsAndProb"):
Expand Down

0 comments on commit 8f332c0

Please sign in to comment.