Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bigstitcher-spark #52

Merged
merged 9 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ or for snakemake<8.0, use:
snakemake -c all --use-singularity
```

Note: if you run the workflow on a system with large memory, you will need to set the heap size for the stitching and fusion rules. This can be done with e.g.: `--set-resources bigstitcher:mem_mb=60000 fuse_dataset:mem_mb=100000`
Note: if you run the workflow on a system with large memory, you will need to set the heap size for the stitching and fusion rules. This can be done with e.g.: `--set-resources bigstitcher_spark_stitching:mem_mb=60000 bigstitcher_spark_fusion:mem_mb=100000`

7. If you want to run the workflow using a batch job submission server, please see the executor plugins here: https://snakemake.github.io/snakemake-plugin-catalog/

Expand Down
31 changes: 16 additions & 15 deletions config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,35 @@ bigstitcher:
downsample_in_x: 4
downsample_in_y: 4
downsample_in_z: 1
method: "phase_corr"
methods:
phase_corr: "Phase Correlation"
optical_flow: "Lucas-Kanade"
method: "phase_corr" #unused
methods: #unused
phase_corr: "Phase Correlation"
optical_flow: "Lucas-Kanade"
filter_pairwise_shifts:
enabled: 1
enabled: 1 #unused
min_r: 0.7
max_shift_total: 50
global_optimization:
enabled: 1
strategy: two_round
strategies:
one_round: "One-Round"
one_round_iterative: "One-Round with iterative dropping of bad links"
two_round: "Two-Round using metadata to align unconnected Tiles"
two_round_iterative: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links"
enabled: 1
method: TWO_ROUND_ITERATIVE
methods: #unused, only for reference
ONE_ROUND_SIMPLE: "One-Round"
ONE_ROUND_ITERATIVE: "One-Round with iterative dropping of bad links"
TWO_ROUND_SIMPLE: "Two-Round using metadata to align unconnected Tiles"
TWO_ROUND_ITERATIVE: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links"


fuse_dataset:
downsampling: 1
block_size_x: 256 # for storage
block_size_y: 256
block_size_z: 1
block_size_z: 8
block_size_factor_x: 1 #e.g. 2 will use 2*block_size for computation
block_size_factor_y: 1
block_size_factor_z: 256
block_size_factor_z: 32

ome_zarr:
desc: stitchedflatcorr
desc: sparkstitchedflatcorr
max_downsampling_layers: 5 # e.g. 4 levels: { 0: orig, 1: ds2, 2: ds4, 3: ds8, 4: ds16}
rechunk_size: #z, y, x
- 1
Expand Down
2 changes: 1 addition & 1 deletion workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ rule all:
input:
get_all_targets(),
get_bids_toplevel_targets(),
# get_qc_targets(), #need to skip this if using prestitched
get_qc_targets(), #need to skip this if using prestitched
localrule: True


Expand Down
185 changes: 174 additions & 11 deletions workflow/rules/bigstitcher.smk
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ rule bigstitcher:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcherproc.sh",
suffix="bigstitcherfiji.sh",
)
),
dataset_xml=temp(
Expand All @@ -123,13 +123,13 @@ rule bigstitcher:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcherfiji.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitcherproc",
datatype="bigstitcherfiji",
subject="{subject}",
sample="{sample}",
acq="{acq}",
Expand All @@ -139,7 +139,7 @@ rule bigstitcher:
log:
bids(
root="logs",
datatype="bigstitcherproc",
datatype="bigstitcherfiji",
subject="{subject}",
sample="{sample}",
acq="{acq}",
Expand All @@ -161,7 +161,142 @@ rule bigstitcher:
" && {output.launcher} |& tee {log} && {params.rm_old_xml}"


rule fuse_dataset:
rule bigstitcher_spark_stitching:
input:
dataset_n5=rules.zarr_to_bdv.output.bdv_n5,
dataset_xml=rules.zarr_to_bdv.output.bdv_xml,
params:
downsampling="--downsampling={dsx},{dsy},{dsz}".format(
dsx=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_x"],
dsy=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_y"],
dsz=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_z"],
),
min_r="--minR={min_r}".format(
min_r=config["bigstitcher"]["filter_pairwise_shifts"]["min_r"]
),
max_shift="--maxShiftTotal={max_shift}".format(
max_shift=config["bigstitcher"]["filter_pairwise_shifts"][
"max_shift_total"
]
),
mem_gb=lambda wildcards, resources: "{mem_gb}".format(
mem_gb=int(resources.mem_mb / 1000)
),
rm_old_xml=lambda wildcards, output: f"rm -f {output.dataset_xml}~?",
output:
dataset_xml=temp(
bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcherstitching.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitcherstitching",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="bigstitcherproc",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="log.txt",
),
container:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"cp {input.dataset_xml} {output.dataset_xml} && "
"stitching {params.mem_gb} {threads} -x {output.dataset_xml} "
" {params.min_r} {params.downsampling} {params.max_shift} && "
"{params.rm_old_xml}"


rule bigstitcher_spark_solver:
input:
dataset_n5=rules.zarr_to_bdv.output.bdv_n5,
dataset_xml=rules.bigstitcher_spark_stitching.output.dataset_xml,
params:
downsampling="--downsampling={dsx},{dsy},{dsz}".format(
dsx=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_x"],
dsy=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_y"],
dsz=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_z"],
),
method="--method={method}".format(
method=config["bigstitcher"]["global_optimization"]["method"]
),
mem_gb=lambda wildcards, resources: "{mem_gb}".format(
mem_gb=int(resources.mem_mb / 1000)
),
rm_old_xml=lambda wildcards, output: f"rm -f {output.dataset_xml}~?",
output:
dataset_xml=temp(
bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitchersolver.xml",
)
),
benchmark:
bids(
root="benchmarks",
datatype="bigstitchersolver",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="bigstitcherproc",
subject="{subject}",
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="log.txt",
),
container:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"cp {input.dataset_xml} {output.dataset_xml} && "
"solver {params.mem_gb} {threads} -x {output.dataset_xml} "
" -s STITCHING --lambda 0.1 "
" {params.method} && "
"{params.rm_old_xml}"
#lambda 0.1 is default (can expose this if needed)


rule bigstitcher_fusion:
input:
dataset_n5=bids(
root=work,
Expand All @@ -179,7 +314,11 @@ rule fuse_dataset:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcher{}.xml".format(
"solver"
if config["bigstitcher"]["global_optimization"]["enabled"]
else "stitching"
),
),
ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm",
params:
Expand Down Expand Up @@ -238,7 +377,7 @@ rule fuse_dataset:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=20000,
mem_mb=40000,
threads: config["cores_per_rule"]
group:
"preproc"
Expand All @@ -248,7 +387,7 @@ rule fuse_dataset:
" && {output.launcher} |& tee {log}"


rule fuse_dataset_spark:
rule bigstitcher_spark_fusion:
input:
dataset_n5=bids(
root=work,
Expand All @@ -266,9 +405,30 @@ rule fuse_dataset_spark:
sample="{sample}",
acq="{acq}",
desc="{desc}",
suffix="bigstitcher.xml",
suffix="bigstitcher{}.xml".format(
"solver"
if config["bigstitcher"]["global_optimization"]["enabled"]
else "stitching"
),
),
ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm",
params:
channel=lambda wildcards: "--channelId={channel}".format(
channel=get_stains(wildcards).index(wildcards.stain)
),
block_size="--blockSize={bsx},{bsy},{bsz}".format(
bsx=config["bigstitcher"]["fuse_dataset"]["block_size_x"],
bsy=config["bigstitcher"]["fuse_dataset"]["block_size_y"],
bsz=config["bigstitcher"]["fuse_dataset"]["block_size_z"],
),
block_size_factor="--blockScale={bsfx},{bsfy},{bsfz}".format(
bsfx=config["bigstitcher"]["fuse_dataset"]["block_size_factor_x"],
bsfy=config["bigstitcher"]["fuse_dataset"]["block_size_factor_y"],
bsfz=config["bigstitcher"]["fuse_dataset"]["block_size_factor_z"],
),
mem_gb=lambda wildcards, resources: "{mem_gb}".format(
mem_gb=int(resources.mem_mb / 1000)
),
output:
zarr=temp(
directory(
Expand Down Expand Up @@ -310,9 +470,12 @@ rule fuse_dataset_spark:
config["containers"]["spimprep"]
resources:
runtime=30,
mem_mb=20000,
mem_mb=30000,
threads: config["cores_per_rule"]
group:
"preproc"
shell:
"affine-fusion ..."
"affine-fusion {params.mem_gb} {threads} --preserveAnisotropy -x {input.dataset_xml} "
" -o {output.zarr} -d /fused/s0 -s ZARR "
" --UINT16 --minIntensity 0 --maxIntensity 65535 "
"{params.block_size} {params.block_size_factor} {params.channel}"
35 changes: 2 additions & 33 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def get_output_ome_zarr_uri():
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.ome.zarr",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
)
else:
return "local://" + _bids(
Expand All @@ -321,7 +321,7 @@ def get_output_ome_zarr_uri():
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.ome.zarr",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
)


Expand Down Expand Up @@ -368,37 +368,6 @@ def get_output_ome_zarr(acq_type):
}


def get_input_ome_zarr_to_nii():
if is_remote(config["root"]):
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
else:
if config["write_ome_zarr_direct"]:
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
else:
return bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix=f"SPIM.ome.zarr",
)


def get_storage_creds():
"""for rules that deal with remote storage directly"""
protocol = Path(config["root"]).protocol
Expand Down
9 changes: 8 additions & 1 deletion workflow/rules/ome_zarr.smk
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,14 @@ rule ome_zarr_to_zipstore:
rule ome_zarr_to_nii:
input:
**get_storage_creds(),
zarr=get_input_ome_zarr_to_nii(),
zarr=bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()),
),
params:
channel_index=lambda wildcards: get_stains(wildcards).index(wildcards.stain),
uri=get_output_ome_zarr_uri(),
Expand Down
Loading