diff --git a/tests/rf_utils.py b/tests/rf_utils.py index c288f81b09..40f90126ae 100644 --- a/tests/rf_utils.py +++ b/tests/rf_utils.py @@ -8,19 +8,27 @@ import re import numpy import numpy.testing -import tensorflow as tf + +try: + import tensorflow as tf +except ImportError: + tf = globals().get("tf", None) # type: ignore from returnn.config import Config, global_config_ctx from returnn.util.pprint import pprint import returnn.frontend as rf from returnn.tensor import Tensor, Dim, TensorDict from returnn.tensor.utils import tensor_dict_fill_random_numpy_ -import returnn.tf.compat as tf_compat import returnn.torch.frontend as rft -import returnn.tf.frontend_layers as rfl -from returnn.tf.network import TFNetwork from returnn.torch.data.tensor_utils import tensor_dict_numpy_to_torch_, tensor_dict_torch_to_numpy_ +if tf: + import returnn.tf.compat as tf_compat + import returnn.tf.frontend_layers as rfl + from returnn.tf.network import TFNetwork +else: + tf_compat = rfl = TFNetwork = None + # noinspection PyProtectedMember from returnn.frontend._random_journal import RandomJournal