From eac14fea834514463f8305eb4204d2d9cde4877c Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Mon, 6 Nov 2023 17:56:18 -0800 Subject: [PATCH 1/3] Find handling of optimize=True in einsum --- cunumeric/module.py | 10 ++++++---- tests/integration/test_einsum.py | 4 +--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cunumeric/module.py b/cunumeric/module.py index e2bbc78f7..f993a63cb 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -4587,7 +4587,7 @@ def einsum( out: Optional[ndarray] = None, dtype: Optional[np.dtype[Any]] = None, casting: CastingKind = "safe", - optimize: Union[bool, str] = False, + optimize: Union[bool, Literal["greedy", "optimal"]] = False, ) -> ndarray: """ Evaluates the Einstein summation convention on the operands. @@ -4629,8 +4629,8 @@ def einsum( Default is 'safe'. optimize : ``{False, True, 'greedy', 'optimal'}``, optional Controls if intermediate optimization should occur. No optimization - will occur if False. Uses opt_einsum to find an optimized contraction - plan if True. + will occur if False, and True will default to the 'greedy' algorithm. + Defaults to False. Returns ------- @@ -4654,7 +4654,9 @@ def einsum( if out is not None: out = convert_to_cunumeric_ndarray(out, share=True) - if not optimize: + if optimize is True: + optimize = "greedy" + elif optimize is False: optimize = NullOptimizer() # This call normalizes the expression (adds the output part if it's diff --git a/tests/integration/test_einsum.py b/tests/integration/test_einsum.py index e482e8cf0..4fcdd2402 100644 --- a/tests/integration/test_einsum.py +++ b/tests/integration/test_einsum.py @@ -272,7 +272,7 @@ def test_cast(expr, dtype): False, "optimal", "greedy", - pytest.param(True, marks=pytest.mark.xfail), + True, ], ) def test_optimize(optimize): @@ -282,8 +282,6 @@ def test_optimize(optimize): np_res = np.einsum("ik,kj->ij", a, b, optimize=optimize) num_res = num.einsum("ik,kj->ij", a, b, optimize=optimize) assert allclose(np_res, num_res) - # when optimize=True, cunumeric raises - # TypeError: 'bool' object is not iterable def test_expr_opposite(): From 8a05e0e1f7a883ef7c361b6ba6fed4bbe8e3fe62 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Mon, 6 Nov 2023 17:56:43 -0800 Subject: [PATCH 2/3] Use einsum path optimizer by default cuNumeric can only contract two arrays at a time, so the naive input-order contraction path can easily result in huge intermediates. --- cunumeric/module.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cunumeric/module.py b/cunumeric/module.py index f993a63cb..47c4dea90 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -4587,7 +4587,7 @@ def einsum( out: Optional[ndarray] = None, dtype: Optional[np.dtype[Any]] = None, casting: CastingKind = "safe", - optimize: Union[bool, Literal["greedy", "optimal"]] = False, + optimize: Union[bool, Literal["greedy", "optimal"]] = True, ) -> ndarray: """ Evaluates the Einstein summation convention on the operands. @@ -4628,9 +4628,10 @@ def einsum( Default is 'safe'. optimize : ``{False, True, 'greedy', 'optimal'}``, optional - Controls if intermediate optimization should occur. No optimization - will occur if False, and True will default to the 'greedy' algorithm. - Defaults to False. + Controls if intermediate optimization should occur. If False then + arrays will be contracted in input order, one at a time. True (the + default) will use the 'greedy' algorithm. See ``cunumeric.einsum_path`` + for more information on the available optimization algorithms. Returns ------- From 1e0218671117ae691774d3e65f116fde5bdfb29d Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Mon, 6 Nov 2023 21:18:04 -0800 Subject: [PATCH 3/3] Bump legate.core git hash --- cmake/versions.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/versions.json b/cmake/versions.json index 93a1d8001..43d60fa5e 100644 --- a/cmake/versions.json +++ b/cmake/versions.json @@ -5,7 +5,7 @@ "git_url" : "https://github.com/nv-legate/legate.core.git", "git_shallow": false, "always_download": false, - "git_tag" : "a4b5430ebb2c52e3f8da8f27534bc0db8826b804" + "git_tag" : "6fa0acc9dcfa89be2702f1de6c045bc262f752b1" } } }