diff --git a/element_interface/caiman_loader.py b/element_interface/caiman_loader.py index e245abd..18199c2 100644 --- a/element_interface/caiman_loader.py +++ b/element_interface/caiman_loader.py @@ -532,7 +532,12 @@ def _process_scanimage_tiff(scan_filenames, output_dir="./", split_depths=False) imsave(save_fp.as_posix(), chn_vol) -def _save_mc(mc, caiman_fp: str, is3D: bool): +def _save_mc( + mc, + caiman_fp: str, + is3D: bool, + summary_images: dict = None, +): """Save motion correction to hdf5 output Run these commands after the CaImAn analysis has completed. @@ -545,21 +550,13 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): shifts_rig : Rigid transformation x and y shifts per frame x_shifts_els : Non rigid transformation x shifts per frame per block y_shifts_els : Non rigid transformation y shifts per frame per block - caiman_fp (str): CaImAn output (*.hdf5) file path + caiman_fp (str): CaImAn output (*.hdf5) file path - append if exists, else create new one + is3D (bool): the data is 3D + summary_images(dict): dict of summary images (average_image, max_image, correlation_image) - if None, will be computed, if provided as empty dict, will not be computed """ - - # Load motion corrected mmap image - mc_image = cm.load(mc.mmap_file, is3D=is3D) - - # Compute motion corrected summary images - average_image = np.mean(mc_image, axis=0) - max_image = np.max(mc_image, axis=0) - - # Compute motion corrected correlation image - correlation_image = cm.local_correlations( - mc_image.transpose((1, 2, 3, 0) if is3D else (1, 2, 0)) - ) - correlation_image[np.isnan(correlation_image)] = 0 + Yr, dims, T = cm.mmapping.load_memmap(mc.mmap_file[0]) + # Load the first frame of the movie + mc_image = np.reshape(Yr[: np.product(dims), :1], [1] + list(dims), order="F") # Compute mc.coord_shifts_els grid = [] @@ -591,7 +588,8 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): ) # Open hdf5 file and create 'motion_correction' group - h5f = h5py.File(caiman_fp, "r+") + caiman_fp = pathlib.Path(caiman_fp) + h5f = h5py.File(caiman_fp, "r+" if caiman_fp.exists() else "w") h5g = h5f.require_group("motion_correction") # Write motion correction shifts and motion corrected summary images to hdf5 file @@ -623,7 +621,7 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): # For CaImAn, reference image is still a 2D array even for the case of 3D # Assume that the same ref image is used for all the planes reference_image = ( - np.tile(mc.total_template_els, (1, 1, correlation_image.shape[-1])) + np.tile(mc.total_template_els, (1, 1, dims[-1])) if is3D else mc.total_template_els ) @@ -638,32 +636,45 @@ def _save_mc(mc, caiman_fp: str, is3D: bool): "coord_shifts_rig", shape=np.shape(grid), data=grid, dtype=type(grid[0][0]) ) reference_image = ( - np.tile(mc.total_template_rig, (1, 1, correlation_image.shape[-1])) + np.tile(mc.total_template_rig, (1, 1, dims[-1])) if is3D else mc.total_template_rig ) + if summary_images is None: + # Load motion corrected mmap image + mc_image = cm.load(mc.mmap_file, is3D=is3D) + + # Compute motion corrected summary images + average_image = np.mean(mc_image, axis=0) + max_image = np.max(mc_image, axis=0) + + # Compute motion corrected correlation image + correlation_image = cm.local_correlations( + mc_image.transpose((1, 2, 3, 0) if is3D else (1, 2, 0)) + ) + correlation_image[np.isnan(correlation_image)] = 0 + + summary_images = { + "average_image": average_image, + "max_image": max_image, + "correlation_image": correlation_image, + } + + for img_type, img in summary_images.items(): + h5g.require_dataset( + img_type, + shape=np.shape(img), + data=img, + dtype=img.dtype, + ) + h5g.require_dataset( "reference_image", shape=np.shape(reference_image), data=reference_image, dtype=reference_image.dtype, ) - h5g.require_dataset( - "correlation_image", - shape=np.shape(correlation_image), - data=correlation_image, - dtype=correlation_image.dtype, - ) - h5g.require_dataset( - "average_image", - shape=np.shape(average_image), - data=average_image, - dtype=average_image.dtype, - ) - h5g.require_dataset( - "max_image", shape=np.shape(max_image), data=max_image, dtype=max_image.dtype - ) # Close hdf5 file h5f.close() diff --git a/element_interface/prairie_view_loader.py b/element_interface/prairie_view_loader.py index 6f701ae..83d56b2 100644 --- a/element_interface/prairie_view_loader.py +++ b/element_interface/prairie_view_loader.py @@ -104,6 +104,7 @@ def write_single_bigtiff( output_dir="./", caiman_compatible=False, # if True, save the movie as a single page (frame x height x width) overwrite=False, + gb_per_file=None, ): logger.warning("Deprecation warning: `caiman_compatible` argument will no longer have any effect and will be removed in the future. `write_single_bigtiff` will return multi-page tiff, which is compatible with CaImAn.") @@ -112,13 +113,14 @@ def write_single_bigtiff( ) if output_prefix is None: output_prefix = os.path.commonprefix(tiff_names) - output_tiff_fullpath = ( - Path(output_dir) - / f"{output_prefix}_pln{plane_idx}_chn{channel}.tif" - ) - if output_tiff_fullpath.exists() and not overwrite: - return output_tiff_fullpath + output_tiff_stem = f"{output_prefix}_pln{plane_idx}_chn{channel}" + + output_dir = Path(output_dir) + output_tiff_list = list(output_dir.glob(f"{output_tiff_stem}*.tif")) + if len(output_tiff_list) and not overwrite: + return output_tiff_list[0] if gb_per_file is None else output_tiff_list + output_tiff_list = [] if self.meta["is_multipage"]: # For multi-page tiff - the pages are organized as: # (channel x slice x frame) - each page is (height x width) @@ -156,38 +158,52 @@ def write_single_bigtiff( except Exception as e: raise Exception(f"Error in processing tiff file {input_file}: {e}") + output_tiff_fullpath = ( + output_dir + / f"{output_tiff_stem}.tif" + ) tifffile.imwrite( output_tiff_fullpath, combined_data, metadata={"axes": "TYX", "'fps'": self.meta["frame_rate"]}, bigtiff=True, ) + output_tiff_list.append(output_tiff_fullpath) else: - with tifffile.TiffWriter( - output_tiff_fullpath, - bigtiff=True, - ) as tiff_writer: - try: - for input_file in tiff_names: - with tifffile.TiffFile( - self.prairieview_dir / input_file - ) as tffl: - assert len(tffl.pages) == 1 - tiff_writer.write( - tffl.pages[0].asarray(), - metadata={ - "axes": "YX", - "'fps'": self.meta["frame_rate"], - }, - ) - # additional safeguard to close the file and delete the object - # in the attempt to prevent error: `not a TIFF file b''` - tffl.close() - del tffl - except Exception as e: - raise Exception(f"Error in processing tiff file {input_file}: {e}") - - return output_tiff_fullpath + while len(tiff_names): + output_tiff_fullpath = ( + output_dir + / f"{output_tiff_stem}_{len(output_tiff_list):04}.tif" + ) + with tifffile.TiffWriter( + output_tiff_fullpath, + bigtiff=True, + ) as tiff_writer: + while len(tiff_names): + input_file = tiff_names.pop(0) + try: + with tifffile.TiffFile( + self.prairieview_dir / input_file + ) as tffl: + assert len(tffl.pages) == 1 + tiff_writer.write( + tffl.pages[0].asarray(), + metadata={ + "axes": "YX", + "'fps'": self.meta["frame_rate"], + }, + ) + # additional safeguard to close the file and delete the object + # in the attempt to prevent error: `not a TIFF file b''` + tffl.close() + del tffl + except Exception as e: + raise Exception(f"Error in processing tiff file {input_file}: {e}") + if gb_per_file and output_tiff_fullpath.stat().st_size >= gb_per_file * 1024 ** 3: + break + output_tiff_list.append(output_tiff_fullpath) + + return output_tiff_list[0] if gb_per_file is None else output_tiff_list def _extract_prairieview_metadata(xml_filepath: str):