From 8c67416c95583ca29e9a712e1cb2060ec57fd565 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Tue, 7 Nov 2023 10:08:05 -0800 Subject: [PATCH] Fixes #1069, #1070 (#1072) * Find handling of optimize=True in einsum * Use einsum path optimizer by default cuNumeric can only contract two arrays at a time, so the naive input-order contraction path can easily result in huge intermediates. * Bump legate.core git hash --- cmake/versions.json | 2 +- cunumeric/module.py | 13 ++++++++----- tests/integration/test_einsum.py | 4 +--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/cmake/versions.json b/cmake/versions.json index 93a1d8001..43d60fa5e 100644 --- a/cmake/versions.json +++ b/cmake/versions.json @@ -5,7 +5,7 @@ "git_url" : "https://github.com/nv-legate/legate.core.git", "git_shallow": false, "always_download": false, - "git_tag" : "a4b5430ebb2c52e3f8da8f27534bc0db8826b804" + "git_tag" : "6fa0acc9dcfa89be2702f1de6c045bc262f752b1" } } } diff --git a/cunumeric/module.py b/cunumeric/module.py index e2bbc78f7..47c4dea90 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -4587,7 +4587,7 @@ def einsum( out: Optional[ndarray] = None, dtype: Optional[np.dtype[Any]] = None, casting: CastingKind = "safe", - optimize: Union[bool, str] = False, + optimize: Union[bool, Literal["greedy", "optimal"]] = True, ) -> ndarray: """ Evaluates the Einstein summation convention on the operands. @@ -4628,9 +4628,10 @@ def einsum( Default is 'safe'. optimize : ``{False, True, 'greedy', 'optimal'}``, optional - Controls if intermediate optimization should occur. No optimization - will occur if False. Uses opt_einsum to find an optimized contraction - plan if True. + Controls if intermediate optimization should occur. If False then + arrays will be contracted in input order, one at a time. True (the + default) will use the 'greedy' algorithm. See ``cunumeric.einsum_path`` + for more information on the available optimization algorithms. Returns ------- @@ -4654,7 +4655,9 @@ def einsum( if out is not None: out = convert_to_cunumeric_ndarray(out, share=True) - if not optimize: + if optimize is True: + optimize = "greedy" + elif optimize is False: optimize = NullOptimizer() # This call normalizes the expression (adds the output part if it's diff --git a/tests/integration/test_einsum.py b/tests/integration/test_einsum.py index e482e8cf0..4fcdd2402 100644 --- a/tests/integration/test_einsum.py +++ b/tests/integration/test_einsum.py @@ -272,7 +272,7 @@ def test_cast(expr, dtype): False, "optimal", "greedy", - pytest.param(True, marks=pytest.mark.xfail), + True, ], ) def test_optimize(optimize): @@ -282,8 +282,6 @@ def test_optimize(optimize): np_res = np.einsum("ik,kj->ij", a, b, optimize=optimize) num_res = num.einsum("ik,kj->ij", a, b, optimize=optimize) assert allclose(np_res, num_res) - # when optimize=True, cunumeric raises - # TypeError: 'bool' object is not iterable def test_expr_opposite():