Skip to content

Commit

Permalink
Merge pull request #117 from ttngu207/dev_write_multiple_bigtiff
Browse files Browse the repository at this point in the history
feat: set `gb_per_file` and write multiple bigtiff
  • Loading branch information
ttngu207 authored Jul 23, 2024
2 parents e2cfc6a + dc98ebc commit ebb6adc
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 64 deletions.
77 changes: 44 additions & 33 deletions element_interface/caiman_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,12 @@ def _process_scanimage_tiff(scan_filenames, output_dir="./", split_depths=False)
imsave(save_fp.as_posix(), chn_vol)


def _save_mc(mc, caiman_fp: str, is3D: bool):
def _save_mc(
mc,
caiman_fp: str,
is3D: bool,
summary_images: dict = None,
):
"""Save motion correction to hdf5 output
Run these commands after the CaImAn analysis has completed.
Expand All @@ -545,21 +550,13 @@ def _save_mc(mc, caiman_fp: str, is3D: bool):
shifts_rig : Rigid transformation x and y shifts per frame
x_shifts_els : Non rigid transformation x shifts per frame per block
y_shifts_els : Non rigid transformation y shifts per frame per block
caiman_fp (str): CaImAn output (*.hdf5) file path
caiman_fp (str): CaImAn output (*.hdf5) file path - append if exists, else create new one
is3D (bool): the data is 3D
summary_images(dict): dict of summary images (average_image, max_image, correlation_image) - if None, will be computed, if provided as empty dict, will not be computed
"""

# Load motion corrected mmap image
mc_image = cm.load(mc.mmap_file, is3D=is3D)

# Compute motion corrected summary images
average_image = np.mean(mc_image, axis=0)
max_image = np.max(mc_image, axis=0)

# Compute motion corrected correlation image
correlation_image = cm.local_correlations(
mc_image.transpose((1, 2, 3, 0) if is3D else (1, 2, 0))
)
correlation_image[np.isnan(correlation_image)] = 0
Yr, dims, T = cm.mmapping.load_memmap(mc.mmap_file[0])
# Load the first frame of the movie
mc_image = np.reshape(Yr[: np.product(dims), :1], [1] + list(dims), order="F")

# Compute mc.coord_shifts_els
grid = []
Expand Down Expand Up @@ -591,7 +588,8 @@ def _save_mc(mc, caiman_fp: str, is3D: bool):
)

# Open hdf5 file and create 'motion_correction' group
h5f = h5py.File(caiman_fp, "r+")
caiman_fp = pathlib.Path(caiman_fp)
h5f = h5py.File(caiman_fp, "r+" if caiman_fp.exists() else "w")
h5g = h5f.require_group("motion_correction")

# Write motion correction shifts and motion corrected summary images to hdf5 file
Expand Down Expand Up @@ -623,7 +621,7 @@ def _save_mc(mc, caiman_fp: str, is3D: bool):
# For CaImAn, reference image is still a 2D array even for the case of 3D
# Assume that the same ref image is used for all the planes
reference_image = (
np.tile(mc.total_template_els, (1, 1, correlation_image.shape[-1]))
np.tile(mc.total_template_els, (1, 1, dims[-1]))
if is3D
else mc.total_template_els
)
Expand All @@ -638,32 +636,45 @@ def _save_mc(mc, caiman_fp: str, is3D: bool):
"coord_shifts_rig", shape=np.shape(grid), data=grid, dtype=type(grid[0][0])
)
reference_image = (
np.tile(mc.total_template_rig, (1, 1, correlation_image.shape[-1]))
np.tile(mc.total_template_rig, (1, 1, dims[-1]))
if is3D
else mc.total_template_rig
)

if summary_images is None:
# Load motion corrected mmap image
mc_image = cm.load(mc.mmap_file, is3D=is3D)

# Compute motion corrected summary images
average_image = np.mean(mc_image, axis=0)
max_image = np.max(mc_image, axis=0)

# Compute motion corrected correlation image
correlation_image = cm.local_correlations(
mc_image.transpose((1, 2, 3, 0) if is3D else (1, 2, 0))
)
correlation_image[np.isnan(correlation_image)] = 0

summary_images = {
"average_image": average_image,
"max_image": max_image,
"correlation_image": correlation_image,
}

for img_type, img in summary_images.items():
h5g.require_dataset(
img_type,
shape=np.shape(img),
data=img,
dtype=img.dtype,
)

