Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: add onnx_export option and onnx_comp_floor_div #1453

Merged
merged 4 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,8 +1182,12 @@ def _bin_op_tf(a, b):
elif kind == "mul":
return a * b
elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder
if util.is_onnx_export_global():
return tf_util.onnx_compat_floor_div(a, b)
return a // b
elif kind == "ceildiv":
if util.is_onnx_export_global():
return -tf_util.onnx_compat_floor_div(-a, b)
return -(-a // b)
else:
raise ValueError("unknown op kind %r" % op.kind)
Expand Down
21 changes: 12 additions & 9 deletions returnn/tf/frontend_low_level/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tensorflow as tf

import returnn.tf.compat as tf_compat
from returnn.util.basic import NotSpecified
from returnn.util.basic import NotSpecified, is_onnx_export_global
from returnn.tensor import Tensor, Dim
from returnn.tf.util import basic as tf_util

Expand Down Expand Up @@ -132,14 +132,17 @@ def combine_raw(a: tf.Tensor, kind: str, b: tf.Tensor) -> tf.Tensor:
:return: a `kind` b
"""
assert a.shape.ndims == b.shape.ndims or a.shape.ndims == 0 or b.shape.ndims == 0
kind = {
"sub": "subtract",
"mul": "multiply",
}.get(kind, kind)
op = getattr(tf, kind, None) # e.g. tf.add
# In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead
if op is None:
op = getattr(tf.math, kind)
if kind == "floordiv" and is_onnx_export_global():
op = tf_util.onnx_compat_floor_div
else:
kind = {
"sub": "subtract",
"mul": "multiply",
}.get(kind, kind)
op = getattr(tf, kind, None) # e.g. tf.add
# In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead
if op is None:
op = getattr(tf.math, kind)
with tf_util.same_control_flow_ctx([a, b]):
return op(a, b)

Expand Down
11 changes: 11 additions & 0 deletions returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7820,3 +7820,14 @@ def is_axis_from_description_recurrent(axis, network, data):
if axis == single_step_dim:
return True
return False


def onnx_compat_floor_div(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor:
"""
:param a:
:param b:
:return: for onnx export compatible floor_divide
"""
# https://github.com/onnx/tensorflow-onnx/issues/2174
abs_a, abs_b = tf.abs(a), tf.abs(b)
return tf.where(a * b >= 0, a // b, -abs_a // abs_b - tf.cast(abs_a % abs_b != 0, dtype=a.dtype))
12 changes: 12 additions & 0 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3627,6 +3627,18 @@ def get_global_inf_value() -> float:
return config.float("inf_value", _default_global_inf_value)


def is_onnx_export_global() -> bool:
"""
:return: False by default. If 'onnx_export' is set in the config, that value is used.
"""
from returnn.config import get_global_config

config = get_global_config(raise_exception=False)
if not config:
return False
return config.bool("onnx_export", False)


# See :func:`maybe_restart_returnn_with_atfork_patch` below for why you might want to use this.
_c_code_patch_atfork = """
#define _GNU_SOURCE
Expand Down
Loading