Skip to content

Commit

Permalink
WIP spark
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Sep 25, 2024
1 parent c3586de commit d1fa55c
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 57 deletions.
31 changes: 17 additions & 14 deletions config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ work: 'work'

remote_creds: '~/.config/gcloud/application_default_credentials.json' #this is needed so we can pass creds to container

write_ome_zarr_direct: True #use this to skip writing the final zarr output to work first and copying afterwards -- useful when work is not a fast local disk
write_ome_zarr_direct: False #use this to skip writing the final zarr output to work first and copying afterwards -- useful when work is not a fast local disk

cores_per_rule: 32

Expand Down Expand Up @@ -42,28 +42,29 @@ bigstitcher:
optical_flow: "Lucas-Kanade"
filter_pairwise_shifts:
enabled: 1
min_r: 0.7
min_r: 0.2
max_shift_total: 50
global_optimization:
enabled: 1
strategy: two_round
strategies:
one_round: "One-Round"
one_round_iterative: "One-Round with iterative dropping of bad links"
two_round: "Two-Round using metadata to align unconnected Tiles"
two_round_iterative: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links"
enabled: 1
method: ONE_ROUND_SIMPLE
methods:
ONE_ROUND_SIMPLE: "One-Round"
ONE_ROUND_ITERATIVE: "One-Round with iterative dropping of bad links"
TWO_ROUND_SIMPLE: "Two-Round using metadata to align unconnected Tiles"
TWO_ROUND_ITERATIVE: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links"


fuse_dataset:
downsampling: 1
block_size_x: 256 # for storage
block_size_y: 256
block_size_z: 1
block_size_z: 8
block_size_factor_x: 1 #e.g. 2 will use 2*block_size for computation
block_size_factor_y: 1
block_size_factor_z: 256
block_size_factor_z: 32

ome_zarr:
desc: stitchedflatcorr
desc: sparkstitchedflatcorr
max_downsampling_layers: 5 # e.g. 4 levels: { 0: orig, 1: ds2, 2: ds4, 3: ds8, 4: ds16}
rechunk_size: #z, y, x
- 1
Expand Down Expand Up @@ -99,7 +100,7 @@ ome_zarr:
id: 0
name: spim
version: "0.4"
use_zipstore: False #if True, produce SPIM.ome.zarr.zip instead of SPIM.ome.zarr
use_zipstore: True #if True, produce SPIM.ome.zarr.zip instead of SPIM.ome.zarr

nifti:
levels: #cannot be higher than max_downsampling_layers in ome_zarr
Expand Down Expand Up @@ -154,5 +155,7 @@ report:


containers:
spimprep: 'docker://khanlab/spimprep-deps:main'
#spimprep: 'docker://khanlab/spimprep-deps:main'
spimprep: '/local/scratch/spimprep-deps_spark.sif'
# updated.sif'

127 changes: 121 additions & 6 deletions workflow/rules/bigstitcher.smk
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ rule zarr_to_bdv:
script:
"../scripts/zarr_to_n5_bdv.py"


"""
rule bigstitcher:
input:
dataset_n5=rules.zarr_to_bdv.output.bdv_n5,
Expand Down Expand Up @@ -159,6 +159,121 @@ rule bigstitcher:
" {params.fiji_launcher_cmd} && "
" echo ' -macro {input.ijm} \"{params.macro_args}\"' >> {output.launcher} "
" && {output.launcher} |& tee {log} && {params.rm_old_xml}"
"""

rule bigstitcher_spark_stitching:
input:
dataset_n5=rules.zarr_to_bdv.output.bdv_n5,
dataset_xml=rules.zarr_to_bdv.output.bdv_xml,
params:
downsampling='--downsampling={dsx},{dsy},{dsz}'.format(dsx=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_x'],
dsy=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_y'],
dsz=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_z']),
min_r='--minR={min_r}'.format(min_r=config['bigstitcher']['filter_pairwise_shifts']['min_r']),
max_shift='--maxShiftTotal={max_shift}'.format(max_shift=config['bigstitcher']['filter_pairwise_shifts']['max_shift_total']),
mem_gb=lambda wildcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000))
output:
dataset_xml=temp(
bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcherstitching.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitcherstitching",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="bigstitcherproc",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="log.txt",
),
container:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"cp {input.dataset_xml} {output.dataset_xml} && "
"stitching {params.mem_gb} {threads} -x {output.dataset_xml} "
" {params.min_r} {params.downsampling} "


rule bigstitcher_spark_solver:
input:
dataset_n5=rules.zarr_to_bdv.output.bdv_n5,
dataset_xml=rules.bigstitcher_spark_stitching.output.dataset_xml,
params:
downsampling='--downsampling={dsx},{dsy},{dsz}'.format(dsx=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_x'],
dsy=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_y'],
dsz=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_z']),
method='--method={method}'.format(method=config['bigstitcher']['global_optimization']['method']),
mem_gb=lambda wildcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000))
output:
dataset_xml=temp(
bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitchersolver.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitchersolver",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="bigstitcherproc",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="log.txt",
),
container:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"cp {input.dataset_xml} {output.dataset_xml} && "
"solver {params.mem_gb} {threads} -x {output.dataset_xml} "
" -s STITCHING --lambda 0.1 " #lambda 0.1 is default (can expose this if needed)
" {params.method} "



