Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #1069, #1070 #1072

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"git_url" : "https://github.com/nv-legate/legate.core.git",
"git_shallow": false,
"always_download": false,
"git_tag" : "a4b5430ebb2c52e3f8da8f27534bc0db8826b804"
"git_tag" : "6fa0acc9dcfa89be2702f1de6c045bc262f752b1"
}
}
}
13 changes: 8 additions & 5 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4587,7 +4587,7 @@ def einsum(
out: Optional[ndarray] = None,
dtype: Optional[np.dtype[Any]] = None,
casting: CastingKind = "safe",
optimize: Union[bool, str] = False,
optimize: Union[bool, Literal["greedy", "optimal"]] = True,
) -> ndarray:
"""
Evaluates the Einstein summation convention on the operands.
Expand Down Expand Up @@ -4628,9 +4628,10 @@ def einsum(

Default is 'safe'.
optimize : ``{False, True, 'greedy', 'optimal'}``, optional
Controls if intermediate optimization should occur. No optimization
will occur if False. Uses opt_einsum to find an optimized contraction
plan if True.
Controls if intermediate optimization should occur. If False then
arrays will be contracted in input order, one at a time. True (the
default) will use the 'greedy' algorithm. See ``cunumeric.einsum_path``
for more information on the available optimization algorithms.

Returns
-------
Expand All @@ -4654,7 +4655,9 @@ def einsum(
if out is not None:
out = convert_to_cunumeric_ndarray(out, share=True)

if not optimize:
if optimize is True:
optimize = "greedy"
elif optimize is False:
optimize = NullOptimizer()

# This call normalizes the expression (adds the output part if it's
Expand Down
4 changes: 1 addition & 3 deletions tests/integration/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_cast(expr, dtype):
False,
"optimal",
"greedy",
pytest.param(True, marks=pytest.mark.xfail),
True,
],
)
def test_optimize(optimize):
Expand All @@ -282,8 +282,6 @@ def test_optimize(optimize):
np_res = np.einsum("ik,kj->ij", a, b, optimize=optimize)
num_res = num.einsum("ik,kj->ij", a, b, optimize=optimize)
assert allclose(np_res, num_res)
# when optimize=True, cunumeric raises
# TypeError: 'bool' object is not iterable


def test_expr_opposite():
Expand Down
Loading