Skip to content

Commit

Permalink
make multiple ome zarr
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Dec 6, 2024
1 parent 5cad3f6 commit abe66df
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 95 deletions.
44 changes: 38 additions & 6 deletions dask-stitch/Snakefile
Original file line number Diff line number Diff line change
@@ -1,15 +1,47 @@
configfile: 'config.yml'

rule create_test_dataset_ome_zarr:
def get_tile_targets():
gridx=config['gridx']
gridy=config['gridy']
tile_path=f'test_grid-{gridx}by{gridy}/tile-{{tile}}_SPIM.ome.zarr'

return expand(tile_path, tile=range(gridx*gridy))

rule all:
input:
ome_zarr=get_tile_targets()

rule create_test_dataset_single_ome_zarr:
params:
grid_shape=lambda wildcards: (int(wildcards.gridx),int(wildcards.gridy)),
tile_index=lambda wildcards: int(wildcards.tile)
output:
ome_zarr=directory('test_tiled.ome.zarr'),
translations_npy='test_translations.npy'
script: 'scripts/create_test_dataset.py'
ome_zarr=directory('test_grid-{gridx}by{gridy}/tile-{tile}_SPIM.ome.zarr'),
nifti='test_grid-{gridx}by{gridy}/tile-{tile}_SPIM.nii',
script: 'scripts/create_test_dataset_singletile.py'





#-- unused below:

rule create_test_dataset_combined_ome_zarr:
params:
grid_shape=lambda wildcards: (int(wildcards.gridx),int(wildcards.gridy))
output:
ome_zarr=directory('testcombined_grid-{gridx}by{gridy}_SPIM.ome.zarr'),
translations_npy='testcombined_grid-{gridx}by{gridy}_translations.npy'
script: 'scripts/create_test_dataset_combined.py'


rule get_tiles_as_nifti:
input:
ome_zarr='test_tiled.ome.zarr',
ome_zarr='testcombined_grid-{gridx}by{gridy}_SPIM.ome.zarr'
params:
n_tiles = lambda wildcards: int(wildcards.gridx) * int(wildcards.gridy)
output:
tiles_dir=directory('test_tiled_niftis')
tiles_dir=directory('testcombine_grid-{gridx}by{gridy}_SPIM.niftis')
script:
'scripts/get_tiles_as_nifti.py'

Expand Down
3 changes: 3 additions & 0 deletions dask-stitch/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gridx: 3
gridy: 4

85 changes: 0 additions & 85 deletions dask-stitch/scripts/create_test_dataset.py

This file was deleted.

94 changes: 94 additions & 0 deletions dask-stitch/scripts/create_test_dataset_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
from templateflow import api as tflow
import dask.array as da
import math
from zarrnii import ZarrNii

def create_test_dataset(template="MNI152NLin2009cAsym", res=2, grid_shape=(3, 4), overlap=8, random_seed=42, final_chunks=(32, 32, 1)):
"""
Create a low-resolution test dataset for tile-based stitching.
Parameters:
- template (str): TemplateFlow template name (default: MNI152NLin2009cAsym).
- res (int): Desired resolution in mm (default: 2mm).
- grid_shape (tuple): Shape of the tiling grid (e.g., (3, 4) for 3x4 grid in X-Y).
- overlap (int): Overlap between tiles in X-Y plane in voxels (default: 8).
- random_seed (int): Seed for reproducible random offsets.
- final_chunks (tuple): Desired chunks for final tiles.
Returns:
- znimg (ZarrNii): ZarrNii object containing the tiles.
- translations (np.ndarray): Array of random offsets for each tile.
"""
import math

# Seed the random number generator
np.random.seed(random_seed)

# Download template and load as a ZarrNii object
template_path = tflow.get(template, resolution=res, desc=None, suffix="T1w")
znimg = ZarrNii.from_path(template_path)
img_data = znimg.darr.squeeze()

# Original image shape
img_shape = np.array(img_data.shape) # (Z, Y, X)

# Keep Z dimension intact, calculate X and Y tile sizes
z_dim, y_dim, x_dim = img_shape
x_tile_size = math.ceil((x_dim + overlap * (grid_shape[1] - 1)) / grid_shape[1])
y_tile_size = math.ceil((y_dim + overlap * (grid_shape[0] - 1)) / grid_shape[0])
tile_shape = (z_dim, y_tile_size, x_tile_size)

# Calculate required padding to ensure X and Y dimensions are divisible by grid shape
padded_x = x_tile_size * grid_shape[1] - overlap * (grid_shape[1] - 1)
padded_y = y_tile_size * grid_shape[0] - overlap * (grid_shape[0] - 1)

padding = (
(0, 0), # No padding in Z
(0, int(max(padded_y - y_dim, 0))),
(0, int(max(padded_x - x_dim, 0))),
)

print('padding')
print(padding)
print(img_data.shape)

# Pad image if needed
if any(p[1] > 0 for p in padding):
img_data = da.pad(img_data, padding, mode="constant", constant_values=0)

# Create tiles
tiles = []
translations = []
for y in range(grid_shape[0]):
for x in range(grid_shape[1]):
# Calculate tile start indices
y_start = y * (y_tile_size - overlap)
x_start = x * (x_tile_size - overlap)

