Skip to content

Commit

Permalink
Merge pull request #117 from clEsperanto/add-array-indexing
Browse files Browse the repository at this point in the history
add setitem and getitem
  • Loading branch information
StRigaud authored Dec 1, 2023
2 parents e77dada + b8b7f60 commit fa2eaeb
Show file tree
Hide file tree
Showing 9 changed files with 799 additions and 106 deletions.
45 changes: 21 additions & 24 deletions pyclesperanto/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,22 @@ def __repr__(self) -> str:
return repr_str[:-1] + f", {extra_info})"


def set(self, array: np.ndarray) -> None:
def set(self, array: np.ndarray, origin: tuple = None, region: tuple = None) -> None:
# for cast array to numpy array
if not isinstance(array, np.ndarray):
array = np.array(array)
if array.dtype != self.dtype:
warnings.warn(
f"Array dtype mismatch. Casting array to '{self.dtype.__name__}' before set().",
RuntimeWarning,
)
array = array.astype(self.dtype)
if array.size != self.size:
raise ValueError(
f"Array size mismatch: {array.size} != {self.size} ({array.shape} != {self.shape})"
)
if array.ndim != self.ndim:

if region and array.size != np.prod(region):
raise ValueError(
f"Array dimension mismatch: {array.ndim} != {self.ndim} ({array.shape} != {self.shape})"
f"Value size mismatch the targeted region: {array.size} != {np.prod(region)} ({array.shape} != {tuple(np.squeeze(region))})"
)
if array.shape != self.shape:
raise ValueError(f"Array shape mismatch: {array.shape} != {self.shape}")
self._write(array)
self._write(array, origin, region)
return self


def get(self) -> np.ndarray:
def get(self, origin: tuple = None, region: tuple = None) -> np.ndarray:
caster = {
"float32": self._read_float32,
"int8": self._read_int8,
Expand All @@ -51,7 +45,7 @@ def get(self) -> np.ndarray:
"uint32": self._read_uint32,
"uint64": self._read_uint64,
}
return caster[self.dtype.name]()
return caster[self.dtype.name](origin, region)


def __array__(self, dtype=None) -> np.ndarray:
Expand All @@ -62,8 +56,6 @@ def __array__(self, dtype=None) -> np.ndarray:


# missing operators:
# __setitem__
# __getitem__
# __iter__
# __array_interface__

Expand Down Expand Up @@ -96,19 +88,24 @@ def __array__(self, dtype=None) -> np.ndarray:
setattr(Array, "_png_to_html", _operators._png_to_html)
setattr(Array, "_repr_html_", _operators._repr_html_)

setattr(Array, "__setitem__", _operators.__setitem__)
setattr(Array, "__getitem__", _operators.__getitem__)

Image = Union[np.ndarray, Array]


def is_image(any_array):
def is_image(object):
return (
isinstance(any_array, np.ndarray)
or isinstance(any_array, tuple)
or isinstance(any_array, list)
or isinstance(any_array, Array)
or str(type(any_array))
isinstance(object, np.ndarray)
or isinstance(object, tuple)
or isinstance(object, list)
or isinstance(object, Array)
or str(type(object))
in [
"<class 'cupy._core.core.ndarray'>",
"<class 'dask.array.core.Array'>",
"<class 'xarray.core.dataarray.DataArray'>",
"<class 'resource_backed_dask_array.ResourceBackedDaskArray'>",
"<class 'torch.Tensor'>",
]
)
189 changes: 162 additions & 27 deletions pyclesperanto/_operators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import numpy as np
from typing import Optional
from typing import Optional, Union

from ._array import Array

from ._utils import _compute_range, _clean_index

