Skip to content

Commit

Permalink
Use tf.cond instead of tf.control_flow_ops.cond (#1460)
Browse files Browse the repository at this point in the history
  • Loading branch information
vieting authored Nov 9, 2023
1 parent a1a3228 commit 530582d
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3697,7 +3697,7 @@ def debug_register_better_repr():

def cond(pred, true_fn, false_fn, name=None):
"""
This is a wrapper around tf.control_flow_ops.cond().
This is a wrapper around tf.cond().
This will be a branched execution, i.e. either fn1() or fn2() will be executed,
or at least the resulting graph will be evaluated.
If pred can is constant at the call, only the corresponding fn will be called.
Expand Down Expand Up @@ -3727,9 +3727,8 @@ def cond(pred, true_fn, false_fn, name=None):
return true_fn()
else:
return false_fn()
from tensorflow.python.ops import control_flow_ops

return control_flow_ops.cond(pred, true_fn, false_fn, name=name)
return tf.cond(pred, true_fn, false_fn, name=name)


def single_strided_slice(x, axis, begin=None, end=None, step=None):
Expand Down

0 comments on commit 530582d

Please sign in to comment.