Skip to content

Commit

Permalink
Merge branch 'demd' of github.com:x12hengyu/POT into demd
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyu02 committed Aug 2, 2023
2 parents a7bde66 + 018313b commit b370202
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 89 deletions.
73 changes: 45 additions & 28 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,27 @@
str_type_error = "All array should be from the same type/backend. Current types are : {}"


def get_backend_list():
"""Returns the list of available backends"""
lst = [NumpyBackend(), ]
# Mapping between argument types and the existing backend
_BACKENDS = []


if torch:
lst.append(TorchBackend())
def register_backend(backend):
_BACKENDS.append(backend)

if jax:
lst.append(JaxBackend())

if cp: # pragma: no cover
lst.append(CupyBackend())
def get_backend_list():
"""Returns the list of available backends"""
return _BACKENDS


if tf:
lst.append(TensorflowBackend())
def _check_args_backend(backend, args):
is_instance = set(isinstance(a, backend.__type__) for a in args)
# check that all arguments matched or not the type
if len(is_instance) == 1:
return is_instance.pop()

return lst
# Oterwise return an error
raise ValueError(str_type_error.format([type(a) for a in args]))


def get_backend(*args):
Expand All @@ -158,22 +162,12 @@ def get_backend(*args):
# check that some arrays given
if not len(args) > 0:
raise ValueError(" The function takes at least one parameter")
# check all same type
if not len(set(type(a) for a in args)) == 1:
raise ValueError(str_type_error.format([type(a) for a in args]))

if isinstance(args[0], np.ndarray):
return NumpyBackend()
elif isinstance(args[0], torch_type):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
elif isinstance(args[0], cp_type): # pragma: no cover
return CupyBackend()
elif isinstance(args[0], tf_type):
return TensorflowBackend()
else:
raise ValueError("Unknown type of non implemented backend.")

for backend in _BACKENDS:
if _check_args_backend(backend, args):
return backend

raise ValueError("Unknown type of non implemented backend.")


def to_numpy(*args):
Expand Down Expand Up @@ -1318,6 +1312,9 @@ def matmul(self, a, b):
return np.matmul(a, b)


register_backend(NumpyBackend())


class JaxBackend(Backend):
"""
JAX implementation of the backend
Expand Down Expand Up @@ -1676,6 +1673,11 @@ def matmul(self, a, b):
return jnp.matmul(a, b)


if jax:
# Only register jax backend if it is installed
register_backend(JaxBackend())


class TorchBackend(Backend):
"""
PyTorch implementation of the backend
Expand Down Expand Up @@ -2148,6 +2150,11 @@ def matmul(self, a, b):
return torch.matmul(a, b)


if torch:
# Only register torch backend if it is installed
register_backend(TorchBackend())


class CupyBackend(Backend): # pragma: no cover
"""
CuPy implementation of the backend
Expand Down Expand Up @@ -2530,6 +2537,11 @@ def matmul(self, a, b):
return cp.matmul(a, b)


if cp:
# Only register cp backend if it is installed
register_backend(CupyBackend())


class TensorflowBackend(Backend):

__name__ = "tf"
Expand Down Expand Up @@ -2930,3 +2942,8 @@ def detach(self, *args):

def matmul(self, a, b):
return tnp.matmul(a, b)


if tf:
# Only register tensorflow backend if it is installed
register_backend(TensorflowBackend())
2 changes: 1 addition & 1 deletion ot/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000,

log_e = {'err': []}

if type(a) == type(b) == type(M) == np.ndarray:
if nx.__name__ == "numpy":
# Next 3 lines equivalent to K=nx.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
Expand Down
2 changes: 1 addition & 1 deletion ot/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def c_transform_entropic(b, M, reg, beta):


def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
log=False):
log=False):
r'''
Compute the transportation matrix to solve the regularized discrete measures optimal transport max problem
Expand Down
82 changes: 23 additions & 59 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ot
import ot.backend
from ot.backend import torch, jax, cp, tf
from ot.backend import torch, jax, tf

import pytest

Expand Down Expand Up @@ -37,17 +37,7 @@ def test_to_numpy(nx):
assert isinstance(M2, np.ndarray)


def test_get_backend():

A = np.zeros((3, 2))
B = np.zeros((3, 1))

nx = get_backend(A)
assert nx.__name__ == 'numpy'

nx = get_backend(A, B)
assert nx.__name__ == 'numpy'

def test_get_backend_invalid():
# error if no parameters
with pytest.raises(ValueError):
get_backend()
Expand All @@ -56,64 +46,38 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(1, 2.0)

# test torch
if torch:

A2 = torch.from_numpy(A)
B2 = torch.from_numpy(B)
def test_get_backend(nx):

nx = get_backend(A2)
assert nx.__name__ == 'torch'

nx = get_backend(A2, B2)
assert nx.__name__ == 'torch'

# test not unique types in input
with pytest.raises(ValueError):
get_backend(A, B2)

if jax:

A2 = jax.numpy.array(A)
B2 = jax.numpy.array(B)

nx = get_backend(A2)
assert nx.__name__ == 'jax'

nx = get_backend(A2, B2)
assert nx.__name__ == 'jax'
A = np.zeros((3, 2))
B = np.zeros((3, 1))

# test not unique types in input
with pytest.raises(ValueError):
get_backend(A, B2)
nx_np = get_backend(A)
assert nx_np.__name__ == 'numpy'

if cp:
A2 = cp.asarray(A)
B2 = cp.asarray(B)
A2, B2 = nx.from_numpy(A, B)

nx = get_backend(A2)
assert nx.__name__ == 'cupy'
effective_nx = get_backend(A2)
assert effective_nx.__name__ == nx.__name__

nx = get_backend(A2, B2)
assert nx.__name__ == 'cupy'
effective_nx = get_backend(A2, B2)
assert effective_nx.__name__ == nx.__name__

# test not unique types in input
if nx.__name__ != "numpy":
# test that types mathcing different backends in input raise an error
with pytest.raises(ValueError):
get_backend(A, B2)
else:
# Check that subclassing a numpy array does not break get_backend
# note: This is only tested for numpy as this is hard to be consistent
# with other backends
class nx_subclass(nx.__type__):
pass

if tf:
A2 = tf.convert_to_tensor(A)
B2 = tf.convert_to_tensor(B)

nx = get_backend(A2)
assert nx.__name__ == 'tf'
A3 = nx_subclass(0)

nx = get_backend(A2, B2)
assert nx.__name__ == 'tf'

# test not unique types in input
with pytest.raises(ValueError):
get_backend(A, B2)
effective_nx = get_backend(A3, B2)
assert effective_nx.__name__ == nx.__name__


def test_convert_between_backends(nx):
Expand Down

0 comments on commit b370202

Please sign in to comment.