Skip to content

Commit

Permalink
WIP: prototype workflow for developing stitching alg
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Dec 6, 2024
1 parent 4562271 commit 5cad3f6
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
16 changes: 16 additions & 0 deletions dask-stitch/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

rule create_test_dataset_ome_zarr:
output:
ome_zarr=directory('test_tiled.ome.zarr'),
translations_npy='test_translations.npy'
script: 'scripts/create_test_dataset.py'

rule get_tiles_as_nifti:
input:
ome_zarr='test_tiled.ome.zarr',
output:
tiles_dir=directory('test_tiled_niftis')
script:
'scripts/get_tiles_as_nifti.py'


85 changes: 85 additions & 0 deletions dask-stitch/scripts/create_test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
from templateflow import api as tflow
import dask.array as da
from zarrnii import ZarrNii


out_ome_zarr=snakemake.output.ome_zarr
out_translations_npy=snakemake.output.translations_npy

def create_test_dataset(template="MNI152NLin2009cAsym", res=2, tile_shape=(32,32, 32), final_chunks=(1,32,32,1),overlap=8, random_seed=42):
"""
Create a low-resolution test dataset for tile-based stitching.
Parameters:
- template (str): TemplateFlow template name (default: MNI152NLin2009cAsym).
- res (int): Desired resolution in mm (default: 2mm).
- tile_shape (tuple): Shape of each tile in voxels (default: (64, 64, 64)).
- overlap (int): Overlap between tiles in voxels (default: 16).
- random_seed (int): Seed for reproducible random offsets.
Returns:
- tiles (dask.array): TxZxYxX Dask array of tiles.
- translations (np.ndarray): Array of random offsets for each tile.
"""
# Seed the random number generator
np.random.seed(random_seed)

# Download template and load as a Numpy array
template_path = tflow.get(template, resolution=res, desc=None,suffix="T1w")
znimg = ZarrNii.from_path(template_path)
print(znimg)
print(znimg.darr)
img_data = znimg.darr


# Determine number of tiles in each dimension
img_shape = img_data.shape
step = tuple(s - overlap for s in tile_shape)

grid_shape = tuple(
max(((img_shape[dim] - tile_shape[dim]) // step[dim]) + 1, 1)
for dim in range(3)
)

print(f'img_shape {img_shape}, step: {step}, grid_shape: {grid_shape}')
# Create tiles
tiles = []
translations = []
for z in range(grid_shape[0]):
for y in range(grid_shape[1]):
for x in range(grid_shape[2]):
# Extract tile
z_start, y_start, x_start = z * step[0], y * step[1], x * step[2]
tile = img_data[z_start:z_start+tile_shape[0],
y_start:y_start+tile_shape[1],
x_start:x_start+tile_shape[2]]

# Add to list
tiles.append(tile)

# Add random offset
offset = np.random.uniform(-5, 5, size=3) # Random 3D offsets
translations.append((z_start, y_start, x_start) + offset)

print(tiles)
# Convert to a Dask array
print(tile_shape)
tiles = da.concatenate([tile.rechunk(chunks=final_chunks) for tile in tiles])
translations = np.array(translations)

znimg.darr = tiles

return znimg, translations




if __name__ == '__main__':
test_znimg, test_translations = create_test_dataset()
print(test_znimg)

print(test_translations.shape)
test_znimg.to_ome_zarr(out_ome_zarr)
np.save(out_translations_npy,test_translations)

17 changes: 17 additions & 0 deletions dask-stitch/scripts/get_tiles_as_nifti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import nibabel as nib
from zarrnii import ZarrNii
from pathlib import Path


znimg_example_tile= ZarrNii.from_path(snakemake.input.ome_zarr)
print(znimg_full.darr.shape)

out_dir = Path(snakemake.output.tiles_dir)
out_dir.mkdir(exist_ok=True, parents=True)

for tile in range(znimg_full.darr.shape[0]):
print(f'reading tile {tile} and writing to nifti')
ZarrNii.from_path(snakemake.input.ome_zarr,channels=[tile]).to_nifti(out_dir / f'tile_{tile:02d}.nii')



0 comments on commit 5cad3f6

Please sign in to comment.