Skip to content

Commit

Permalink
fusion still WIP, but getting closer
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Dec 7, 2024
1 parent 674ef74 commit c972839
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 15 deletions.
16 changes: 13 additions & 3 deletions dask-stitch/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ wildcard_constraints:

rule all:
input:
nifti=expand('results/tile-{tile}_optimized_SPIM.nii', tile=range(gridx*gridy))
'results/fused_SPIM.nii'

rule create_test_dataset_single_ome_zarr:
params:
Expand All @@ -34,7 +34,8 @@ rule compute_pairwise_correlation:
ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)),
pairs='results/overlapping_pairs.txt'
output:
offsets='results/pairwise_offsets.txt'
offsets='results/pairwise_offsets.txt',
diagnostics_dir=directory('results/pairwise_offsets_work')
script: 'scripts/compute_pairwise_correlation.py'

rule global_optimization:
Expand All @@ -56,5 +57,14 @@ rule assign_translations:
nifti=expand('results/tile-{tile}_optimized_SPIM.nii', tile=range(gridx*gridy))
script:
"scripts/assign_translation.py"


rule fuse_volume:
input:
ome_zarr=expand('results/tile-{tile}_SPIM.ome.zarr', tile=range(gridx*gridy)),
optimized_translations='results/optimized_translations.txt'
output:
ome_zarr = directory('results/fused_SPIM.ome.zarr'),
nifti = 'results/fused_SPIM.nii'
script:
'scripts/fuse_volume.py'

87 changes: 76 additions & 11 deletions dask-stitch/scripts/compute_pairwise_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
from zarrnii import ZarrNii
from scipy.fft import fftn, ifftn
from scipy.ndimage import map_coordinates
import os