# Extract tile
tile = img_data[:, y_start:y_start + y_tile_size, x_start:x_start + x_tile_size]
tiles.append(tile)

# Add random offset -- NOT ACTUALLY BEING APPLIED TO SAMPLING HERE!
offset = np.random.uniform(-5, 5, size=3) # Random 3D offsets
translations.append((0, y_start, x_start) + offset)

# Convert tiles to a Dask array
tiles = da.stack([tile.rechunk(chunks=final_chunks) for tile in tiles])

# Save back into ZarrNii object
znimg.darr = tiles
translations = np.array(translations)

return znimg, translations




test_znimg, test_translations = create_test_dataset(grid_shape=snakemake.params.grid_shape)
print(test_znimg.darr.shape)
test_znimg.to_ome_zarr(snakemake.output.ome_zarr)
np.save(snakemake.output.translations_npy,test_translations)


101 changes: 101 additions & 0 deletions dask-stitch/scripts/create_test_dataset_singletile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import nibabel as nib
import numpy as np
from templateflow import api as tflow
import nibabel as nib
import dask.array as da
import math
from zarrnii import ZarrNii

def create_test_dataset_single(tile_index, template="MNI152NLin2009cAsym", res=2, grid_shape=(3, 4), overlap=8, random_seed=42, final_chunks=(1,32, 32, 1)):
"""
Create a low-resolution test dataset for tile-based stitching.
Parameters:
- tile_index: the index of the tile to create
- template (str): TemplateFlow template name (default: MNI152NLin2009cAsym).
- res (int): Desired resolution in mm (default: 2mm).
- grid_shape (tuple): Shape of the tiling grid (e.g., (3, 4) for 3x4 grid in X-Y).
- overlap (int): Overlap between tiles in X-Y plane in voxels (default: 8).
- random_seed (int): Seed for reproducible random offsets.
- final_chunks (tuple): Desired chunks for final tiles.
Returns:
- znimg (ZarrNii): ZarrNii object containing the tiles.
- translations (np.ndarray): Array of random offsets for each tile.
"""
import math

# Seed the random number generator
np.random.seed(random_seed)

# Download template and load as a ZarrNii object
template_path = tflow.get(template, resolution=res, desc=None, suffix="T1w")
img_data = nib.load(template_path).get_fdata()

# Original image shape
img_shape = np.array(img_data.shape) # (Z, Y, X)

# Keep Z dimension intact, calculate X and Y tile sizes
x_dim, y_dim, z_dim = img_shape
x_tile_size = math.ceil((x_dim + overlap * (grid_shape[0] - 1)) / grid_shape[0])
y_tile_size = math.ceil((y_dim + overlap * (grid_shape[1] - 1)) / grid_shape[1])

# Calculate required padding to ensure X and Y dimensions are divisible by grid shape
padded_x = x_tile_size * grid_shape[0] - overlap * (grid_shape[0] - 1)
padded_y = y_tile_size * grid_shape[1] - overlap * (grid_shape[1] - 1)

padding = (
(0, int(max(padded_x - x_dim, 0))),
(0, int(max(padded_y - y_dim, 0))),
(0, 0), # No padding in Z
)


# Pad image if needed
if any(p[1] > 0 for p in padding):
img_data = np.pad(img_data, padding, mode="constant", constant_values=0)

# Create tiles

x,y = np.unravel_index(tile_index,grid_shape)


# Calculate tile start indices
x_start = x * (x_tile_size - overlap)
y_start = y * (y_tile_size - overlap)

# Extract tile
tile = img_data[x_start:x_start + x_tile_size, y_start:y_start + y_tile_size, :]

# Add random offset - ensure that the random offset generated for the same tile is the same
# do this by gneerating
offset = np.random.uniform(-5, 5, size=(grid_shape[0],grid_shape[1],3)) # Random 3D offsets
translation = ((x_start, y_start, 0) + offset[x,y,:])

print((x_start, y_start, 0))
print(translation)

tile_shape = (1,x_tile_size, y_tile_size, z_dim)


#save translation into vox2ras
vox2ras = np.eye(4)
vox2ras[:3,3] = np.array(translation)


# Save back into ZarrNii object
darr = da.from_array(tile.reshape(tile_shape),chunks=final_chunks)

znimg = ZarrNii.from_darr(darr,vox2ras=vox2ras,axes_nifti=True)

return znimg




test_znimg = create_test_dataset_single(tile_index=snakemake.params.tile_index,
grid_shape=snakemake.params.grid_shape)
test_znimg.to_ome_zarr(snakemake.output.ome_zarr)
test_znimg.to_nifti(snakemake.output.nifti)


5 changes: 1 addition & 4 deletions dask-stitch/scripts/get_tiles_as_nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@
from pathlib import Path


znimg_example_tile= ZarrNii.from_path(snakemake.input.ome_zarr)
print(znimg_full.darr.shape)

out_dir = Path(snakemake.output.tiles_dir)
out_dir.mkdir(exist_ok=True, parents=True)

for tile in range(znimg_full.darr.shape[0]):
for tile in range(snakemake.params.n_tiles):
print(f'reading tile {tile} and writing to nifti')
ZarrNii.from_path(snakemake.input.ome_zarr,channels=[tile]).to_nifti(out_dir / f'tile_{tile:02d}.nii')

Expand Down

0 comments on commit abe66df

Please sign in to comment.