From 731f1bf2f679f00e488a28cf06ac344cd8cff742 Mon Sep 17 00:00:00 2001 From: Yuchen Pang <32781495+yuchenpang@users.noreply.github.com> Date: Tue, 7 Jul 2020 18:51:27 -0500 Subject: [PATCH] Use opt_einsum in NumPy backend numpy.einsum is very slow in some cases. --- setup.py | 1 + tensorbackends/backends/numpy/numpy_backend.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b2a020c..296a205 100644 --- a/setup.py +++ b/setup.py @@ -20,5 +20,6 @@ install_requires=[ 'numpy', 'scipy', + 'opt_einsum', ], ) diff --git a/tensorbackends/backends/numpy/numpy_backend.py b/tensorbackends/backends/numpy/numpy_backend.py index 229d5a2..47a6c1b 100644 --- a/tensorbackends/backends/numpy/numpy_backend.py +++ b/tensorbackends/backends/numpy/numpy_backend.py @@ -6,6 +6,7 @@ import numpy as np import numpy.linalg as la +from opt_einsum import contract from ...interface import Backend from ...utils import einstr @@ -175,7 +176,7 @@ def wrapped_result(*args, **kwargs): return result def _einsum(self, expr, operands): - result = np.einsum(expr.indices_string, *(operand.tsr for operand in operands), optimize='greedy') + result = contract(expr.indices_string, *(operand.tsr for operand in operands)) if isinstance(result, np.ndarray) and result.ndim != 0: newshape = expr.outputs[0].newshape(result.shape) result = result.reshape(*newshape)