Skip to content

Commit

Permalink
Refactor to use the same code for handling control dependency inside …
Browse files Browse the repository at this point in the history
…cond.

Change: 142214059
  • Loading branch information
yuanbyu authored and tensorflower-gardener committed Dec 16, 2016
1 parent af99086 commit c38776d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
13 changes: 13 additions & 0 deletions tensorflow/python/kernel_tests/control_flow_ops_py_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,19 @@ def testCondRef(self):
r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
self.assertAllEqual([2.0], r.eval())

def testCondWithControl(self):
with self.test_session() as sess:
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
def true_branch():
with ops.control_dependencies([control_holder]):
_ = a + 1
return a + 2
r = control_flow_ops.cond(constant_op.constant(True),
true_branch,
lambda: constant_op.constant(1))
self.assertEqual(5, r.eval())

def testUninitializedRefIdentity(self):
with self.test_session() as sess:
v = gen_state_ops._variable(
Expand Down
27 changes: 9 additions & 18 deletions tensorflow/python/ops/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,19 +1621,11 @@ def _AddOpInternal(self, op):
else:
for index in range(len(op.inputs)):
x = op.inputs[index]
if x.name not in self._values:
self._values.add(x.name)
# Add this value to the parent contexts up to the context that
# creates this value.
real_x = x
if self._outer_context:
real_x = self._outer_context.AddValue(x)
self._values.add(real_x.name)
real_x = _SwitchRefOrTensor(real_x, self._pred)[self._branch]
self._external_values[x.name] = real_x
x = self._external_values.get(x.name)
if x is not None:
op._update_input(index, x)
real_x = self.AddValue(x)
if real_x != x:
# pylint: disable=protected-access
op._update_input(index, real_x)
# pylint: enable=protected-access
for x in op.outputs:
self._values.add(x.name)
if self._outer_context or not IsLoopExit(op):
Expand Down Expand Up @@ -2060,9 +2052,8 @@ def _AddOpInternal(self, op):
else:
for index in range(len(op.inputs)):
x = op.inputs[index]
self.AddValue(x)
real_x = self._external_values.get(x.name)
if real_x is not None:
real_x = self.AddValue(x)
if real_x != x:
op._update_input(index, real_x)
# Remove any external control dependency on this op.
self._RemoveExternalControlEdges(op)
Expand Down Expand Up @@ -2161,8 +2152,8 @@ def AddBackPropLoopCounter(self, count, outer_grad_state):
merge_count = merge([enter_count, enter_count])[0]
self._pivot_for_pred = merge_count

cond = math_ops.greater_equal(merge_count, one)
self._pivot = loop_cond(cond, name="b_count")
pred = math_ops.greater_equal(merge_count, one)
self._pivot = loop_cond(pred, name="b_count")
switch_count = switch(merge_count, self._pivot)

index = math_ops.sub(switch_count[1], one)
Expand Down

0 comments on commit c38776d

Please sign in to comment.