Skip to content

Commit

Permalink
Add missing store overlap checks (#1111)
Browse files Browse the repository at this point in the history
* Add missing store overlap checks

* Add unit test for various cases of overlap

* Add put and putmask to testcase

* No need to test overlaps() in .setitem(), .copy() will do it

* Fix self-assignment skip for the case a[1:] = a[1:]

* Can't take storage of Stores indiscriminately

* Update legate.core hash
  • Loading branch information
manopapad authored Jan 26, 2024
1 parent b4e164b commit 8693a3d
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 13 deletions.
2 changes: 1 addition & 1 deletion cmake/versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"git_url" : "https://github.com/nv-legate/legate.core.git",
"git_shallow": false,
"always_download": false,
"git_tag" : "08da13fc544f3db26bf1ef7ce9bdb85e72a9d9fb"
"git_tag" : "85c2a247a6b2c8086e57568ef0056045c3e175e3"
}
}
}
30 changes: 18 additions & 12 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def _copy_if_overlapping(self, other: DeferredArray) -> DeferredArray:
self.runtime.create_empty_thunk(
self.shape,
self.base.type,
inputs=[self],
),
)
copy.copy(self, deep=True)
Expand Down Expand Up @@ -1098,22 +1099,20 @@ def set_item(self, key: Any, rhs: Any) -> None:
# to set the result back. In cuNumeric, the object we
# return in step (1) is actually a subview to the array arr
# through which we make updates in place, so after step (2) is
# done, # the effect of inplace update is already reflected
# done, the effect of inplace update is already reflected
# to the arr. Therefore, we skip the copy to avoid redundant
# copies if we know that we hit such a scenario.
# TODO: We should make this work for the advanced indexing case
if view.base == rhs.base:
# NOTE: Neither Store nor Storage have an __eq__, so we can
# only check that the underlying RegionField/Future corresponds
# to the same Legion handle.
if (
view.base.has_storage
and rhs.base.has_storage
and view.base.storage.same_handle(rhs.base.storage)
):
return

if view.base.overlaps(rhs.base):
rhs_copy = self.runtime.create_empty_thunk(
rhs.shape,
rhs.base.type,
inputs=[rhs],
)
rhs_copy.copy(rhs, deep=False)
rhs = rhs_copy

view.copy(rhs, deep=False)

def broadcast_to(self, shape: NdShape) -> NumPyThunk:
Expand Down Expand Up @@ -1870,6 +1869,9 @@ def put(self, indices: Any, values: Any, check_bounds: bool) -> None:

assert indices.size == values.size

# Handle store overlap
values = values._copy_if_overlapping(self_tmp)

# first, we create indirect array with PointN type that
# (indices.size,) shape and is used to copy data from values
# to the target ND array (self)
Expand Down Expand Up @@ -1910,11 +1912,12 @@ def put(self, indices: Any, values: Any, check_bounds: bool) -> None:
@auto_convert("mask", "values")
def putmask(self, mask: Any, values: Any) -> None:
assert self.shape == mask.shape

values = values._copy_if_overlapping(self)
if values.shape != self.shape:
values_new = values._broadcast(self.shape)
else:
values_new = values.base

task = self.context.create_auto_task(CuNumericOpCode.PUTMASK)
task.add_input(self.base)
task.add_input(mask.base)
Expand Down Expand Up @@ -3142,6 +3145,7 @@ def unary_op(
multiout: Optional[Any] = None,
) -> None:
lhs = self.base
src = src._copy_if_overlapping(self)
rhs = src._broadcast(lhs.shape)

with Annotation({"OpCode": op.name}):
Expand Down Expand Up @@ -3304,7 +3308,9 @@ def binary_op(
args: Any,
) -> None:
lhs = self.base
src1 = src1._copy_if_overlapping(self)
rhs1 = src1._broadcast(lhs.shape)
src2 = src2._copy_if_overlapping(self)
rhs2 = src2._broadcast(lhs.shape)

with Annotation({"OpCode": op_code.name}):
Expand Down
80 changes: 80 additions & 0 deletions tests/integration/test_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pytest
from utils.generators import mk_seq_array

import cunumeric as num


def setitem(lib, a, slice_lhs, slice_rhs):
a[slice_lhs] = a[slice_rhs]


def dot(lib, a, slice_lhs, slice_rhs):
modes = "".join([chr(ord("a") + m) for m in range(len(a.shape))])
expr = f"{modes},{modes}->{modes}"
lib.einsum(expr, a[slice_lhs], a[slice_rhs], out=a[slice_lhs])


def unary_arith(lib, a, slice_lhs, slice_rhs):
lib.sin(a[slice_rhs], out=a[slice_lhs])


def binary_arith(lib, a, slice_lhs, slice_rhs):
a[slice_lhs] += a[slice_rhs]


def put(lib, a, slice_lhs, slice_rhs):
indices = lib.flip(lib.arange(a[slice_rhs].size))
a[slice_lhs].put(indices, a[slice_rhs])


def putmask(lib, a, slice_lhs, slice_rhs):
mask = (mk_seq_array(lib, a[slice_rhs].shape) % 2).astype(bool)
lib.putmask(a[slice_lhs], mask, a[slice_rhs])


SHAPES = ((4,), (4, 5), (4, 5, 6))
OPERATIONS = (setitem, dot, unary_arith, binary_arith, put, putmask)


@pytest.mark.parametrize("partial", (True, False))
@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("operation", OPERATIONS)
def test_partial(partial, shape, operation):
if partial:
# e.g. for shape = (4,5) and setitem: lhs[1:,:] = rhs[:-1,:]
slice_lhs = (slice(1, None),) + (slice(None),) * (len(shape) - 1)
slice_rhs = (slice(None, -1),) + (slice(None),) * (len(shape) - 1)
else:
# e.g. for shape = (4,5) and setitem: lhs[:,:] = rhs[:,:]
slice_lhs = (slice(None),) * len(shape)
slice_rhs = (slice(None),) * len(shape)

a_np = mk_seq_array(np, shape).astype(np.float64)
a_num = mk_seq_array(num, shape).astype(np.float64)

operation(np, a_np, slice_lhs, slice_rhs)
operation(num, a_num, slice_lhs, slice_rhs)

assert np.array_equal(a_np, a_num)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(sys.argv))

0 comments on commit 8693a3d

Please sign in to comment.