cl_buffer_datatype_dict = {
bool: "bool",
Expand Down Expand Up @@ -57,11 +61,8 @@ def max(self, axis: Optional[int] = None, out=None):
else:
raise ValueError("Axis " + axis + " not supported")
if out is not None:
from ._memory import pull
from ._array import Image

if isinstance(out, Image):
np.copyto(out, pull(result).astype(out.dtype))
if isinstance(out, Union[Array, np.ndarray]):
np.copyto(out, result.get().astype(out.dtype))
else:
out = result
return result
Expand All @@ -84,11 +85,8 @@ def min(self, axis: Optional[int] = None, out=None):
else:
raise ValueError("Axis " + axis + " not supported")
if out is not None:
from ._memory import pull
from ._array import Image

if isinstance(out, Image):
np.copyto(out, pull(result).astype(out.dtype))
if isinstance(out, Union[Array, np.ndarray]):
np.copyto(out, result.get().astype(out.dtype))
return result


Expand All @@ -109,11 +107,8 @@ def sum(self, axis: Optional[int] = None, out=None):
else:
raise ValueError("Axis " + axis + " not supported")
if out is not None:
from ._memory import pull
from ._array import Image

if isinstance(out, Image):
np.copyto(out, pull(result).astype(out.dtype))
if isinstance(out, Union[Array, np.ndarray]):
np.copyto(out, result.get().astype(out.dtype))
return result


Expand Down Expand Up @@ -320,12 +315,158 @@ def __next__(self):
return MyIterator(self)


def __getitem__(self, index):
raise NotImplementedError("Not implemented yet.")
def __getitem__(self, key):
result = None
key = _clean_index(key)
index = [[0, x, 1] for x in self.shape]
for x in range(len(key)):
if isinstance(key[x], slice):
index[x] = [key[x].start, key[x].stop, key[x].step]
elif np.issubdtype(type(key[x]), np.integer):
index[x] = [key[x], key[x] + 1 if key[x] > 0 else key[x] - 1, None]
key = index
# manage range for (x,y,z), with nothing that we deal with a z,y,x order
use_range, range_x, range_y, range_z = _compute_range(key, self.shape)
origin = [range_z[0], range_y[0], range_x[0]]
region = [
range_z[1] - range_z[0],
range_y[1] - range_y[0],
range_x[1] - range_x[0],
]
region = [abs(x) for x in region]
# we are dealing with a single pixel operation
if np.prod(region) == 1:
result = self.get(origin, region)
# a specific step was provided, we are dealing with a range operation
if use_range and result is None:
from ._tier1 import range as gpu_range

result = gpu_range(
self,
start_x=range_x[0],
stop_x=range_x[1],
step_x=range_x[2],
start_y=range_y[0],
stop_y=range_y[1],
step_y=range_y[2],
start_z=range_z[0],
stop_z=range_z[1],
step_z=range_z[2],
)
else: # we are dealing with a sub-region operation
from ._memory import create

try:
# we copy sub-region inside a new buffer to return
result = create(
region, dtype=self.dtype, mtype=self.mtype, device=self.device
)
self.copy(result, origin, [0] * len(region), region)
except Exception:
# if we fail to copy, we rely on numpy to do the job
result = self.set(self.get().__getitem__(key))
# if result is an Array, and one of the dimension is equal to 1
if isinstance(result, Array) and any([x == 1 for x in result.shape]):
from ._tier1 import transpose_xy, transpose_yz

transpose = [x == 1 for x in region]
if transpose[2]: # x is empty
result = transpose_xy(result)
result = transpose_yz(result)
if transpose[1]: # y is empty
result = transpose_yz(result)

def __setitem__(self, index, value):
raise NotImplementedError("Not implemented yet.")
return result


def __setitem__(self, key, value):
if not isinstance(value, Union[Array, np.ndarray]):
value = np.array(value)
key = _clean_index(key)
print(f"value type: {type(value)}")
print(f"value dtype: {value.dtype}")
# check if the dtype of the value is a numeric type such as float, int etc
if value.dtype not in _supported_numeric_types:
raise ValueError(
"dtype "
+ str(value.dtype)
+ " not supported. Use one of "
+ str(_supported_numeric_types)
)
# define default index as slices(0, shape, 1), and iterate over keys and replace when relevant
index = [[0, x, 1] for x in self.shape]
for x in range(len(key)):
if isinstance(key[x], slice):
index[x] = [key[x].start, key[x].stop, key[x].step]
elif np.issubdtype(type(key[x]), np.integer):
index[x] = [key[x], key[x] + 1 if key[x] > 0 else key[x] - 1, None]
key = index

print(f"key: {key}")

# manage range for (x,y,z), with nothing that we deal with a z,y,x order
use_range, range_x, range_y, range_z = _compute_range(key, self.shape)
origin = [range_z[0], range_y[0], range_x[0]]
region = [
range_z[1] - range_z[0],
range_y[1] - range_y[0],
range_x[1] - range_x[0],
]
region = [abs(x) for x in region]

print(f"origin: {origin}, region: {region}")
print(
f"use_range: {use_range}, range_x: {range_x}, range_y: {range_y}, range_z: {range_z}"
)

stride_region = [
abs(region[0] / range_z[2]),
abs(region[1] / range_y[2]),
abs(region[2] / range_x[2]),
]

if value.size == 1:
value = np.repeat(value, np.prod(region))
value = value.reshape(region)
self.set(value, origin, region)
return
if value.size != np.prod(stride_region):
raise IndexError(
f"Input value mismatch the indexed region: {value.size} != {np.prod(stride_region)} ({value.shape} != {region})"
)
if use_range:
from ._tier1 import range as gpu_range

gpu_range(
value,
self,
start_x=range_x[0],
stop_x=range_x[1],
step_x=range_x[2],
start_y=range_y[0],
stop_y=range_y[1],
step_y=range_y[2],
start_z=range_z[0],
stop_z=range_z[1],
step_z=range_z[2],
)
else:
if isinstance(value, Array):
if self.dtype == value.dtype:
self.copy(value, origin, (0, 0, 0), region)
else:
# otherwise we copy with cast using paste
from ._tier1 import paste

paste(
value,
self,
index_x=origin[-1] if len(origin) > 0 else 0,
index_y=origin[-2] if len(origin) > 1 else 0,
index_z=origin[-3] if len(origin) > 2 else 0,
)
else:
self.set(value, origin, region)


# adapted from https://github.com/napari/napari/blob/d6bc683b019c4a3a3c6e936526e29bbd59cca2f4/napari/utils/notebook_display.py#L54-L73
Expand Down Expand Up @@ -373,13 +514,7 @@ def _repr_html_(self):
imshow(self, labels=labels, continue_drawing=True, colorbar=not labels)
image = self._png_to_html(self._plt_to_png())
else:
return (
"<pre>cle.array("
+ str(np.asarray(self))
+ ", dtype="
+ str(self.dtype)
+ ")</pre>"
)
return "<pre>" + repr(self) + "</pre>"

units = ["B", "kB", "MB", "GB", "TB", "PB"]
unit_index = 0
Expand Down
13 changes: 11 additions & 2 deletions pyclesperanto/_tier5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,22 @@
from ._decorators import plugin_function


@plugin_function
def array_equal(
input_image0: Image, input_image1: Image, device: Device = None
) -> list:
from ._pyclesperanto import _array_equal as op

return op(device=device, src0=input_image0, src1=input_image1)


@plugin_function
def combine_labels(
input_image0: Image,
input_image1: Image,
output_image: Image = None,
device: Device = None,
) -> Image:
) -> list:
from ._pyclesperanto import _combine_labels as op

return op(device=device, src0=input_image0, src1=input_image1, dst=output_image)
Expand All @@ -24,7 +33,7 @@ def connected_components_labeling(
output_image: Image = None,
connectivity: str = "",
device: Device = None,
) -> Image:
) -> list:
from ._pyclesperanto import _connected_components_labeling as op

return op(
Expand Down
Loading

0 comments on commit fa2eaeb

Please sign in to comment.