From a87969037daf274b2952f6e534e59260d1e32ed5 Mon Sep 17 00:00:00 2001 From: Thomas Moreau Date: Wed, 2 Aug 2023 14:15:02 +0200 Subject: [PATCH] [MTN] more permissive check_backend (#494) * MTN more permissive check_backend * TST simplify tests for get_backend and test more * FIX pep8 --- ot/backend.py | 73 ++++++++++++++++++++++++--------------- ot/partial.py | 2 +- ot/stochastic.py | 2 +- test/test_backend.py | 82 +++++++++++++------------------------------- 4 files changed, 70 insertions(+), 89 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 33c323d92..1b6ca606b 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -131,23 +131,27 @@ str_type_error = "All array should be from the same type/backend. Current types are : {}" -def get_backend_list(): - """Returns the list of available backends""" - lst = [NumpyBackend(), ] +# Mapping between argument types and the existing backend +_BACKENDS = [] + - if torch: - lst.append(TorchBackend()) +def register_backend(backend): + _BACKENDS.append(backend) - if jax: - lst.append(JaxBackend()) - if cp: # pragma: no cover - lst.append(CupyBackend()) +def get_backend_list(): + """Returns the list of available backends""" + return _BACKENDS + - if tf: - lst.append(TensorflowBackend()) +def _check_args_backend(backend, args): + is_instance = set(isinstance(a, backend.__type__) for a in args) + # check that all arguments matched or not the type + if len(is_instance) == 1: + return is_instance.pop() - return lst + # Oterwise return an error + raise ValueError(str_type_error.format([type(a) for a in args])) def get_backend(*args): @@ -158,22 +162,12 @@ def get_backend(*args): # check that some arrays given if not len(args) > 0: raise ValueError(" The function takes at least one parameter") - # check all same type - if not len(set(type(a) for a in args)) == 1: - raise ValueError(str_type_error.format([type(a) for a in args])) - - if isinstance(args[0], np.ndarray): - return NumpyBackend() - elif isinstance(args[0], torch_type): - return TorchBackend() - elif isinstance(args[0], jax_type): - return JaxBackend() - elif isinstance(args[0], cp_type): # pragma: no cover - return CupyBackend() - elif isinstance(args[0], tf_type): - return TensorflowBackend() - else: - raise ValueError("Unknown type of non implemented backend.") + + for backend in _BACKENDS: + if _check_args_backend(backend, args): + return backend + + raise ValueError("Unknown type of non implemented backend.") def to_numpy(*args): @@ -1318,6 +1312,9 @@ def matmul(self, a, b): return np.matmul(a, b) +register_backend(NumpyBackend()) + + class JaxBackend(Backend): """ JAX implementation of the backend @@ -1676,6 +1673,11 @@ def matmul(self, a, b): return jnp.matmul(a, b) +if jax: + # Only register jax backend if it is installed + register_backend(JaxBackend()) + + class TorchBackend(Backend): """ PyTorch implementation of the backend @@ -2148,6 +2150,11 @@ def matmul(self, a, b): return torch.matmul(a, b) +if torch: + # Only register torch backend if it is installed + register_backend(TorchBackend()) + + class CupyBackend(Backend): # pragma: no cover """ CuPy implementation of the backend @@ -2530,6 +2537,11 @@ def matmul(self, a, b): return cp.matmul(a, b) +if cp: + # Only register cp backend if it is installed + register_backend(CupyBackend()) + + class TensorflowBackend(Backend): __name__ = "tf" @@ -2930,3 +2942,8 @@ def detach(self, *args): def matmul(self, a, b): return tnp.matmul(a, b) + + +if tf: + # Only register tensorflow backend if it is installed + register_backend(TensorflowBackend()) diff --git a/ot/partial.py b/ot/partial.py index 43f3362e7..85635c9ba 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -873,7 +873,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, log_e = {'err': []} - if type(a) == type(b) == type(M) == np.ndarray: + if nx.__name__ == "numpy": # Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) diff --git a/ot/stochastic.py b/ot/stochastic.py index 319d00670..79d971bfd 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -258,7 +258,7 @@ def c_transform_entropic(b, M, reg, beta): def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None, - log=False): + log=False): r''' Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem diff --git a/test/test_backend.py b/test/test_backend.py index 799ac54d3..8f7cd9ec1 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -7,7 +7,7 @@ import ot import ot.backend -from ot.backend import torch, jax, cp, tf +from ot.backend import torch, jax, tf import pytest @@ -37,17 +37,7 @@ def test_to_numpy(nx): assert isinstance(M2, np.ndarray) -def test_get_backend(): - - A = np.zeros((3, 2)) - B = np.zeros((3, 1)) - - nx = get_backend(A) - assert nx.__name__ == 'numpy' - - nx = get_backend(A, B) - assert nx.__name__ == 'numpy' - +def test_get_backend_invalid(): # error if no parameters with pytest.raises(ValueError): get_backend() @@ -56,64 +46,38 @@ def test_get_backend(): with pytest.raises(ValueError): get_backend(1, 2.0) - # test torch - if torch: - A2 = torch.from_numpy(A) - B2 = torch.from_numpy(B) +def test_get_backend(nx): - nx = get_backend(A2) - assert nx.__name__ == 'torch' - - nx = get_backend(A2, B2) - assert nx.__name__ == 'torch' - - # test not unique types in input - with pytest.raises(ValueError): - get_backend(A, B2) - - if jax: - - A2 = jax.numpy.array(A) - B2 = jax.numpy.array(B) - - nx = get_backend(A2) - assert nx.__name__ == 'jax' - - nx = get_backend(A2, B2) - assert nx.__name__ == 'jax' + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) - # test not unique types in input - with pytest.raises(ValueError): - get_backend(A, B2) + nx_np = get_backend(A) + assert nx_np.__name__ == 'numpy' - if cp: - A2 = cp.asarray(A) - B2 = cp.asarray(B) + A2, B2 = nx.from_numpy(A, B) - nx = get_backend(A2) - assert nx.__name__ == 'cupy' + effective_nx = get_backend(A2) + assert effective_nx.__name__ == nx.__name__ - nx = get_backend(A2, B2) - assert nx.__name__ == 'cupy' + effective_nx = get_backend(A2, B2) + assert effective_nx.__name__ == nx.__name__ - # test not unique types in input + if nx.__name__ != "numpy": + # test that types mathcing different backends in input raise an error with pytest.raises(ValueError): get_backend(A, B2) + else: + # Check that subclassing a numpy array does not break get_backend + # note: This is only tested for numpy as this is hard to be consistent + # with other backends + class nx_subclass(nx.__type__): + pass - if tf: - A2 = tf.convert_to_tensor(A) - B2 = tf.convert_to_tensor(B) - - nx = get_backend(A2) - assert nx.__name__ == 'tf' + A3 = nx_subclass(0) - nx = get_backend(A2, B2) - assert nx.__name__ == 'tf' - - # test not unique types in input - with pytest.raises(ValueError): - get_backend(A, B2) + effective_nx = get_backend(A3, B2) + assert effective_nx.__name__ == nx.__name__ def test_convert_between_backends(nx):