diff --git a/setup.py b/setup.py index b2a020c..296a205 100644 --- a/setup.py +++ b/setup.py @@ -20,5 +20,6 @@ install_requires=[ 'numpy', 'scipy', + 'opt_einsum', ], ) diff --git a/tensorbackends/backends/numpy/numpy_backend.py b/tensorbackends/backends/numpy/numpy_backend.py index 229d5a2..47a6c1b 100644 --- a/tensorbackends/backends/numpy/numpy_backend.py +++ b/tensorbackends/backends/numpy/numpy_backend.py @@ -6,6 +6,7 @@ import numpy as np import numpy.linalg as la +from opt_einsum import contract from ...interface import Backend from ...utils import einstr @@ -175,7 +176,7 @@ def wrapped_result(*args, **kwargs): return result def _einsum(self, expr, operands): - result = np.einsum(expr.indices_string, *(operand.tsr for operand in operands), optimize='greedy') + result = contract(expr.indices_string, *(operand.tsr for operand in operands)) if isinstance(result, np.ndarray) and result.ndim != 0: newshape = expr.outputs[0].newshape(result.shape) result = result.reshape(*newshape)