Skip to content

Commit

Permalink
use tf.cond instead of tf.control_flow_ops.cond
Browse files Browse the repository at this point in the history
  • Loading branch information
vieting committed Nov 9, 2023
1 parent e5a5d90 commit 17cc348
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3697,7 +3697,7 @@ def debug_register_better_repr():

def cond(pred, true_fn, false_fn, name=None):
"""
This is a wrapper around tf.control_flow_ops.cond().
This is a wrapper around tf.cond().
This will be a branched execution, i.e. either fn1() or fn2() will be executed,
or at least the resulting graph will be evaluated.
If pred can is constant at the call, only the corresponding fn will be called.
Expand Down Expand Up @@ -3727,9 +3727,8 @@ def cond(pred, true_fn, false_fn, name=None):
return true_fn()
else:
return false_fn()
from tensorflow.python.ops import control_flow_ops

return control_flow_ops.cond(pred, true_fn, false_fn, name=name)
return tf.cond(pred, true_fn, false_fn, name=name)


def single_strided_slice(x, axis, begin=None, end=None, step=None):
Expand Down

0 comments on commit 17cc348

Please sign in to comment.