diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index 60dacc1a44..183520f572 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -3697,7 +3697,7 @@ def debug_register_better_repr(): def cond(pred, true_fn, false_fn, name=None): """ - This is a wrapper around tf.control_flow_ops.cond(). + This is a wrapper around tf.cond(). This will be a branched execution, i.e. either fn1() or fn2() will be executed, or at least the resulting graph will be evaluated. If pred can is constant at the call, only the corresponding fn will be called. @@ -3727,9 +3727,8 @@ def cond(pred, true_fn, false_fn, name=None): return true_fn() else: return false_fn() - from tensorflow.python.ops import control_flow_ops - return control_flow_ops.cond(pred, true_fn, false_fn, name=name) + return tf.cond(pred, true_fn, false_fn, name=name) def single_strided_slice(x, axis, begin=None, end=None, step=None):