Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr sink Improvements #1713

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 74 additions & 28 deletions sources/zarr/large_image_source_zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def _readFrameValues(self, found, baseArray):
if axes_values.get(a) is not None
]
self._frameUnits = {k: axes_units.get(k) for k in self.frameAxes if k in axes_units}
self._frameValues = None
frame_values_shape = [baseArray.shape[self._axes[a]] for a in self.frameAxes]
frame_values_shape.append(len(frame_values_shape))
frame_values = np.empty(frame_values_shape, dtype=object)
Expand All @@ -387,7 +388,8 @@ def _readFrameValues(self, found, baseArray):
if name:
slicing[self._frameAxes.index(name)] = j
frame_values[tuple(slicing)] = value
self._frameValues = frame_values
if frame_values.size > 0:
self._frameValues = frame_values

def _validateZarr(self):
"""
Expand Down Expand Up @@ -645,6 +647,50 @@ def _validateNewTile(self, tile, mask, placement, axes):

return tile, mask, placement, axes

def _updateFrameValues(self, frame_values, placement, axes, new_axes, new_dims):
self._frameAxes = [
a for a in axes
if a in frame_values or
(self.frameAxes is not None and a in self.frameAxes)
]
frames_shape = [new_dims[a] for a in self.frameAxes]
frames_shape.append(len(frames_shape))
if self.frameValues is None:
self._frameValues = np.empty(frames_shape, dtype=object)
elif self.frameValues.shape != frames_shape:
if len(new_axes):
for i in new_axes.values():
self._frameValues = np.expand_dims(self._frameValues, axis=i)
frame_padding = [
(0, s - self.frameValues.shape[i])
for i, s in enumerate(frames_shape)
]
frame_padding[-1] = (0, 0)
self._frameValues = np.pad(self._frameValues, frame_padding)
for i in new_axes.values():
self._frameValues = np.insert(
self._frameValues, i, 0, axis=len(frames_shape) - 1,
)
current_frame_slice = tuple(placement.get(a) for a in self.frameAxes)
for i, k in enumerate(self.frameAxes):
self.frameValues[(*current_frame_slice, i)] = frame_values.get(k)

def _resizeImage(self, arr, new_shape, new_axes, chunking):
if new_shape != arr.shape:
if len(new_axes):
for i in new_axes.values():
arr = np.expand_dims(arr, axis=i)
arr = np.pad(
arr,
[(0, s - arr.shape[i]) for i, s in enumerate(new_shape)],
)
new_arr = zarr.empty(new_shape, chunks=chunking, dtype=arr.dtype)
new_arr[:] = arr[:]
arr = new_arr
else:
arr.resize(*new_shape)
return arr

def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
"""
Add a numpy or image tile to the image, expanding the image as needed
Expand All @@ -665,6 +711,12 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
``level`` is a reserved word and not permitted for an axis name.
"""
self._checkEditable()
try:
# read any info written by other processes
self._validateZarr()
except TileSourceError:
pass
updateMetadata = False
store_path = str(kwargs.pop('level', 0))
placement = {
'x': x,
Expand All @@ -678,9 +730,12 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
tile, mask, placement, axes = self._validateNewTile(tile, mask, placement, axes)

with self._threadLock and self._processLock:
old_axes = self._axes if hasattr(self, '_axes') else {}
self._axes = {k: i for i, k in enumerate(axes)}
new_axes = {k: i for k, i in self._axes.items() if k not in old_axes}
new_dims = {
a: max(
self._axisCounts.get(a, 0) if hasattr(self, '_axisCounts') else 0,
self._dims.get(store_path, {}).get(a, 0),
placement.get(a, 0) + tile.shape[i],
)
Expand All @@ -694,23 +749,8 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):

if len(frame_values.keys()) > 0:
# update self.frameValues
self.frameAxes = [
a for a in axes
if a in frame_values or
(self.frameAxes is not None and a in self.frameAxes)
]
frames_shape = [new_dims[a] for a in self.frameAxes]
frames_shape.append(len(frames_shape))
if self.frameValues is None:
self.frameValues = np.empty(frames_shape, dtype=object)
elif self.frameValues.shape != frames_shape:
self.frameValues = np.pad(
self.frameValues,
[(0, s - self.frameValues.shape[i]) for i, s in enumerate(frames_shape)],
)
current_frame_slice = tuple(placement.get(a) for a in self.frameAxes)
for i, k in enumerate(self.frameAxes):
self.frameValues[(*current_frame_slice, i)] = frame_values.get(k)
updateMetadata = True
self._updateFrameValues(frame_values, placement, axes, new_axes, new_dims)

current_arrays = dict(self._zarr.arrays())
if store_path == '0':
Expand All @@ -728,16 +768,18 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
])
else:
arr = current_arrays[store_path]
new_shape = tuple(max(v, arr.shape[i]) for i, v in enumerate(new_dims.values()))
if new_shape != arr.shape:
arr.resize(*new_shape)
if arr.chunks[-1] != new_dims.get('s'):
# rechunk if length of samples axis changes
chunking = tuple([
self._tileSize if a in ['x', 'y'] else
new_dims.get('s') if a == 's' else 1
for a in axes
])
new_shape = tuple(
max(v, arr.shape[old_axes[k]] if k in old_axes else 0)
for k, v in new_dims.items()
)
if arr.chunks[-1] != new_dims.get('s') or len(new_axes):
# rechunk if length of samples axis changed or any new axis added
chunking = tuple([
self._tileSize if a in ['x', 'y'] else
new_dims.get('s') if a == 's' else 1
for a in axes
])
arr = self._resizeImage(arr, new_shape, new_axes, chunking)

if mask is not None:
arr[placement_slices] = np.where(mask, tile, arr[placement_slices])
Expand Down Expand Up @@ -767,6 +809,8 @@ def addTile(self, tile, x=0, y=0, mask=None, axes=None, **kwargs):
self._levels = None
self.levels = int(max(1, math.ceil(math.log(max(
self.sizeX / self.tileWidth, self.sizeY / self.tileHeight)) / math.log(2)) + 1))
if updateMetadata:
self._writeInternalMetadata()

def addAssociatedImage(self, image, imageKey=None):
"""
Expand Down Expand Up @@ -1002,6 +1046,7 @@ def frameAxes(self):
def frameAxes(self, axes):
self._checkEditable()
self._frameAxes = axes
self._writeInternalMetadata()