h5g.require_dataset(
"reference_image",
shape=np.shape(reference_image),
data=reference_image,
dtype=reference_image.dtype,
)
h5g.require_dataset(
"correlation_image",
shape=np.shape(correlation_image),
data=correlation_image,
dtype=correlation_image.dtype,
)
h5g.require_dataset(
"average_image",
shape=np.shape(average_image),
data=average_image,
dtype=average_image.dtype,
)
h5g.require_dataset(
"max_image", shape=np.shape(max_image), data=max_image, dtype=max_image.dtype
)

# Close hdf5 file
h5f.close()
80 changes: 49 additions & 31 deletions element_interface/prairie_view_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def write_single_bigtiff(
output_dir="./",
caiman_compatible=False, # if True, save the movie as a single page (frame x height x width)
overwrite=False,
gb_per_file=None,
):
logger.warning("Deprecation warning: `caiman_compatible` argument will no longer have any effect and will be removed in the future. `write_single_bigtiff` will return multi-page tiff, which is compatible with CaImAn.")

Expand All @@ -112,14 +113,17 @@ def write_single_bigtiff(
)
if output_prefix is None:
output_prefix = os.path.commonprefix(tiff_names)
output_tiff_fullpath = (
Path(output_dir)
/ f"{output_prefix}_pln{plane_idx}_chn{channel}.tif"
)
if output_tiff_fullpath.exists() and not overwrite:
return output_tiff_fullpath
output_tiff_stem = f"{output_prefix}_pln{plane_idx}_chn{channel}"

output_dir = Path(output_dir)
output_tiff_list = list(output_dir.glob(f"{output_tiff_stem}*.tif"))
if len(output_tiff_list) and not overwrite:
return output_tiff_list[0] if gb_per_file is None else output_tiff_list

output_tiff_list = []
if self.meta["is_multipage"]:
if gb_per_file is not None:
logger.warning("Ignoring `gb_per_file` argument for multi-page tiff (NotYetImplemented)")
# For multi-page tiff - the pages are organized as:
# (channel x slice x frame) - each page is (height x width)
# - TODO: verify this is the case for Bruker multi-page tiff
Expand Down Expand Up @@ -156,38 +160,52 @@ def write_single_bigtiff(
except Exception as e:
raise Exception(f"Error in processing tiff file {input_file}: {e}")

output_tiff_fullpath = (
output_dir
/ f"{output_tiff_stem}.tif"
)
tifffile.imwrite(
output_tiff_fullpath,
combined_data,
metadata={"axes": "TYX", "'fps'": self.meta["frame_rate"]},
bigtiff=True,
)
output_tiff_list.append(output_tiff_fullpath)
else:
with tifffile.TiffWriter(
output_tiff_fullpath,
bigtiff=True,
) as tiff_writer:
try:
for input_file in tiff_names:
with tifffile.TiffFile(
self.prairieview_dir / input_file
) as tffl:
assert len(tffl.pages) == 1
tiff_writer.write(
tffl.pages[0].asarray(),
metadata={
"axes": "YX",
"'fps'": self.meta["frame_rate"],
},
)
# additional safeguard to close the file and delete the object
# in the attempt to prevent error: `not a TIFF file b''`
tffl.close()
del tffl
except Exception as e:
raise Exception(f"Error in processing tiff file {input_file}: {e}")

return output_tiff_fullpath
while len(tiff_names):
output_tiff_fullpath = (
output_dir
/ f"{output_tiff_stem}_{len(output_tiff_list):04}.tif"
)
with tifffile.TiffWriter(
output_tiff_fullpath,
bigtiff=True,
) as tiff_writer:
while len(tiff_names):
input_file = tiff_names.pop(0)
try:
with tifffile.TiffFile(
self.prairieview_dir / input_file
) as tffl:
assert len(tffl.pages) == 1
tiff_writer.write(
tffl.pages[0].asarray(),
metadata={
"axes": "YX",
"'fps'": self.meta["frame_rate"],
},
)
# additional safeguard to close the file and delete the object
# in the attempt to prevent error: `not a TIFF file b''`
tffl.close()
del tffl
except Exception as e:
raise Exception(f"Error in processing tiff file {input_file}: {e}")
if gb_per_file and output_tiff_fullpath.stat().st_size >= gb_per_file * 1024 ** 3:
break
output_tiff_list.append(output_tiff_fullpath)

return output_tiff_list[0] if gb_per_file is None else output_tiff_list


def _extract_prairieview_metadata(xml_filepath: str):
Expand Down

0 comments on commit ebb6adc

Please sign in to comment.