diff --git a/dask-stitch/Snakefile b/dask-stitch/Snakefile index a7651d4..92a9f0f 100644 --- a/dask-stitch/Snakefile +++ b/dask-stitch/Snakefile @@ -9,7 +9,7 @@ wildcard_constraints: rule all: input: - nifti=expand('results/tile-{tile}_optimized_SPIM.nii', tile=range(gridx*gridy)) + 'results/fused_SPIM.nii' rule create_test_dataset_single_ome_zarr: params: @@ -34,7 +34,8 @@ rule compute_pairwise_correlation: ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)), pairs='results/overlapping_pairs.txt' output: - offsets='results/pairwise_offsets.txt' + offsets='results/pairwise_offsets.txt', + diagnostics_dir=directory('results/pairwise_offsets_work') script: 'scripts/compute_pairwise_correlation.py' rule global_optimization: @@ -56,5 +57,14 @@ rule assign_translations: nifti=expand('results/tile-{tile}_optimized_SPIM.nii', tile=range(gridx*gridy)) script: "scripts/assign_translation.py" - + +rule fuse_volume: + input: + ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)), + optimized_translations='results/optimized_translations.txt' + output: + ome_zarr = directory('results/fused_SPIM.ome.zarr'), + nifti = 'results/fused_SPIM.nii' + script: + 'scripts/fuse_volume.py' diff --git a/dask-stitch/scripts/compute_pairwise_correlation.py b/dask-stitch/scripts/compute_pairwise_correlation.py index 28c819d..4449110 100644 --- a/dask-stitch/scripts/compute_pairwise_correlation.py +++ b/dask-stitch/scripts/compute_pairwise_correlation.py @@ -2,15 +2,18 @@ from zarrnii import ZarrNii from scipy.fft import fftn, ifftn from scipy.ndimage import map_coordinates +import os -def phase_correlation(img1, img2): +def phase_correlation(img1, img2, diagnostics_dir=None, pair_index=None): """ Compute the phase correlation between two 3D images to find the translation offset. Parameters: - img1 (np.ndarray): First image (3D array). - img2 (np.ndarray): Second image (3D array). + - diagnostics_dir (str): Directory to save diagnostic outputs (optional). + - pair_index (int): Index of the pair being processed (for naming diagnostics). Returns: - np.ndarray: Offset vector [z_offset, y_offset, x_offset]. @@ -26,6 +29,10 @@ def phase_correlation(img1, img2): # Inverse Fourier transform to get correlation map correlation = np.abs(ifftn(cross_power)) + # Save the correlation map for diagnostics + if diagnostics_dir and pair_index is not None: + np.save(os.path.join(diagnostics_dir, f"correlation_map_pair_{pair_index}.npy"), correlation) + # Find the peak in the correlation map peak = np.unravel_index(np.argmax(correlation), correlation.shape) @@ -69,7 +76,40 @@ def resample_to_bounding_box(img, affine, bbox_min, bbox_max, output_shape): return map_coordinates(img, [voxel_coords[..., dim] for dim in range(3)], order=1, mode="constant") -def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape=(64, 64, 64)): +def compute_corrected_bounding_box(affine, img_shape): + """ + Compute the corrected bounding box in physical space, accounting for negative physical dimensions. + + Parameters: + - affine (np.ndarray): 4x4 affine matrix. + - img_shape (tuple): Shape of the image (Z, Y, X). + + Returns: + - bbox_min (np.ndarray): Minimum physical coordinates. + - bbox_max (np.ndarray): Maximum physical coordinates. + """ + corners = [ + np.array([0, 0, 0, 1]), + np.array([img_shape[2], 0, 0, 1]), + np.array([0, img_shape[1], 0, 1]), + np.array([img_shape[2], img_shape[1], 0, 1]), + np.array([0, 0, img_shape[0], 1]), + np.array([img_shape[2], 0, img_shape[0], 1]), + np.array([0, img_shape[1], img_shape[0], 1]), + np.array([img_shape[2], img_shape[1], img_shape[0], 1]), + ] + + # Transform voxel corners to physical space + corners_physical = np.dot(affine, np.array(corners).T).T[:, :3] + + # Find corrected min and max + bbox_min = np.min(corners_physical, axis=0) + bbox_max = np.max(corners_physical, axis=0) + + return bbox_min, bbox_max + + +def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape=(64, 64, 64), diagnostics_dir=None): """ Compute the optimal offset for each pair of overlapping tiles. @@ -77,13 +117,17 @@ def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. - overlapping_pairs (list of tuples): List of overlapping tile indices. - output_shape (tuple): Shape of the resampled bounding box. + - diagnostics_dir (str): Directory to save diagnostic outputs (optional). Returns: - np.ndarray: Array of offsets for each pair (N, 3) where N is the number of pairs. """ offsets = [] - for i, j in overlapping_pairs: + if diagnostics_dir: + os.makedirs(diagnostics_dir, exist_ok=True) + + for pair_index, (i, j) in enumerate(overlapping_pairs): # Load the two images and their affines znimg1 = ZarrNii.from_path(ome_zarr_paths[i]) znimg2 = ZarrNii.from_path(ome_zarr_paths[j]) @@ -93,26 +137,46 @@ def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape affine1 = znimg1.vox2ras.affine affine2 = znimg2.vox2ras.affine - #HACK FIX affine1[:3,3] = -1 * np.flip(affine1[:3,3]) affine2[:3,3] = -1 * np.flip(affine2[:3,3]) - # Compute the overlapping bounding box in physical space - bbox1_min = affine1[:3, 3] - bbox1_max = bbox1_min + np.dot(affine1[:3, :3], img1.shape[::-1]) - bbox2_min = affine2[:3, 3] - bbox2_max = bbox2_min + np.dot(affine2[:3, :3], img2.shape[::-1]) + + + # Compute the corrected bounding boxes + bbox1_min, bbox1_max = compute_corrected_bounding_box(affine1, img1.shape) + bbox2_min, bbox2_max = compute_corrected_bounding_box(affine2, img2.shape) + + + + print(f'bbox1, from tile: {i}') + print(bbox1_min) + print(bbox1_max) + print(f'bbox2, from tile: {j}') + print(bbox2_min) + print(bbox2_max) bbox_min = np.maximum(bbox1_min, bbox2_min) bbox_max = np.minimum(bbox1_max, bbox2_max) + print('overlapping bbox') + print(bbox_min) + print(bbox_max) + # Save bounding box for diagnostics + if diagnostics_dir: + np.save(os.path.join(diagnostics_dir, f"bounding_box_pair_{pair_index}.npy"), np.array([bbox_min, bbox_max])) + # Resample images to the overlapping bounding box resampled_img1 = resample_to_bounding_box(img1, affine1, bbox_min, bbox_max, output_shape) resampled_img2 = resample_to_bounding_box(img2, affine2, bbox_min, bbox_max, output_shape) + # Save resampled images for diagnostics + if diagnostics_dir: + np.save(os.path.join(diagnostics_dir, f"resampled_img1_pair_{pair_index}.npy"), resampled_img1) + np.save(os.path.join(diagnostics_dir, f"resampled_img2_pair_{pair_index}.npy"), resampled_img2) + # Compute phase correlation on the resampled overlapping region - offset = phase_correlation(resampled_img1, resampled_img2) + offset = phase_correlation(resampled_img1, resampled_img2, diagnostics_dir=diagnostics_dir, pair_index=pair_index) offsets.append(offset) return np.array(offsets) @@ -121,9 +185,10 @@ def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape # Example usage overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths +diagnostics_dir = snakemake.output.diagnostics_dir # Directory for diagnostics # Compute pairwise offsets -offsets = compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs) +offsets = compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, diagnostics_dir=diagnostics_dir) # Save results np.savetxt(snakemake.output.offsets, offsets, fmt="%.6f") diff --git a/dask-stitch/scripts/create_test_dataset_singletile.py b/dask-stitch/scripts/create_test_dataset_singletile.py index 6db72ac..786bf8d 100644 --- a/dask-stitch/scripts/create_test_dataset_singletile.py +++ b/dask-stitch/scripts/create_test_dataset_singletile.py @@ -69,7 +69,7 @@ def create_test_dataset_single(tile_index, template="MNI152NLin2009cAsym", res=2 # TODO: Simulate error by applying a transformation to the image before # initially lets just do a random jitter: - offset = np.random.uniform(-5, 5, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile + offset = np.random.uniform(0, 0, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile xfm_img_data = affine_transform(img_data,matrix=np.eye(3,3),offset=offset[x,y,:],order=1) diff --git a/dask-stitch/scripts/fuse_volume.py b/dask-stitch/scripts/fuse_volume.py new file mode 100644 index 0000000..9b5fcde --- /dev/null +++ b/dask-stitch/scripts/fuse_volume.py @@ -0,0 +1,288 @@ +import numpy as np +import dask.array as da +from zarrnii import ZarrNii +from scipy.ndimage import map_coordinates +import os + + +def compute_fused_volume_shape(ome_zarr_paths, optimized_translations): + """ + Compute the shape of the fused volume in physical space. + + Parameters: + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + - optimized_translations (np.ndarray): Optimized translations for each tile. + + Returns: + - bbox_min (np.ndarray): Minimum coordinates of the fused volume. + - bbox_max (np.ndarray): Maximum coordinates of the fused volume. + - voxel_size (np.ndarray): Voxel size in physical units. + """ + bbox_min = np.inf * np.ones(3) + bbox_max = -np.inf * np.ones(3) + + for path, translation in zip(ome_zarr_paths, optimized_translations): + znimg = ZarrNii.from_path(path) + affine = znimg.vox2ras.affine + + #HACK fix: + affine[:3,3] = -1 * np.flip(affine[:3,3]) + + img_shape = znimg.darr.shape[1:] # Exclude the channel/time dimension + + # Compute corrected bounding box with translation + corners = [ + np.array([0, 0, 0, 1]), + np.array([img_shape[2], 0, 0, 1]), + np.array([0, img_shape[1], 0, 1]), + np.array([img_shape[2], img_shape[1], 0, 1]), + np.array([0, 0, img_shape[0], 1]), + np.array([img_shape[2], 0, img_shape[0], 1]), + np.array([0, img_shape[1], img_shape[0], 1]), + np.array([img_shape[2], img_shape[1], img_shape[0], 1]), + ] + corners_physical = np.dot(affine, np.array(corners).T).T[:, :3] + translation + + # Update bounding box + bbox_min = np.minimum(bbox_min, corners_physical.min(axis=0)) + bbox_max = np.maximum(bbox_max, corners_physical.max(axis=0)) + + voxel_size = np.array((-affine[0,2],-affine[1,1],-affine[2,0])) + return bbox_min, bbox_max, voxel_size + + +def resample_tile_to_chunk(tile, affine, translation, chunk_bbox_min, chunk_bbox_max, chunk_shape): + """ + Resample a tile to the chunk's bounding box in physical space. + + Parameters: + - tile (np.ndarray): Input tile. + - affine (np.ndarray): Affine transformation matrix (4x4). + - translation (np.ndarray): Optimized translation for the tile. + - chunk_bbox_min (np.ndarray): Minimum physical coordinates of the chunk. + - chunk_bbox_max (np.ndarray): Maximum physical coordinates of the chunk. + - chunk_shape (tuple): Shape of the output chunk. + + Returns: + - np.ndarray: Resampled tile for the chunk. + """ + coords = np.meshgrid( + np.linspace(chunk_bbox_min[0], chunk_bbox_max[0], chunk_shape[0]), + np.linspace(chunk_bbox_min[1], chunk_bbox_max[1], chunk_shape[1]), + np.linspace(chunk_bbox_min[2], chunk_bbox_max[2], chunk_shape[2]), + indexing="ij", + ) + coords = np.stack(coords, axis=-1) # Shape: (Z, Y, X, 3) + + # Convert physical coordinates to voxel indices + inverse_affine = np.linalg.inv(affine) + voxel_coords = np.dot(coords.reshape(-1, 3), inverse_affine[:3, :3].T) + inverse_affine[:3, 3] - translation + voxel_coords = voxel_coords.reshape(chunk_shape + (3,)) + + # Interpolate tile values at the transformed voxel coordinates + return map_coordinates(tile, [voxel_coords[..., dim] for dim in range(3)], order=1, mode="constant") + + +def process_chunk(block_info, ome_zarr_paths, optimized_translations, bbox_min, voxel_size, chunk_shape): + """ + Process a single chunk of the fused volume. + + Parameters: + - block_info (dict): Information about the block being processed. + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + - optimized_translations (np.ndarray): Optimized translations for each tile. + - bbox_min (np.ndarray): Minimum physical coordinates of the fused volume. + - voxel_size (np.ndarray): Voxel size in physical units. + - chunk_shape (tuple): Shape of the chunk being processed. + + Returns: + - np.ndarray: Fused chunk. + """ + # Determine chunk bounding box in physical space + chunk_start = np.array(block_info[0]["array-location"][0]) * voxel_size + bbox_min + chunk_end = chunk_start + np.array(chunk_shape) * voxel_size + + chunk = np.zeros(chunk_shape, dtype=np.float32) + weight = np.zeros(chunk_shape, dtype=np.float32) + + for path, translation in zip(ome_zarr_paths, optimized_translations): + znimg = ZarrNii.from_path(path) + tile = znimg.darr.squeeze().compute() + affine = znimg.vox2ras.affine + affine[:3,3] = -1 * np.flip(affine[:3,3]) + + # Resample tile to chunk + resampled_tile = resample_tile_to_chunk(tile, affine, translation, chunk_start, chunk_end, chunk_shape) + + # Fuse by summing intensities and weights + mask = resampled_tile > 0 + chunk[mask] += resampled_tile[mask] + weight[mask] += 1 + + # Avoid division by zero + fused_chunk = np.divide(chunk, weight, out=np.zeros_like(chunk), where=weight > 0) + return fused_chunk + + +def fuse_volume(ome_zarr_paths, optimized_translations, fused_shape, chunk_shape, bbox_min, voxel_size): + """ + Fuse all tiles into a single volume. + + Parameters: + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + - optimized_translations (np.ndarray): Optimized translations for each tile. + - fused_shape (tuple): Shape of the final fused volume. + - chunk_shape (tuple): Shape of each chunk. + - bbox_min (np.ndarray): Minimum coordinates of the fused volume. + - voxel_size (np.ndarray): Voxel size in physical units. + + Returns: + - dask.array: Fused volume. + """ + # Wrap process_chunk for Dask + def wrapped_process_chunk(block, block_info=None): + return process_chunk( + block_info, + ome_zarr_paths, + optimized_translations, + bbox_min, + voxel_size, + chunk_shape, + ) + + # Define Dask array for the fused volume + fused_volume = da.map_blocks( + wrapped_process_chunk, + chunks=chunk_shape, + dtype=np.float32, + shape=fused_shape, + ) + + return fused_volume + + +# Example usage +ome_zarr_paths = snakemake.input.ome_zarr # List of input OME-Zarr paths +optimized_translations = np.loadtxt(snakemake.input.optimized_translations, dtype=float) # Optimized translations + +# Compute the fused volume shape +bbox_min, bbox_max, voxel_size = compute_fused_volume_shape(ome_zarr_paths, optimized_translations) +assert all(voxel_size > 0), "Voxel size must be greater than zero." + +fused_shape = tuple(np.ceil((bbox_max - bbox_min) / voxel_size).astype(int)) +chunk_shape = (64, 64, 64) # Example chunk size + +# Fuse the volume +fused_volume = fuse_volume(ome_zarr_paths, optimized_translations, fused_shape, chunk_shape, bbox_min, voxel_size) + +# Save the fused volume +fused_volume.to_zarr(snakemake.output.fused_volume) + + + + +def process_chunk(chunk_data, block_info, ome_zarr_paths, optimized_translations, bbox_min, voxel_size, chunk_shape): + """ + Process a single chunk of the fused volume. + + Parameters: + - chunk_data (np.ndarray): Placeholder data for the chunk. + - block_info (dict): Information about the block being processed. + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + - optimized_translations (np.ndarray): Optimized translations for each tile. + - bbox_min (np.ndarray): Minimum physical coordinates of the fused volume. + - voxel_size (np.ndarray): Voxel size in physical units. + - chunk_shape (tuple): Shape of the chunk being processed. + + Returns: + - np.ndarray: Fused chunk. + """ + # Extract chunk location and physical bounding box + chunk_start = np.array(block_info[0]["chunk-location"]) * voxel_size + bbox_min + chunk_end = chunk_start + np.array(chunk_shape) * voxel_size + + chunk = np.zeros(chunk_shape, dtype=np.float32) + weight = np.zeros(chunk_shape, dtype=np.float32) + + for path, translation in zip(ome_zarr_paths, optimized_translations): + znimg = ZarrNii.from_path(path) + tile = znimg.darr.squeeze().compute() + affine = znimg.vox2ras.affine + + # Resample tile to chunk + resampled_tile = resample_tile_to_chunk(tile, affine, translation, chunk_start, chunk_end, chunk_shape) + + # Fuse by summing intensities and weights + mask = resampled_tile > 0 + chunk[mask] += resampled_tile[mask] + weight[mask] += 1 + + # Avoid division by zero + fused_chunk = np.divide(chunk, weight, out=np.zeros_like(chunk), where=weight > 0) + return fused_chunk + + +def fuse_volume(ome_zarr_paths, optimized_translations, fused_shape, chunk_shape, bbox_min, voxel_size): + """ + Fuse all tiles into a single volume. + + Parameters: + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + - optimized_translations (np.ndarray): Optimized translations for each tile. + - fused_shape (tuple): Shape of the final fused volume. + - chunk_shape (tuple): Shape of each chunk. + - bbox_min (np.ndarray): Minimum coordinates of the fused volume. + - voxel_size (np.ndarray): Voxel size in physical units. + + Returns: + - dask.array: Fused volume. + """ + # Wrap process_chunk for Dask + def wrapped_process_chunk(chunk_data, block_info=None): + return process_chunk( + chunk_data, + block_info, + ome_zarr_paths, + optimized_translations, + bbox_min, + voxel_size, + chunk_shape, + ) + + # Define Dask array for the fused volume + fused_volume = da.map_blocks( + wrapped_process_chunk, + chunks=chunk_shape, + dtype=np.float32, + shape=fused_shape, + ) + + return fused_volume + + +# Example usage +ome_zarr_paths = snakemake.input.ome_zarr # List of input OME-Zarr paths +optimized_translations = np.loadtxt(snakemake.input.optimized_translations, dtype=float) # Optimized translations + +# Compute the fused volume shape +bbox_min, bbox_max, voxel_size = compute_fused_volume_shape(ome_zarr_paths, optimized_translations) + + +print(bbox_min) +print(bbox_max) +print(voxel_size) +assert all(voxel_size > 0), "Voxel size must be greater than zero." + +fused_shape = tuple(np.ceil((bbox_max - bbox_min) / voxel_size).astype(int)) +chunk_shape = (64, 64, 64) # Example chunk size + +# Fuse the volume +fused_volume = fuse_volume(ome_zarr_paths, optimized_translations, fused_shape, chunk_shape, bbox_min, voxel_size) + +# Save the fused volume +fused_volume.to_zarr(snakemake.output.fused_volume) + + + + +