@property
def frameUnits(self):
Expand Down Expand Up @@ -1034,6 +1079,7 @@ def frameValues(self, a):
err = f'frameValues must have {len(self.frameAxes) + 1} dimensions.'
raise ValueError(err)
self._frameValues = a
self._writeInternalMetadata()

def _generateDownsampledLevels(self, resample_method):
self._checkEditable()
Expand Down
84 changes: 73 additions & 11 deletions test/test_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,6 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path):
))
expected_metadata = get_expected_metadata(axis_spec, frame_shape)

sink.frameAxes = list(axis_spec.keys())
sink.frameUnits = {
k: v['units'] for k, v in axis_spec.items()
}
frame_values_shape = [
*[len(v['values']) for v in axis_spec.values()],
len(axis_spec),
Expand Down Expand Up @@ -633,7 +629,11 @@ def testFrameValuesSmall(use_add_tile_args, tmp_path):
index += 1

if not use_add_tile_args:
sink.frameAxes = list(axis_spec.keys())
sink.frameValues = frame_values
sink.frameUnits = {
k: v['units'] for k, v in axis_spec.items()
}
compare_metadata(dict(sink.getMetadata()), expected_metadata)

sink.write(output_file)
Expand Down Expand Up @@ -686,10 +686,6 @@ def testFrameValues(use_add_tile_args, tmp_path):
)
expected_metadata = get_expected_metadata(axis_spec, frame_shape)

