Skip to content

Commit

Permalink
Always check for int and np.integer (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Sep 24, 2020
1 parent 3585536 commit 127a885
Showing 1 changed file with 52 additions and 39 deletions.
91 changes: 52 additions & 39 deletions lazy_ops/lazy_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
from abc import ABCMeta, abstractmethod
from typing import Union
import h5py

installed_dataset_types = h5py.Dataset


class DatasetView(metaclass=ABCMeta):

def __new__(cls, dataset: installed_dataset_types = None, slice_index=(np.index_exp[:],()), axis_order=None):
def __new__(cls, dataset: installed_dataset_types = None, slice_index=(np.index_exp[:], ()), axis_order=None):
"""
Args:
dataset: the underlying dataset
Expand All @@ -42,18 +44,18 @@ def __new__(cls, dataset: installed_dataset_types = None, slice_index=(np.index_
lazy object
"""
if cls == DatasetView:
if isinstance(dataset,h5py.Dataset):
if isinstance(dataset, h5py.Dataset):
return DatasetViewh5py(dataset=dataset)
elif HAVE_ZARR:
if isinstance(dataset,zarr.core.Array):
if isinstance(dataset, zarr.core.Array):
return DatasetViewzarr(dataset=dataset)
elif str(z1).find("zarr") != -1:
raise TypeError("To use DatasetView with a zarr array install zarr: \n pip install zarr\n")
raise TypeError("DatasetView requires either an h5py dataset or a zarr array as first argument")
else:
return super().__new__(cls)

def __init__(self, dataset: installed_dataset_types = None, slice_index=(np.index_exp[:],()), axis_order=None):
def __init__(self, dataset: installed_dataset_types = None, slice_index=(np.index_exp[:], ()), axis_order=None):
"""
Args:
dataset: the underlying dataset
Expand Down Expand Up @@ -105,7 +107,7 @@ def _slice_tuple(self, key):
Returns:
The slice object tuple
"""
if isinstance(key, (slice,int,np.ndarray)):
if isinstance(key, (slice, int, np.integer, np.ndarray)):
key = key,
else:
key = *key,
Expand All @@ -131,33 +133,35 @@ def _slice_shape(self, slice_):
int_ind = slice_[1]
slice_ = self._slice_tuple(slice_[0])
# converting the slice to regular slices that only contain integers
slice_regindices = [slice(*slice_[i].indices(self.dataset.shape[self.axis_order[i]])) if isinstance(slice_[i],slice)
else slice_[i]
for i in range(len(slice_))]
slice_regindices = [
slice(*slice_[i].indices(self.dataset.shape[self.axis_order[i]])) if isinstance(slice_[i], slice)
else slice_[i]
for i in range(len(slice_))]

slice_shape = ()
int_index = ()
axis_order = ()
for i in range(len(slice_)):
if isinstance(slice_[i],slice):
slice_start, slice_stop, slice_step = slice_regindices[i].start, slice_regindices[i].stop, slice_regindices[i].step
if isinstance(slice_[i], slice):
slice_start, slice_stop, slice_step = slice_regindices[i].start, slice_regindices[i].stop, \
slice_regindices[i].step
if slice_step < 1:
raise ValueError("Slice step parameter must be positive")
if slice_stop < slice_start:
slice_start = slice_stop
slice_regindices[i] = slice(slice_start, slice_stop, slice_step)
slice_shape += (1 + (slice_stop - slice_start -1 )//slice_step if slice_stop != slice_start else 0,)
slice_shape += (1 + (slice_stop - slice_start - 1) // slice_step if slice_stop != slice_start else 0,)
axis_order += (self.axis_order[i],)
elif isinstance(slice_[i],int):
int_index += ((i,slice_[i],self.axis_order[i]),)
elif isinstance(slice_[i], (int, np.integer)):
int_index += ((i, slice_[i], self.axis_order[i]),)
else:
# slice_[i] is an iterator of integers
slice_shape += (len(slice_[i]),)
axis_order += (self.axis_order[i],)
slice_regindices = tuple(el for el in slice_regindices if not isinstance(el,int))
axis_order += tuple(self.axis_order[len(axis_order)+len(int_index)::])
slice_regindices = tuple(el for el in slice_regindices if not isinstance(el, (int, np.integer)))
axis_order += tuple(self.axis_order[len(axis_order) + len(int_index)::])
int_index += int_ind
slice_shape += self.dataset.shape[len(slice_shape)+len(int_index)::]
slice_shape += self.dataset.shape[len(slice_shape) + len(int_index)::]

return slice_shape, slice_regindices, int_index, axis_order

Expand All @@ -179,7 +183,7 @@ def lazy_iter(self, axis=0):
Modifications to the items are not stored
"""
for i in range(self._shape[axis]):
yield self.lazy_slice[(*np.index_exp[:]*axis,i)]
yield self.lazy_slice[(*np.index_exp[:] * axis, i)]

def __call__(self, new_slice):
""" allows lazy_slice function calls with slice objects as input"""
Expand Down Expand Up @@ -220,7 +224,7 @@ def _slice_composition(self, new_slice):
slice_result = ()
# Iterating over the new slicing tuple to change the merged dataset slice.
for i in range(len(new_slice)):
if isinstance(new_slice[i],slice):
if isinstance(new_slice[i], slice):
if i < len(self.key):
# converting new_slice slice to regular slices,
# newkey_start, newkey_stop, newkey_step only contains positive or zero integers
Expand All @@ -230,21 +234,22 @@ def _slice_composition(self, new_slice):
raise ValueError("Slice step parameter must be positive")
if newkey_stop < newkey_start:
newkey_start = newkey_stop
if isinstance(self.key[i],slice):
slice_result += (slice(min(self.key[i].start + self.key[i].step * newkey_start, self.key[i].stop),
min(self.key[i].start + self.key[i].step * newkey_stop, self.key[i].stop),
newkey_step * self.key[i].step),)
if isinstance(self.key[i], slice):
slice_result += (
slice(min(self.key[i].start + self.key[i].step * newkey_start, self.key[i].stop),
min(self.key[i].start + self.key[i].step * newkey_stop, self.key[i].stop),
newkey_step * self.key[i].step),)
else:
# self.key[i] is an iterator of integers
slice_result += (self.key[i][new_slice[i]],)
else:
slice_result += (slice(*new_slice[i].indices(self.dataset.shape[self.axis_order[i]])),)
elif isinstance(new_slice[i],int):
elif isinstance(new_slice[i], (int, np.integer)):
if i < len(self.key):
if new_slice[i] >= self._shape[i] or new_slice[i] <= ~self._shape[i]:
raise IndexError("Index %d out of range, dim %d of size %d" % (new_slice[i],i,self._shape[i]))
if isinstance(self.key[i],slice):
int_index = self.key[i].start + self.key[i].step*(new_slice[i]%self._shape[i])
raise IndexError("Index %d out of range, dim %d of size %d" % (new_slice[i], i, self._shape[i]))
if isinstance(self.key[i], slice):
int_index = self.key[i].start + self.key[i].step * (new_slice[i] % self._shape[i])
slice_result += (int_index,)
else:
# self.key[i] is an iterator of integers
Expand All @@ -253,28 +258,32 @@ def _slice_composition(self, new_slice):
slice_result += (new_slice[i],)
else:
try:
if not all(isinstance(el,int) for el in new_slice[i]):
if not all([isinstance(el, (int, np.integer)) for el in new_slice[i]]):
if new_slice[i].dtype.kind != 'b':
raise ValueError("Indices must be either integers or booleans")
else:
# boolean indexing
if len(new_slice[i]) != self.shape[i]:
raise IndexError("Length of boolean index $d must be equal to size %d in dim %d" % (len(new_slice[i]),self.shape[i],i))
raise IndexError("Length of boolean index $d must be equal to size %d in dim %d" % (
len(new_slice[i]), self.shape[i], i))
new_slice_i = new_slice[i].nonzero()[0]
else:
new_slice_i = new_slice[i]
if i < len(self.key):
if any(el >= self._shape[i] or el <= ~self._shape[i] for el in new_slice_i):
raise IndexError("Index %s out of range, dim %d of size %d" % (str(new_slice_i),i,self._shape[i]))
if isinstance(self.key[i],slice):
slice_result += (tuple(self.key[i].start + self.key[i].step*(ind%self._shape[i]) for ind in new_slice_i),)
raise IndexError(
"Index %s out of range, dim %d of size %d" % (str(new_slice_i), i, self._shape[i]))
if isinstance(self.key[i], slice):
slice_result += (tuple(
self.key[i].start + self.key[i].step * (ind % self._shape[i]) for ind in new_slice_i),)
else:
# self.key[i] is an iterator of integers
slice_result += (tuple(self.key[i][ind] for ind in new_slice_i),)
else:
slice_result += (new_slice_i,)
except:
raise IndexError("Indices must be either integers, iterators of integers, slice objects, or numpy boolean arrays")
raise IndexError(
"Indices must be either integers, iterators of integers, slice objects, or numpy boolean arrays")
slice_result += self.key[len(new_slice):]

return slice_result
Expand Down Expand Up @@ -314,14 +323,15 @@ def _ellipsis_slices(self, new_slice):
Returns:
equivalent slices with Ellipsis expanded
"""
ellipsis_count = sum(s==Ellipsis for s in new_slice if not isinstance(s,np.ndarray))
ellipsis_count = sum(s == Ellipsis for s in new_slice if not isinstance(s, np.ndarray))
if ellipsis_count == 1:
ellipsis_index = new_slice.index(Ellipsis)
if ellipsis_index == len(new_slice)-1:
if ellipsis_index == len(new_slice) - 1:
new_slice = new_slice[:-1]
else:
num_ellipsis_dims = len(self.shape) - (len(new_slice) - 1)
new_slice = new_slice[:ellipsis_index] + np.index_exp[:]*num_ellipsis_dims + new_slice[ellipsis_index+1:]
new_slice = new_slice[:ellipsis_index] + np.index_exp[:] * num_ellipsis_dims + new_slice[
ellipsis_index + 1:]
elif ellipsis_count > 0:
raise IndexError("Only a single Ellipsis is allowed")
return new_slice
Expand Down Expand Up @@ -365,6 +375,7 @@ def read_direct(self, dest, source_sel=None, dest_sel=None):
self.dataset.read_direct(reversed_dest, source_sel=reversed_slice_key, dest_sel=reversed_dest_sel)
np.copyto(dest, reversed_dest.transpose(axis_order_read))


def lazy_transpose(dset: installed_dataset_types, axes=None):
""" Array lazy transposition, not passing axis argument reverses the order of dimensions
Args:
Expand All @@ -378,19 +389,21 @@ def lazy_transpose(dset: installed_dataset_types, axes=None):

return DatasetView(dset).lazy_transpose(axis_order=axes)

class DatasetViewh5py(DatasetView, h5py.Dataset):

def __new__(cls,dataset):
class DatasetViewh5py(DatasetView, h5py.Dataset):

def __new__(cls, dataset):
_self = super().__new__(cls)
h5py.Dataset.__init__(_self, dataset.id)
return _self


try:
import zarr
from .lazy_loading_zarr import DatasetViewzarr
installed_dataset_types = Union[installed_dataset_types,zarr.core.Array]

installed_dataset_types = Union[installed_dataset_types, zarr.core.Array]
HAVE_ZARR = True
except ImportError:
HAVE_ZARR = False
DatasetViewzarr = None
DatasetViewzarr = None

0 comments on commit 127a885

Please sign in to comment.