diff --git a/cmake/versions.json b/cmake/versions.json index 93a1d8001..43d60fa5e 100644 --- a/cmake/versions.json +++ b/cmake/versions.json @@ -5,7 +5,7 @@ "git_url" : "https://github.com/nv-legate/legate.core.git", "git_shallow": false, "always_download": false, - "git_tag" : "a4b5430ebb2c52e3f8da8f27534bc0db8826b804" + "git_tag" : "6fa0acc9dcfa89be2702f1de6c045bc262f752b1" } } } diff --git a/cunumeric/module.py b/cunumeric/module.py index e2bbc78f7..47c4dea90 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -4587,7 +4587,7 @@ def einsum( out: Optional[ndarray] = None, dtype: Optional[np.dtype[Any]] = None, casting: CastingKind = "safe", - optimize: Union[bool, str] = False, + optimize: Union[bool, Literal["greedy", "optimal"]] = True, ) -> ndarray: """ Evaluates the Einstein summation convention on the operands. @@ -4628,9 +4628,10 @@ def einsum( Default is 'safe'. optimize : ``{False, True, 'greedy', 'optimal'}``, optional - Controls if intermediate optimization should occur. No optimization - will occur if False. Uses opt_einsum to find an optimized contraction - plan if True. + Controls if intermediate optimization should occur. If False then + arrays will be contracted in input order, one at a time. True (the + default) will use the 'greedy' algorithm. See ``cunumeric.einsum_path`` + for more information on the available optimization algorithms. Returns ------- @@ -4654,7 +4655,9 @@ def einsum( if out is not None: out = convert_to_cunumeric_ndarray(out, share=True) - if not optimize: + if optimize is True: + optimize = "greedy" + elif optimize is False: optimize = NullOptimizer() # This call normalizes the expression (adds the output part if it's diff --git a/tests/integration/test_einsum.py b/tests/integration/test_einsum.py index e482e8cf0..4fcdd2402 100644 --- a/tests/integration/test_einsum.py +++ b/tests/integration/test_einsum.py @@ -272,7 +272,7 @@ def test_cast(expr, dtype): False, "optimal", "greedy", - pytest.param(True, marks=pytest.mark.xfail), + True, ], ) def test_optimize(optimize): @@ -282,8 +282,6 @@ def test_optimize(optimize): np_res = np.einsum("ik,kj->ij", a, b, optimize=optimize) num_res = num.einsum("ik,kj->ij", a, b, optimize=optimize) assert allclose(np_res, num_res) - # when optimize=True, cunumeric raises - # TypeError: 'bool' object is not iterable def test_expr_opposite():