diff --git a/dask-stitch/Snakefile b/dask-stitch/Snakefile index 0ee335a..7dc9ac3 100644 --- a/dask-stitch/Snakefile +++ b/dask-stitch/Snakefile @@ -21,8 +21,31 @@ rule create_test_dataset_single_ome_zarr: script: 'scripts/create_test_dataset_singletile.py' +rule find_overlapping_pairs: + input: + ome_zarr=get_tile_targets() + output: + txt='test_grid-{gridx}by{gridy}/overlapping_pairs.txt' + script: 'scripts/find_overlapping_pairs.py' + +rule compute_pairwise_correlation: + input: + ome_zarr=get_tile_targets(), + pairs='test_grid-{gridx}by{gridy}/overlapping_pairs.txt' + output: + offsets='test_grid-{gridx}by{gridy}/pairwise_offsets.txt' + script: 'scripts/compute_pairwise_correlation.py' - +rule global_optimization: + input: + ome_zarr=get_tile_targets(), + pairs='test_grid-{gridx}by{gridy}/overlapping_pairs.txt', + offsets='test_grid-{gridx}by{gridy}/pairwise_offsets.txt' + output: + optimized_translations='test_grid-{gridx}by{gridy}/optimized_translations.txt' + script: + 'scripts/global_optimization.py' + #-- unused below: diff --git a/dask-stitch/scripts/compute_pairwise_correlation.py b/dask-stitch/scripts/compute_pairwise_correlation.py new file mode 100644 index 0000000..cd04431 --- /dev/null +++ b/dask-stitch/scripts/compute_pairwise_correlation.py @@ -0,0 +1,78 @@ +import numpy as np +from zarrnii import ZarrNii +from scipy.fft import fftn, ifftn +from scipy.ndimage import center_of_mass + + +def phase_correlation(img1, img2): + """ + Compute the phase correlation between two 3D images to find the translation offset. + + Parameters: + - img1 (np.ndarray): First image (3D array). + - img2 (np.ndarray): Second image (3D array). + + Returns: + - np.ndarray: Offset vector [z_offset, y_offset, x_offset]. + """ + # Compute the Fourier transforms + fft1 = fftn(img1) + fft2 = fftn(img2) + + # Compute the cross-power spectrum + cross_power = fft1 * np.conj(fft2) + cross_power /= np.abs(cross_power) # Normalize + + # Inverse Fourier transform to get correlation map + correlation = np.abs(ifftn(cross_power)) + + # Find the peak in the correlation map + peak = np.unravel_index(np.argmax(correlation), correlation.shape) + + # Convert peak index to an offset + shifts = np.array(peak, dtype=float) + for dim, size in enumerate(correlation.shape): + if shifts[dim] > size // 2: + shifts[dim] -= size + + return shifts + + +def compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs): + """ + Compute the optimal offset for each pair of overlapping tiles. + + Parameters: + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + - overlapping_pairs (list of tuples): List of overlapping tile indices. + + Returns: + - np.ndarray: Array of offsets for each pair (N, 3) where N is the number of pairs. + """ + offsets = [] + + for i, j in overlapping_pairs: + # Load the two images + znimg1 = ZarrNii.from_path(ome_zarr_paths[i]) + znimg2 = ZarrNii.from_path(ome_zarr_paths[j]) + + img1 = znimg1.darr.squeeze().compute() + img2 = znimg2.darr.squeeze().compute() + + # Compute phase correlation + offset = phase_correlation(img1, img2) + offsets.append(offset) + + return np.array(offsets) + + +# Example usage +overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs +ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths + +# Compute pairwise offsets +offsets = compute_pairwise_correlation(ome_zarr_paths, overlapping_pairs) + +# Save results +np.savetxt(snakemake.output.offsets, offsets, fmt="%.6f") + diff --git a/dask-stitch/scripts/create_test_dataset_singletile.py b/dask-stitch/scripts/create_test_dataset_singletile.py index b8d1d8e..2a887b0 100644 --- a/dask-stitch/scripts/create_test_dataset_singletile.py +++ b/dask-stitch/scripts/create_test_dataset_singletile.py @@ -1,5 +1,6 @@ import nibabel as nib import numpy as np +from scipy.ndimage import affine_transform from templateflow import api as tflow import nibabel as nib import dask.array as da @@ -64,21 +65,27 @@ def create_test_dataset_single(tile_index, template="MNI152NLin2009cAsym", res=2 x_start = x * (x_tile_size - overlap) y_start = y * (y_tile_size - overlap) + + # TODO: Simulate error by applying a transformation to the image before + + # initially lets just do a random jitter: + offset = np.random.uniform(-10, 10, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets for each tile + + xfm_img_data = affine_transform(img_data,matrix=np.eye(3,3),offset=offset[x,y,:],order=1) + # Extract tile - tile = img_data[x_start:x_start + x_tile_size, y_start:y_start + y_tile_size, :] + tile = xfm_img_data[x_start:x_start + x_tile_size, y_start:y_start + y_tile_size, :] - # Add random offset - ensure that the random offset generated for the same tile is the same - # do this by gneerating - offset = np.random.uniform(-5, 5, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets - translation = ((x_start, y_start, 0) + offset[x,y,:]) - print((x_start, y_start, 0)) - print(translation) + + translation = ((x_start, y_start, 0)) + + tile_shape = (1,x_tile_size, y_tile_size, z_dim) - #save translation into vox2ras + #save tiling coordinate translation into vox2ras (not the random jitter) vox2ras = np.eye(4) vox2ras[:3,3] = np.array(translation) diff --git a/dask-stitch/scripts/find_overlapping_pairs.py b/dask-stitch/scripts/find_overlapping_pairs.py new file mode 100644 index 0000000..b70c86f --- /dev/null +++ b/dask-stitch/scripts/find_overlapping_pairs.py @@ -0,0 +1,59 @@ +import numpy as np + +def find_overlapping_pairs(ome_zarr_paths): + """ + Identify overlapping tile pairs based on their physical offsets. + + Parameters: + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + + Returns: + - List of tuples: Each tuple is a pair of overlapping tile indices ((i, j)). + """ + from zarrnii import ZarrNii + + # Read physical transformations and calculate bounding boxes + bounding_boxes = [] + for path in ome_zarr_paths: + znimg = ZarrNii.from_path(path) + affine = znimg.vox2ras.affine # 4x4 matrix + tile_shape = znimg.darr.shape[1:] + + # Compute physical bounding box using affine + corners = [ + np.array([0, 0, 0, 1]), + np.array([tile_shape[2], 0, 0, 1]), + np.array([0, tile_shape[1], 0, 1]), + np.array([tile_shape[2], tile_shape[1], 0, 1]), + np.array([0, 0, tile_shape[0], 1]), + np.array([tile_shape[2], 0, tile_shape[0], 1]), + np.array([0, tile_shape[1], tile_shape[0], 1]), + np.array([tile_shape[2], tile_shape[1], tile_shape[0], 1]), + ] + corners_physical = np.dot(affine, np.array(corners).T).T[:, :3] # Drop homogeneous coordinate + bbox_min = corners_physical.min(axis=0) + bbox_max = corners_physical.max(axis=0) + + bounding_boxes.append((bbox_min, bbox_max)) + + # Find overlapping pairs + overlapping_pairs = [] + for i, (bbox1_min, bbox1_max) in enumerate(bounding_boxes): + for j, (bbox2_min, bbox2_max) in enumerate(bounding_boxes): + if i >= j: + continue # Avoid duplicate pairs and self-comparison + + # Check for overlap in all dimensions + overlap = all( + bbox1_min[d] < bbox2_max[d] and bbox1_max[d] > bbox2_min[d] + for d in range(3) + ) + if overlap: + overlapping_pairs.append((i, j)) + + return overlapping_pairs + + +overlapping_pairs = find_overlapping_pairs(snakemake.input) +np.savetxt(snakemake.output.txt,np.array(overlapping_pairs),fmt='%d') + diff --git a/dask-stitch/scripts/global_optimization.py b/dask-stitch/scripts/global_optimization.py new file mode 100644 index 0000000..9ce994f --- /dev/null +++ b/dask-stitch/scripts/global_optimization.py @@ -0,0 +1,66 @@ +import numpy as np +from scipy.optimize import least_squares +from zarrnii import ZarrNii + + +def global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets): + """ + Perform global optimization to adjust translations for all tiles. + + Parameters: + - ome_zarr_paths (list of str): List of paths to OME-Zarr datasets. + - overlapping_pairs (list of tuples): List of overlapping tile indices ((i, j)). + - pairwise_offsets (np.ndarray): Array of pairwise offsets (N, 3), where N is the number of pairs. + + Returns: + - np.ndarray: Optimized global translations of shape (T, 3). + """ + # Number of tiles is the number of OME-Zarr paths + num_tiles = len(ome_zarr_paths) + + # Initial translations (start with identity translation: no offsets) + initial_translations = np.zeros((num_tiles, 3)) + + # Flatten initial translations for optimization + x0 = initial_translations.flatten() + + def objective(x): + """ + Compute the residuals for global optimization. + + Parameters: + - x (np.ndarray): Flattened translations array (T * 3,). + + Returns: + - np.ndarray: Residuals for least-squares optimization. + """ + translations = x.reshape((num_tiles, 3)) + residuals = [] + + for (i, j), offset in zip(overlapping_pairs, pairwise_offsets): + # Residual is the difference between the predicted and actual offset + predicted_offset = translations[j] - translations[i] + residuals.append(predicted_offset - offset) + + return np.concatenate(residuals) + + # Perform least-squares optimization + result = least_squares(objective, x0) + + # Reshape result back to (T, 3) + optimized_translations = result.x.reshape((num_tiles, 3)) + + return optimized_translations + + +# Example usage +overlapping_pairs = np.loadtxt(snakemake.input.pairs, dtype=int).tolist() # Overlapping pairs +pairwise_offsets = np.loadtxt(snakemake.input.offsets, dtype=float) # Pairwise offsets +ome_zarr_paths = snakemake.input.ome_zarr # List of OME-Zarr paths + +# Perform global optimization +optimized_translations = global_optimization(ome_zarr_paths, overlapping_pairs, pairwise_offsets) + +# Save results +np.savetxt(snakemake.output.optimized_translations, optimized_translations, fmt="%.6f") +