def phase_correlation(img1, img2):
def phase_correlation(img1, img2, diagnostics_dir=None, pair_index=None):
"""
Compute the phase correlation between two 3D images to find the translation offset.
Parameters:
- img1 (np.ndarray): First image (3D array).
- img2 (np.ndarray): Second image (3D array).
- diagnostics_dir (str): Directory to save diagnostic outputs (optional).
- pair_index (int): Index of the pair being processed (for naming diagnostics).
Returns:
- np.ndarray: Offset vector [z_offset, y_offset, x_offset].
Expand All @@ -26,6 +29,10 @@ def phase_correlation(img1, img2):
# Inverse Fourier transform to get correlation map
correlation = np.abs(ifftn(cross_power))

# Save the correlation map for diagnostics
if diagnostics_dir and pair_index is not None:
np.save(os.path.join(diagnostics_dir, f"correlation_map_pair_{pair_index}.npy"), correlation)

# Find the peak in the correlation map
peak = np.unravel_index(np.argmax(correlation), correlation.shape)

Expand Down Expand Up @@ -69,21 +76,58 @@ def resample_to_bounding_box(img, affine, bbox_min, bbox_max, output_shape):
return map_coordinates(img, [voxel_coords[..., dim] for dim in range(3)], order=1, mode="constant")


def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape=(64, 64, 64)):
def compute_corrected_bounding_box(affine, img_shape):
"""
Compute the corrected bounding box in physical space, accounting for negative physical dimensions.
Parameters:
- affine (np.ndarray): 4x4 affine matrix.
- img_shape (tuple): Shape of the image (Z, Y, X).
Returns:
- bbox_min (np.ndarray): Minimum physical coordinates.
- bbox_max (np.ndarray): Maximum physical coordinates.
"""
corners = [
np.array([0, 0, 0, 1]),
np.array([img_shape[2], 0, 0, 1]),
np.array([0, img_shape[1], 0, 1]),
np.array([img_shape[2], img_shape[1], 0, 1]),
np.array([0, 0, img_shape[0], 1]),
np.array([img_shape[2], 0, img_shape[0], 1]),
np.array([0, img_shape[1], img_shape[0], 1]),
np.array([img_shape[2], img_shape[1], img_shape[0], 1]),
]

# Transform voxel corners to physical space
corners_physical = np.dot(affine, np.array(corners).T).T[:, :3]

# Find corrected min and max
bbox_min = np.min(corners_physical, axis=0)
bbox_max = np.max(corners_physical, axis=0)

return bbox_min, bbox_max


def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape=(64, 64, 64), diagnostics_dir=None):
"""
Compute the optimal offset for each pair of overlapping tiles.
Parameters:
- ome_zarr_paths (list of str): List of paths to OME-Zarr datasets.
- overlapping_pairs (list of tuples): List of overlapping tile indices.
- output_shape (tuple): Shape of the resampled bounding box.
- diagnostics_dir (str): Directory to save diagnostic outputs (optional).
Returns:
- np.ndarray: Array of offsets for each pair (N, 3) where N is the number of pairs.
"""
offsets = []

for i, j in overlapping_pairs:
if diagnostics_dir:
os.makedirs(diagnostics_dir, exist_ok=True)

for pair_index, (i, j) in enumerate(overlapping_pairs):
# Load the two images and their affines
znimg1 = ZarrNii.from_path(ome_zarr_paths[i])
znimg2 = ZarrNii.from_path(ome_zarr_paths[j])
Expand All @@ -93,26 +137,46 @@ def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape

affine1 = znimg1.vox2ras.affine
affine2 = znimg2.vox2ras.affine

#HACK FIX
affine1[:3,3] = -1 * np.flip(affine1[:3,3])
affine2[:3,3] = -1 * np.flip(affine2[:3,3])

# Compute the overlapping bounding box in physical space
bbox1_min = affine1[:3, 3]
bbox1_max = bbox1_min + np.dot(affine1[:3, :3], img1.shape[::-1])
bbox2_min = affine2[:3, 3]
bbox2_max = bbox2_min + np.dot(affine2[:3, :3], img2.shape[::-1])


# Compute the corrected bounding boxes
bbox1_min, bbox1_max = compute_corrected_bounding_box(affine1, img1.shape)
bbox2_min, bbox2_max = compute_corrected_bounding_box(affine2, img2.shape)



print(f'bbox1, from tile: {i}')
print(bbox1_min)
print(bbox1_max)
print(f'bbox2, from tile: {j}')
print(bbox2_min)
print(bbox2_max)

bbox_min = np.maximum(bbox1_min, bbox2_min)
bbox_max = np.minimum(bbox1_max, bbox2_max)

print('overlapping bbox')
print(bbox_min)
print(bbox_max)
# Save bounding box for diagnostics
if diagnostics_dir:
np.save(os.path.join(diagnostics_dir, f"bounding_box_pair_{pair_index}.npy"), np.array([bbox_min, bbox_max]))

# Resample images to the overlapping bounding box
resampled_img1 = resample_to_bounding_box(img1, affine1, bbox_min, bbox_max, output_shape)
resampled_img2 = resample_to_bounding_box(img2, affine2, bbox_min, bbox_max, output_shape)

# Save resampled images for diagnostics
if diagnostics_dir:
np.save(os.path.join(diagnostics_dir, f"resampled_img1_pair_{pair_index}.npy"), resampled_img1)
np.save(os.path.join(diagnostics_dir, f"resampled_img2_pair_{pair_index}.npy"), resampled_img2)

# Compute phase correlation on the resampled overlapping region
offset = phase_correlation(resampled_img1, resampled_img2)
offset = phase_correlation(resampled_img1, resampled_img2, diagnostics_dir=diagnostics_dir, pair_index=pair_index)
offsets.append(offset)

return np.array(offsets)
Expand All @@ -121,9 +185,10 @@ def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, output_shape
# Example usage
overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs
ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths
diagnostics_dir = snakemake.output.diagnostics_dir # Directory for diagnostics

# Compute pairwise offsets
offsets = compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs)
offsets = compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs, diagnostics_dir=diagnostics_dir)

# Save results
np.savetxt(snakemake.output.offsets, offsets, fmt="%.6f")
Expand Down
2 changes: 1 addition & 1 deletion dask-stitch/scripts/create_test_dataset_singletile.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def create_test_dataset_single(tile_index, template="MNI152NLin2009cAsym", res=2
# TODO: Simulate error by applying a transformation to the image before

# initially lets just do a random jitter:
offset = np.random.uniform(-5, 5, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile
offset = np.random.uniform(0, 0, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile

xfm_img_data = affine_transform(img_data,matrix=np.eye(3,3),offset=offset[x,y,:],order=1)

Expand Down
Loading

0 comments on commit c972839

Please sign in to comment.