rule fuse_dataset:
Expand All @@ -179,7 +294,7 @@ rule fuse_dataset:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcher{}.xml".format('solver' if config['bigstitcher']['global_optimization']['enabled'] else 'stitching'),
),
ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm",
params:
Expand Down Expand Up @@ -238,7 +353,7 @@ rule fuse_dataset:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=20000,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
Expand Down Expand Up @@ -266,7 +381,7 @@ rule fuse_dataset_spark:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcher{}.xml".format('solver' if config['bigstitcher']['global_optimization']['enabled'] else 'stitching'),
),
ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm",
params:
Expand All @@ -283,7 +398,7 @@ rule fuse_dataset_spark:
bsfy=config["bigstitcher"]["fuse_dataset"]["block_size_factor_y"],
bsfz=config["bigstitcher"]["fuse_dataset"]["block_size_factor_z"],
),
mem_gb=lambda wikdcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000))
mem_gb=lambda wildcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000))
output:
zarr=temp(
directory(
Expand Down Expand Up @@ -325,7 +440,7 @@ rule fuse_dataset_spark:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=20000,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
Expand Down
38 changes: 7 additions & 31 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def get_output_ome_zarr_uri():
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.ome.zarr",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
)
else:
return "local://" + _bids(
Expand All @@ -321,7 +321,7 @@ def get_output_ome_zarr_uri():
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.ome.zarr",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
)


Expand Down Expand Up @@ -368,35 +368,11 @@ def get_output_ome_zarr(acq_type):
}


def get_input_ome_zarr_to_nii():
if is_remote(config["root"]):
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
else:
if config["write_ome_zarr_direct"]:
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
else:
return bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix=f"SPIM.ome.zarr",
)
def get_output_ome_zarr_as_input(wildcards):
if 'blaze' in wildcards.acq:
return get_output_ome_zarr('blaze')
elif 'prestitched' in wildcards.acq:
return get_output_ome_zarr('prestitched')


def get_storage_creds():
Expand Down
8 changes: 7 additions & 1 deletion workflow/rules/ome_zarr.smk
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,14 @@ rule ome_zarr_to_zipstore:

rule ome_zarr_to_nii:
input:
zarr=bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr())),
**get_storage_creds(),
zarr=get_input_ome_zarr_to_nii(),
params:
channel_index=lambda wildcards: get_stains(wildcards).index(wildcards.stain),
uri=get_output_ome_zarr_uri(),
Expand Down
4 changes: 2 additions & 2 deletions workflow/rules/qc.smk
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ rule generate_flatfield_qc:
rule generate_whole_slice_qc:
"Generates an html file to view whole slices from preprocessed images"
input:
unpack(get_output_ome_zarr_as_input),
**get_storage_creds(),
ome=get_input_ome_zarr_to_nii(),
ws_html=config["report"]["resources"]["ws_html"],
params:
ws_s_start=config["report"]["whole_slice_viewer"]["slice_start"],
Expand Down Expand Up @@ -112,8 +112,8 @@ rule generate_whole_slice_qc:
rule generate_volume_qc:
"Generates an html file to view the volume rendered image"
input:
unpack(get_output_ome_zarr_as_input),
**get_storage_creds(),
ome=get_input_ome_zarr_to_nii(),
vol_viewer_dir=config["report"]["resources"]["vol_viewer_dir"],
params:
uri=get_output_ome_zarr_uri(),
Expand Down
2 changes: 1 addition & 1 deletion workflow/scripts/generate_volume_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
html_dest = snakemake.output.html

# inputted ome-zarr path
ome_data = snakemake.input.ome
ome_data = snakemake.input.zarr

# move volume renderer into the subjects directory
copy_tree(snakemake.input.vol_viewer_dir, resource_dir)
Expand Down
2 changes: 1 addition & 1 deletion workflow/scripts/generate_whole_slice_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ws_cmap=snakemake.params.ws_cmap

# input ome-zarr file
ome= snakemake.input.ome
ome= snakemake.input.zarr

# output paths
image_dir = snakemake.output.images_dir
Expand Down
14 changes: 13 additions & 1 deletion workflow/scripts/ome_zarr_to_nii.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,19 @@
fs_args={}

fs = get_fsspec(uri,**fs_args)
store = zarr.storage.FSStore(Path(uri).path,fs=fs,dimension_separator='/',mode='r')

print(fs)
print(uri)
print(Path(uri))
print(Path(uri).path)

print(Path(uri).suffix)
if Path(uri).suffix == '.zip':
store = zarr.storage.ZipStore(Path(uri).path,dimension_separator='/',mode='r')
else:
store = zarr.storage.FSStore(Path(uri).path,fs=fs,dimension_separator='/',mode='r')

print(store)
zi = zarr.open(store=store,mode='r')


Expand Down

0 comments on commit d1fa55c

Please sign in to comment.