Skip to content

Commit

Permalink
TF add onnx_export option, onnx_comp_floor_div
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerstenberger committed Nov 6, 2023
1 parent 52abd3a commit 09f1048
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
21 changes: 12 additions & 9 deletions returnn/tf/frontend_low_level/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tensorflow as tf

import returnn.tf.compat as tf_compat
from returnn.util.basic import NotSpecified
from returnn.util.basic import NotSpecified, is_onnx_export_global
from returnn.tensor import Tensor, Dim
from returnn.tf.util import basic as tf_util

Expand Down Expand Up @@ -132,14 +132,17 @@ def combine_raw(a: tf.Tensor, kind: str, b: tf.Tensor) -> tf.Tensor:
:return: a `kind` b
"""
assert a.shape.ndims == b.shape.ndims or a.shape.ndims == 0 or b.shape.ndims == 0
kind = {
"sub": "subtract",
"mul": "multiply",
}.get(kind, kind)
op = getattr(tf, kind, None) # e.g. tf.add
# In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead
if op is None:
op = getattr(tf.math, kind)
if kind == "floordiv" and is_onnx_export_global():
op = tf_util.onnx_compat_floor_div
else:
kind = {
"sub": "subtract",
"mul": "multiply",
}.get(kind, kind)
op = getattr(tf, kind, None) # e.g. tf.add
# In tf v2, some ops like floordiv or mod exist in the tf.math namespace instead
if op is None:
op = getattr(tf.math, kind)
with tf_util.same_control_flow_ctx([a, b]):
return op(a, b)

Expand Down
15 changes: 15 additions & 0 deletions returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7820,3 +7820,18 @@ def is_axis_from_description_recurrent(axis, network, data):
if axis == single_step_dim:
return True
return False


def onnx_compat_floor_div(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor:
"""
:param a:
:param b:
:return: for onnx export compatible floor_divide
"""
# https://github.com/onnx/tensorflow-onnx/issues/2174
abs_a, abs_b = tf.abs(a), tf.abs(b)
return tf.where(
a * b >= 0,
a // b,
-abs_a // abs_b - tf.cast(abs_a % abs_b != 0, dtype=a.dtype)
)
12 changes: 12 additions & 0 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3627,6 +3627,18 @@ def get_global_inf_value() -> float:
return config.float("inf_value", _default_global_inf_value)


def is_onnx_export_global() -> bool:
"""
:return: False by default. If 'onnx_export' is set in the config, that value is used.
"""
from returnn.config import get_global_config

config = get_global_config(raise_exception=False)
if not config:
return False
return config.bool("onnx_export", False)


# See :func:`maybe_restart_returnn_with_atfork_patch` below for why you might want to use this.
_c_code_patch_atfork = """
#define _GNU_SOURCE
Expand Down

0 comments on commit 09f1048

Please sign in to comment.