From d1fa55ccbe384078b4e79e418f2e15be1946dcfc Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Wed, 25 Sep 2024 12:18:49 -0400 Subject: [PATCH] WIP spark --- config/config.yml | 31 ++--- workflow/rules/bigstitcher.smk | 127 +++++++++++++++++++- workflow/rules/common.smk | 38 ++---- workflow/rules/ome_zarr.smk | 8 +- workflow/rules/qc.smk | 4 +- workflow/scripts/generate_volume_qc.py | 2 +- workflow/scripts/generate_whole_slice_qc.py | 2 +- workflow/scripts/ome_zarr_to_nii.py | 14 ++- 8 files changed, 169 insertions(+), 57 deletions(-) diff --git a/config/config.yml b/config/config.yml index 3d6e0a7..6a5c684 100644 --- a/config/config.yml +++ b/config/config.yml @@ -6,7 +6,7 @@ work: 'work' remote_creds: '~/.config/gcloud/application_default_credentials.json' #this is needed so we can pass creds to container -write_ome_zarr_direct: True #use this to skip writing the final zarr output to work first and copying afterwards -- useful when work is not a fast local disk +write_ome_zarr_direct: False #use this to skip writing the final zarr output to work first and copying afterwards -- useful when work is not a fast local disk cores_per_rule: 32 @@ -42,28 +42,29 @@ bigstitcher: optical_flow: "Lucas-Kanade" filter_pairwise_shifts: enabled: 1 - min_r: 0.7 + min_r: 0.2 + max_shift_total: 50 global_optimization: - enabled: 1 - strategy: two_round - strategies: - one_round: "One-Round" - one_round_iterative: "One-Round with iterative dropping of bad links" - two_round: "Two-Round using metadata to align unconnected Tiles" - two_round_iterative: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links" + enabled: 1 + method: ONE_ROUND_SIMPLE + methods: + ONE_ROUND_SIMPLE: "One-Round" + ONE_ROUND_ITERATIVE: "One-Round with iterative dropping of bad links" + TWO_ROUND_SIMPLE: "Two-Round using metadata to align unconnected Tiles" + TWO_ROUND_ITERATIVE: "Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links" fuse_dataset: downsampling: 1 block_size_x: 256 # for storage block_size_y: 256 - block_size_z: 1 + block_size_z: 8 block_size_factor_x: 1 #e.g. 2 will use 2*block_size for computation block_size_factor_y: 1 - block_size_factor_z: 256 + block_size_factor_z: 32 ome_zarr: - desc: stitchedflatcorr + desc: sparkstitchedflatcorr max_downsampling_layers: 5 # e.g. 4 levels: { 0: orig, 1: ds2, 2: ds4, 3: ds8, 4: ds16} rechunk_size: #z, y, x - 1 @@ -99,7 +100,7 @@ ome_zarr: id: 0 name: spim version: "0.4" - use_zipstore: False #if True, produce SPIM.ome.zarr.zip instead of SPIM.ome.zarr + use_zipstore: True #if True, produce SPIM.ome.zarr.zip instead of SPIM.ome.zarr nifti: levels: #cannot be higher than max_downsampling_layers in ome_zarr @@ -154,5 +155,7 @@ report: containers: - spimprep: 'docker://khanlab/spimprep-deps:main' + #spimprep: 'docker://khanlab/spimprep-deps:main' + spimprep: '/local/scratch/spimprep-deps_spark.sif' + # updated.sif' diff --git a/workflow/rules/bigstitcher.smk b/workflow/rules/bigstitcher.smk index 5a47aa2..ebd2d49 100644 --- a/workflow/rules/bigstitcher.smk +++ b/workflow/rules/bigstitcher.smk @@ -93,7 +93,7 @@ rule zarr_to_bdv: script: "../scripts/zarr_to_n5_bdv.py" - +""" rule bigstitcher: input: dataset_n5=rules.zarr_to_bdv.output.bdv_n5, @@ -159,6 +159,121 @@ rule bigstitcher: " {params.fiji_launcher_cmd} && " " echo ' -macro {input.ijm} \"{params.macro_args}\"' >> {output.launcher} " " && {output.launcher} |& tee {log} && {params.rm_old_xml}" +""" + +rule bigstitcher_spark_stitching: + input: + dataset_n5=rules.zarr_to_bdv.output.bdv_n5, + dataset_xml=rules.zarr_to_bdv.output.bdv_xml, + params: + downsampling='--downsampling={dsx},{dsy},{dsz}'.format(dsx=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_x'], + dsy=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_y'], + dsz=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_z']), + min_r='--minR={min_r}'.format(min_r=config['bigstitcher']['filter_pairwise_shifts']['min_r']), + max_shift='--maxShiftTotal={max_shift}'.format(max_shift=config['bigstitcher']['filter_pairwise_shifts']['max_shift_total']), + mem_gb=lambda wildcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000)) + output: + dataset_xml=temp( + bids( + root=work, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq}", + desc="{desc}", + suffix="bigstitcherstitching.xml", + ) + ), + benchmark: + bids( + root="benchmarks", + datatype="bigstitcherstitching", + subject="{subject}", + sample="{sample}", + acq="{acq}", + desc="{desc}", + suffix="benchmark.tsv", + ) + log: + bids( + root="logs", + datatype="bigstitcherproc", + subject="{subject}", + sample="{sample}", + acq="{acq}", + desc="{desc}", + suffix="log.txt", + ), + container: + config["containers"]["spimprep"] + resources: + runtime=30, + mem_mb=40000, + threads: config["cores_per_rule"] + group: + "preproc" + shell: + "cp {input.dataset_xml} {output.dataset_xml} && " + "stitching {params.mem_gb} {threads} -x {output.dataset_xml} " + " {params.min_r} {params.downsampling} " + + +rule bigstitcher_spark_solver: + input: + dataset_n5=rules.zarr_to_bdv.output.bdv_n5, + dataset_xml=rules.bigstitcher_spark_stitching.output.dataset_xml, + params: + downsampling='--downsampling={dsx},{dsy},{dsz}'.format(dsx=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_x'], + dsy=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_y'], + dsz=config['bigstitcher']['calc_pairwise_shifts']['downsample_in_z']), + method='--method={method}'.format(method=config['bigstitcher']['global_optimization']['method']), + mem_gb=lambda wildcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000)) + output: + dataset_xml=temp( + bids( + root=work, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq}", + desc="{desc}", + suffix="bigstitchersolver.xml", + ) + ), + benchmark: + bids( + root="benchmarks", + datatype="bigstitchersolver", + subject="{subject}", + sample="{sample}", + acq="{acq}", + desc="{desc}", + suffix="benchmark.tsv", + ) + log: + bids( + root="logs", + datatype="bigstitcherproc", + subject="{subject}", + sample="{sample}", + acq="{acq}", + desc="{desc}", + suffix="log.txt", + ), + container: + config["containers"]["spimprep"] + resources: + runtime=30, + mem_mb=40000, + threads: config["cores_per_rule"] + group: + "preproc" + shell: + "cp {input.dataset_xml} {output.dataset_xml} && " + "solver {params.mem_gb} {threads} -x {output.dataset_xml} " + " -s STITCHING --lambda 0.1 " #lambda 0.1 is default (can expose this if needed) + " {params.method} " + rule fuse_dataset: @@ -179,7 +294,7 @@ rule fuse_dataset: sample="{sample}", acq="{acq}", desc="{desc}", - suffix="bigstitcher.xml", + suffix="bigstitcher{}.xml".format('solver' if config['bigstitcher']['global_optimization']['enabled'] else 'stitching'), ), ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm", params: @@ -238,7 +353,7 @@ rule fuse_dataset: config["containers"]["spimprep"] resources: runtime=30, - mem_mb=20000, + mem_mb=40000, threads: config["cores_per_rule"] group: "preproc" @@ -266,7 +381,7 @@ rule fuse_dataset_spark: sample="{sample}", acq="{acq}", desc="{desc}", - suffix="bigstitcher.xml", + suffix="bigstitcher{}.xml".format('solver' if config['bigstitcher']['global_optimization']['enabled'] else 'stitching'), ), ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm", params: @@ -283,7 +398,7 @@ rule fuse_dataset_spark: bsfy=config["bigstitcher"]["fuse_dataset"]["block_size_factor_y"], bsfz=config["bigstitcher"]["fuse_dataset"]["block_size_factor_z"], ), - mem_gb=lambda wikdcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000)) + mem_gb=lambda wildcards, resources: '{mem_gb}'.format(mem_gb=int(resources.mem_mb/1000)) output: zarr=temp( directory( @@ -325,7 +440,7 @@ rule fuse_dataset_spark: config["containers"]["spimprep"] resources: runtime=30, - mem_mb=20000, + mem_mb=40000, threads: config["cores_per_rule"] group: "preproc" diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index e6b7ad0..0248dd3 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -312,7 +312,7 @@ def get_output_ome_zarr_uri(): datatype="micr", sample="{sample}", acq="{acq}", - suffix="SPIM.ome.zarr", + suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()), ) else: return "local://" + _bids( @@ -321,7 +321,7 @@ def get_output_ome_zarr_uri(): datatype="micr", sample="{sample}", acq="{acq}", - suffix="SPIM.ome.zarr", + suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr()), ) @@ -368,35 +368,11 @@ def get_output_ome_zarr(acq_type): } -def get_input_ome_zarr_to_nii(): - if is_remote(config["root"]): - return bids( - root=root, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()), - ) - else: - if config["write_ome_zarr_direct"]: - return bids( - root=root, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()), - ) - else: - return bids( - root=work, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - suffix=f"SPIM.ome.zarr", - ) +def get_output_ome_zarr_as_input(wildcards): + if 'blaze' in wildcards.acq: + return get_output_ome_zarr('blaze') + elif 'prestitched' in wildcards.acq: + return get_output_ome_zarr('prestitched') def get_storage_creds(): diff --git a/workflow/rules/ome_zarr.smk b/workflow/rules/ome_zarr.smk index 105fb8d..77a58b3 100644 --- a/workflow/rules/ome_zarr.smk +++ b/workflow/rules/ome_zarr.smk @@ -117,8 +117,14 @@ rule ome_zarr_to_zipstore: rule ome_zarr_to_nii: input: + zarr=bids( + root=root, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq}", + suffix="SPIM.{ext}".format(ext=get_extension_ome_zarr())), **get_storage_creds(), - zarr=get_input_ome_zarr_to_nii(), params: channel_index=lambda wildcards: get_stains(wildcards).index(wildcards.stain), uri=get_output_ome_zarr_uri(), diff --git a/workflow/rules/qc.smk b/workflow/rules/qc.smk index cb44298..a79289f 100644 --- a/workflow/rules/qc.smk +++ b/workflow/rules/qc.smk @@ -73,8 +73,8 @@ rule generate_flatfield_qc: rule generate_whole_slice_qc: "Generates an html file to view whole slices from preprocessed images" input: + unpack(get_output_ome_zarr_as_input), **get_storage_creds(), - ome=get_input_ome_zarr_to_nii(), ws_html=config["report"]["resources"]["ws_html"], params: ws_s_start=config["report"]["whole_slice_viewer"]["slice_start"], @@ -112,8 +112,8 @@ rule generate_whole_slice_qc: rule generate_volume_qc: "Generates an html file to view the volume rendered image" input: + unpack(get_output_ome_zarr_as_input), **get_storage_creds(), - ome=get_input_ome_zarr_to_nii(), vol_viewer_dir=config["report"]["resources"]["vol_viewer_dir"], params: uri=get_output_ome_zarr_uri(), diff --git a/workflow/scripts/generate_volume_qc.py b/workflow/scripts/generate_volume_qc.py index 1fb1398..d5c48d1 100644 --- a/workflow/scripts/generate_volume_qc.py +++ b/workflow/scripts/generate_volume_qc.py @@ -15,7 +15,7 @@ html_dest = snakemake.output.html # inputted ome-zarr path -ome_data = snakemake.input.ome +ome_data = snakemake.input.zarr # move volume renderer into the subjects directory copy_tree(snakemake.input.vol_viewer_dir, resource_dir) diff --git a/workflow/scripts/generate_whole_slice_qc.py b/workflow/scripts/generate_whole_slice_qc.py index ed64b21..f721128 100644 --- a/workflow/scripts/generate_whole_slice_qc.py +++ b/workflow/scripts/generate_whole_slice_qc.py @@ -20,7 +20,7 @@ ws_cmap=snakemake.params.ws_cmap # input ome-zarr file -ome= snakemake.input.ome +ome= snakemake.input.zarr # output paths image_dir = snakemake.output.images_dir diff --git a/workflow/scripts/ome_zarr_to_nii.py b/workflow/scripts/ome_zarr_to_nii.py index db7ffa8..720f9b4 100644 --- a/workflow/scripts/ome_zarr_to_nii.py +++ b/workflow/scripts/ome_zarr_to_nii.py @@ -18,7 +18,19 @@ fs_args={} fs = get_fsspec(uri,**fs_args) -store = zarr.storage.FSStore(Path(uri).path,fs=fs,dimension_separator='/',mode='r') + +print(fs) +print(uri) +print(Path(uri)) +print(Path(uri).path) + +print(Path(uri).suffix) +if Path(uri).suffix == '.zip': + store = zarr.storage.ZipStore(Path(uri).path,dimension_separator='/',mode='r') +else: + store = zarr.storage.FSStore(Path(uri).path,fs=fs,dimension_separator='/',mode='r') + +print(store) zi = zarr.open(store=store,mode='r')