sink.frameAxes = list(axis_spec.keys())
sink.frameUnits = {
k: v['units'] for k, v in axis_spec.items()
}
frame_values_shape = [
*[len(v['values']) for v in axis_spec.values()],
len(axis_spec),
Expand Down Expand Up @@ -727,7 +723,11 @@ def testFrameValues(use_add_tile_args, tmp_path):
index += 1

if not use_add_tile_args:
sink.frameAxes = list(axis_spec.keys())
sink.frameValues = frame_values
sink.frameUnits = {
k: v['units'] for k, v in axis_spec.items()
}
compare_metadata(dict(sink.getMetadata()), expected_metadata)

sink.write(output_file)
Expand Down Expand Up @@ -783,19 +783,81 @@ def testSubprocess(tmp_path):
subprocess.run([sys.executable, '-c', """import large_image_source_zarr
import numpy as np
sink = large_image_source_zarr.open('%s')
sink.addTile(np.ones((1, 1, 1)), x=2047, y=2047, t=5, z=2)
sink.addTile(np.ones((1, 1, 1)), x=2047, y=2047, t=5, z=2, t_value='thursday', z_value=0.2)
""" % path], capture_output=True, text=True, check=True)
sink.addTile(np.ones((1, 1, 1)), x=5000, y=4095, t=0, z=4)
sink.addTile(np.ones((1, 1, 1)), x=5000, y=4095, t=0, z=4, t_value='sunday', z_value=0.4)

assert sink.metadata['IndexRange']['IndexZ'] == 5
metadata = sink.getMetadata()
assert metadata['IndexRange']['IndexZ'] == 5
assert sink.getRegion(
region=dict(left=2047, top=2047, width=1, height=1),
format='numpy',
frame=17,
)[0] == 1
assert metadata['ValueT']['values'][17] == 'thursday'
assert metadata['ValueZ']['values'][17] == 0.2
assert sink.getRegion(
region=dict(left=5000, top=4095, width=1, height=1),
format='numpy',
frame=24,
)[0] == 1
assert metadata['ValueT']['values'][24] == 'sunday'
assert metadata['ValueZ']['values'][24] == 0.4
assert sink.sizeX == 5001


@pytest.mark.parametrize('axes_order', ['tzd', 'tdz', 'dzt', 'dtz', 'ztd', 'zdt'])
def testAddAxes(tmp_path, axes_order):
sink = large_image_source_zarr.new()
kwarg_groups = [
dict(t=0, t_value='sunday'),
dict(
t=5, t_value='friday',
z=1, z_value=0.1,
axes=axes_order.replace('d', '') + 'yxs',
),
dict(
t=6, t_value='saturday',
z=2, z_value=0.2,
d=1, d_value=100,
axes=axes_order + 'yxs',
),
]
for kwarg_group in kwarg_groups:
sink.addTile(
np.ones((4, 4, 4)),
x=1020, y=1020,
**kwarg_group,
)

metadata = sink.getMetadata()
t_values = metadata['ValueT']['values']
z_values = metadata['ValueZ']['values']
d_values = metadata['ValueD']['values']
t_stride = metadata['IndexStride']['IndexT']
z_stride = metadata['IndexStride']['IndexZ']
expected_filled_frames = [
# first and last frame are known, middle frame depends on axis ordering
0, z_stride + t_stride * 5, 41,
]
for frame in metadata.get('frames', []):
frame_index = frame.get('Frame')
sample = sink.getRegion(
region=dict(left=1020, top=1020, width=1, height=1),
format='numpy',
frame=frame_index,
)[0]
frame_values = dict(
t_value=t_values[frame_index],
z_value=z_values[frame_index],
d_value=d_values[frame_index],
)
kwarg_group = {}
if frame_index in expected_filled_frames:
kwarg_group = kwarg_groups[expected_filled_frames.index(frame_index)]
assert (sample == 1).all()
else:
assert (sample == 0).all()

for k, v in frame_values.items():
assert v == kwarg_group.get(k, 0)