Skip to content

Commit

Permalink
Bigstitcher-spark (#52)
Browse files Browse the repository at this point in the history
Now using bigstitcher-spark exclusively instead of fiji for stitching, global optimization, and fusion.
- also fixes a ome_zarr_to_nii bug when dealing with ome.zarr.zip (though the use of zarr.zip is untested with gcs outputs)
  • Loading branch information
akhanf authored Sep 26, 2024
1 parent 22e218c commit 590fc4c
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 81 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ or for snakemake<8.0, use:
snakemake -c all --use-singularity
```

Note: if you run the workflow on a system with large memory, you will need to set the heap size for the stitching and fusion rules. This can be done with e.g.: `--set-resources bigstitcher:mem_mb=60000 fuse_dataset:mem_mb=100000`
Note: if you run the workflow on a system with large memory, you will need to set the heap size for the stitching and fusion rules. This can be done with e.g.: `--set-resources bigstitcher_spark_stitching:mem_mb=60000 bigstitcher_spark_fusion:mem_mb=100000`

7. If you want to run the workflow using a batch job submission server, please see the executor plugins here: https://snakemake.github.io/snakemake-plugin-catalog/

Expand Down
31 changes: 16 additions & 15 deletions config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,35 @@ bigstitcher:
downsample_in_x: 4
downsample_in_y: 4
downsample_in_z: 1
method: "phase_corr"
methods:
phase_corr: "Phase Correlation"
optical_flow: "Lucas-Kanade"
method: "phase_corr" #unused
methods: #unused
phase_corr: "Phase Correlation"
optical_flow: "Lucas-Kanade"
filter_pairwise_shifts:
enabled: 1
enabled: 1 #unused
min_r: 0.7
max_shift_total: 50
global_optimization:
enabled: 1
strategy: two_round
strategies:
one_round: "One-Round"
one_round_iterative: "One-Round with iterative dropping of bad links"
two_round: "Two-Round using metadata to align unconnected Tiles"
two_round_iterative: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links"
enabled: 1
method: TWO_ROUND_ITERATIVE
methods: #unused, only for reference
ONE_ROUND_SIMPLE: "One-Round"
ONE_ROUND_ITERATIVE: "One-Round with iterative dropping of bad links"
TWO_ROUND_SIMPLE: "Two-Round using metadata to align unconnected Tiles"
TWO_ROUND_ITERATIVE: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links"


fuse_dataset:
downsampling: 1
block_size_x: 256 # for storage
block_size_y: 256
block_size_z: 1
block_size_z: 8
block_size_factor_x: 1 #e.g. 2 will use 2*block_size for computation
block_size_factor_y: 1
block_size_factor_z: 256
block_size_factor_z: 32

ome_zarr:
desc: stitchedflatcorr
desc: sparkstitchedflatcorr
max_downsampling_layers: 5 # e.g. 4 levels: { 0: orig, 1: ds2, 2: ds4, 3: ds8, 4: ds16}
rechunk_size: #z, y, x
- 1
Expand Down
2 changes: 1 addition & 1 deletion workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ rule all:
input:
get_all_targets(),
get_bids_toplevel_targets(),
# get_qc_targets(), #need to skip this if using prestitched
get_qc_targets(), #need to skip this if using prestitched
localrule: True


Expand Down
185 changes: 174 additions & 11 deletions workflow/rules/bigstitcher.smk
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ rule bigstitcher:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcherproc.sh",
suffix="bigstitcherfiji.sh",
)
),
dataset_xml=temp(
Expand All @@ -123,13 +123,13 @@ rule bigstitcher:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcherfiji.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitcherproc",
datatype="bigstitcherfiji",
subject="{subject}",
sample="{sample}",
acq="{acq}",
Expand All @@ -139,7 +139,7 @@ rule bigstitcher:
log:
bids(
root="logs",
datatype="bigstitcherproc",
datatype="bigstitcherfiji",
subject="{subject}",
sample="{sample}",
acq="{acq}",
Expand All @@ -161,7 +161,142 @@ rule bigstitcher:
" && {output.launcher} |& tee {log} && {params.rm_old_xml}"


rule fuse_dataset:
rule bigstitcher_spark_stitching:
input:
dataset_n5=rules.zarr_to_bdv.output.bdv_n5,
dataset_xml=rules.zarr_to_bdv.output.bdv_xml,
params:
downsampling="--downsampling={dsx},{dsy},{dsz}".format(
dsx=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_x"],
dsy=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_y"],
dsz=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_z"],
),
min_r="--minR={min_r}".format(
min_r=config["bigstitcher"]["filter_pairwise_shifts"]["min_r"]
),
max_shift="--maxShiftTotal={max_shift}".format(
max_shift=config["bigstitcher"]["filter_pairwise_shifts"][
"max_shift_total"
]
),
mem_gb=lambda wildcards, resources: "{mem_gb}".format(
mem_gb=int(resources.mem_mb / 1000)
),
rm_old_xml=lambda wildcards, output: f"rm -f {output.dataset_xml}~?",
output:
dataset_xml=temp(
bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcherstitching.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitcherstitching",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="bigstitcherproc",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="log.txt",
),
container:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"cp {input.dataset_xml} {output.dataset_xml} && "
"stitching {params.mem_gb} {threads} -x {output.dataset_xml} "
" {params.min_r} {params.downsampling} {params.max_shift} && "
"{params.rm_old_xml}"


rule bigstitcher_spark_solver:
input:
dataset_n5=rules.zarr_to_bdv.output.bdv_n5,
dataset_xml=rules.bigstitcher_spark_stitching.output.dataset_xml,
params:
downsampling="--downsampling={dsx},{dsy},{dsz}".format(
dsx=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_x"],
dsy=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_y"],
dsz=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_z"],
),
method="--method={method}".format(
method=config["bigstitcher"]["global_optimization"]["method"]
),
mem_gb=lambda wildcards, resources: "{mem_gb}".format(
mem_gb=int(resources.mem_mb / 1000)
),
rm_old_xml=lambda wildcards, output: f"rm -f {output.dataset_xml}~?",
output:
dataset_xml=temp(
bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitchersolver.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitchersolver",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="bigstitcherproc",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="log.txt",
),
container:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"cp {input.dataset_xml} {output.dataset_xml} && "
"solver {params.mem_gb} {threads} -x {output.dataset_xml} "
" -s STITCHING --lambda 0.1 "
" {params.method} && "
"{params.rm_old_xml}"
#lambda 0.1 is default (can expose this if needed)


rule bigstitcher_fusion:
input:
dataset_n5=bids(
root=work,
Expand All @@ -179,7 +314,11 @@ rule fuse_dataset:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcher{}.xml".format(
"solver"
if config["bigstitcher"]["global_optimization"]["enabled"]
else "stitching"
),
),
ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm",
params:
Expand Down Expand Up @@ -238,7 +377,7 @@ rule fuse_dataset:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=20000,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
Expand All @@ -248,7 +387,7 @@ rule fuse_dataset:
" && {output.launcher} |& tee {log}"


rule fuse_dataset_spark:
rule bigstitcher_spark_fusion:
input:
dataset_n5=bids(
root=work,
Expand All @@ -266,9 +405,30 @@ rule fuse_dataset_spark:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcher{}.xml".format(
"solver"
if config["bigstitcher"]["global_optimization"]["enabled"]
else "stitching"
),
),
ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm",
params:
channel=lambda wildcards: "--channelId={channel}".format(
channel=get_stains(wildcards).index(wildcards.stain)
),
block_size="--blockSize={bsx},{bsy},{bsz}".format(
bsx=config["bigstitcher"]["fuse_dataset"]["block_size_x"],
bsy=config["bigstitcher"]["fuse_dataset"]["block_size_y"],
bsz=config["bigstitcher"]["fuse_dataset"]["block_size_z"],
),
block_size_factor="--blockScale={bsfx},{bsfy},{bsfz}".format(
bsfx=config["bigstitcher"]["fuse_dataset"]["block_size_factor_x"],
bsfy=config["bigstitcher"]["fuse_dataset"]["block_size_factor_y"],
bsfz=config["bigstitcher"]["fuse_dataset"]["block_size_factor_z"],
),
mem_gb=lambda wildcards, resources: "{mem_gb}".format(
mem_gb=int(resources.mem_mb / 1000)
),
output:
zarr=temp(
directory(
Expand Down Expand Up @@ -310,9 +470,12 @@ rule fuse_dataset_spark:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=20000,
mem_mb=30000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"affine-fusion ..."
"affine-fusion {params.mem_gb} {threads} --preserveAnisotropy -x {input.dataset_xml} "
" -o {output.zarr} -d /fused/s0 -s ZARR "
" --UINT16 --minIntensity 0 --maxIntensity 65535 "
"{params.block_size} {params.block_size_factor} {params.channel}"
35 changes: 2 additions & 33 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def get_output_ome_zarr_uri():
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.ome.zarr",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
)
else:
return "local://" + _bids(
Expand All @@ -321,7 +321,7 @@ def get_output_ome_zarr_uri():
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.ome.zarr",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
)


Expand Down Expand Up @@ -368,37 +368,6 @@ def get_output_ome_zarr(acq_type):
}


def get_input_ome_zarr_to_nii():
if is_remote(config["root"]):
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
else:
if config["write_ome_zarr_direct"]:
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
else:
return bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix=f"SPIM.ome.zarr",
)


def get_storage_creds():
"""for rules that deal with remote storage directly"""
protocol = Path(config["root"]).protocol
Expand Down
9 changes: 8 additions & 1 deletion workflow/rules/ome_zarr.smk
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ rule ome_zarr_to_zipstore:
rule ome_zarr_to_nii:
input:
**get_storage_creds(),
zarr=get_input_ome_zarr_to_nii(),
zarr=bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
),
params:
channel_index=lambda wildcards: get_stains(wildcards).index(wildcards.stain),
uri=get_output_ome_zarr_uri(),
Expand Down
Loading

0 comments on commit 590fc4c

Please sign in to comment.