diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py index 4635e79fc9..bf4da5ca77 100644 --- a/returnn/tensor/_dim_extra.py +++ b/returnn/tensor/_dim_extra.py @@ -1182,8 +1182,12 @@ def _bin_op_tf(a, b): elif kind == "mul": return a * b elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder + if util.is_onnx_export_global(): + return tf_util.onnx_compat_floor_div(a, b) return a // b elif kind == "ceildiv": + if util.is_onnx_export_global(): + return -tf_util.onnx_compat_floor_div(-a, b) return -(-a // b) else: raise ValueError("unknown op kind %r" % op.kind) diff --git a/returnn/tf/frontend_low_level/_backend.py b/returnn/tf/frontend_low_level/_backend.py index 04e7a2817a..1e5825744c 100644 --- a/returnn/tf/frontend_low_level/_backend.py +++ b/returnn/tf/frontend_low_level/_backend.py @@ -8,7 +8,7 @@ import tensorflow as tf import returnn.tf.compat as tf_compat -from returnn.util.basic import NotSpecified +from returnn.util.basic import NotSpecified, is_onnx_export_global from returnn.tensor import Tensor, Dim from returnn.tf.util import basic as tf_util @@ -132,14 +132,17 @@ def combine_raw(a: tf.Tensor, kind: str, b: tf.Tensor) -> tf.Tensor: :return: a `kind` b """ assert a.shape.ndims == b.shape.ndims or a.shape.ndims == 0 or b.shape.ndims == 0 - kind = { - "sub": "subtract", - "mul": "multiply", - }.get(kind, kind) - op = getattr(tf, kind, None) # e.g. tf.add - # In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead - if op is None: - op = getattr(tf.math, kind) + if kind == "floordiv" and is_onnx_export_global(): + op = tf_util.onnx_compat_floor_div + else: + kind = { + "sub": "subtract", + "mul": "multiply", + }.get(kind, kind) + op = getattr(tf, kind, None) # e.g. tf.add + # In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead + if op is None: + op = getattr(tf.math, kind) with tf_util.same_control_flow_ctx([a, b]): return op(a, b) diff --git a/returnn/tf/util/basic.py b/returnn/tf/util/basic.py index 7cb3e6584b..60dacc1a44 100644 --- a/returnn/tf/util/basic.py +++ b/returnn/tf/util/basic.py @@ -7820,3 +7820,14 @@ def is_axis_from_description_recurrent(axis, network, data): if axis == single_step_dim: return True return False + + +def onnx_compat_floor_div(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: + """ + :param a: + :param b: + :return: for onnx export compatible floor_divide + """ + # https://github.com/onnx/tensorflow-onnx/issues/2174 + abs_a, abs_b = tf.abs(a), tf.abs(b) + return tf.where(a * b >= 0, a // b, -abs_a // abs_b - tf.cast(abs_a % abs_b != 0, dtype=a.dtype)) diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 405a8edf85..7e4828248a 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -3627,6 +3627,18 @@ def get_global_inf_value() -> float: return config.float("inf_value", _default_global_inf_value) +def is_onnx_export_global() -> bool: + """ + :return: False by default. If 'onnx_export' is set in the config, that value is used. + """ + from returnn.config import get_global_config + + config = get_global_config(raise_exception=False) + if not config: + return False + return config.bool("onnx_export", False) + + # See :func:`maybe_restart_returnn_with_atfork_patch` below for why you might want to use this. _c_code_patch_atfork = """ #define _GNU_SOURCE