From a11d11a3f6f85f43e7b5bc0189ce65fa72ef9b06 Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Wed, 14 Feb 2024 11:49:56 -0500 Subject: [PATCH] fixes and refactoring for linting and snakefmt --- workflow/Snakefile | 86 +--------------- workflow/rules/atlasreg.smk | 31 +++--- workflow/rules/bigstitcher.smk | 57 +++-------- workflow/rules/common.smk | 157 ++++++++++++++++++++++++++++++ workflow/rules/flatfield_corr.smk | 25 ++++- workflow/rules/import.smk | 54 +++++----- workflow/rules/ome_zarr.smk | 45 ++++----- 7 files changed, 263 insertions(+), 192 deletions(-) create mode 100644 workflow/rules/common.smk diff --git a/workflow/Snakefile b/workflow/Snakefile index eeb58e5..57cf784 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -15,7 +15,7 @@ container: config["containers"]["spimprep"] root = os.path.expandvars(config["root"]) work = os.path.expandvars(config["work"]) -#this is needed to use the latest bids spec with the pre-release snakebids +# this is needed to use the latest bids spec with the pre-release snakebids set_bids_spec("v0_10_1") # read datasets tsv @@ -33,89 +33,7 @@ datasets = pd.read_csv( ) -def get_all_targets(): - targets = [] - for i in range(len(datasets)): - targets.extend( - expand( - bids( - root=root, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - desc="{desc}", - stain="{stain}", - suffix="spim.ome.zarr.zip", - ), - subject=datasets.loc[i, "subject"], - sample=datasets.loc[i, "sample"], - acq=datasets.loc[i, "acq"], - desc=config["targets"]["desc"], - stain=[datasets.loc[i, "stain_0"], datasets.loc[i, "stain_1"]], - ) - ) - targets.extend( - expand( - bids( - root=root, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - desc="{desc}", - from_="{template}", - suffix="dseg.ome.zarr.zip", - ), - subject=datasets.loc[i, "subject"], - sample=datasets.loc[i, "sample"], - acq=datasets.loc[i, "acq"], - desc=config["targets"]["desc"], - template=config["templates"], - stain=[datasets.loc[i, "stain_0"], datasets.loc[i, "stain_1"]], - ) - ) - targets.extend( - expand( - bids( - root=root, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - desc="{desc}", - stain="{stain}", - level="{level}", - suffix="spim.nii", - ), - subject=datasets.loc[i, "subject"], - sample=datasets.loc[i, "sample"], - acq=datasets.loc[i, "acq"], - desc=config["targets"]["desc"], - level=config["nifti"]["levels"], - stain=[datasets.loc[i, "stain_0"], datasets.loc[i, "stain_1"]], - ) - ) - - return targets - - -def get_dataset_path(wildcards): - df = datasets.query( - f"subject=='{wildcards.subject}' and sample=='{wildcards.sample}' and acq=='{wildcards.acq}'" - ) - return df.dataset_path.to_list()[0] - - -def get_stains(wildcards): - df = datasets.query( - f"subject=='{wildcards.subject}' and sample=='{wildcards.sample}' and acq=='{wildcards.acq}'" - ) - - return [ - df.stain_0.to_list()[0], - df.stain_1.to_list()[0], - ] +include: "rules/common.smk" rule all: diff --git a/workflow/rules/atlasreg.smk b/workflow/rules/atlasreg.smk index 5124e8a..11a9766 100644 --- a/workflow/rules/atlasreg.smk +++ b/workflow/rules/atlasreg.smk @@ -1,7 +1,3 @@ -def bids_tpl(root, template, **entities): - """bids() wrapper for files in tpl-template folder""" - return str(Path(bids(root=root, tpl=template)) / bids(tpl=template, **entities)) - rule import_anat: input: @@ -9,7 +5,12 @@ rule import_anat: output: anat=bids_tpl(root=root, template="{template}", suffix="anat.nii.gz"), log: - bids_tpl(root='logs',datatype="import_anat",template="{template}", suffix="log.txt") + bids_tpl( + root="logs", + datatype="import_anat", + template="{template}", + suffix="log.txt", + ), shell: "cp {input} {output}" @@ -20,7 +21,12 @@ rule import_dseg: output: dseg=bids_tpl(root=root, template="{template}", suffix="dseg.nii.gz"), log: - bids_tpl(root='logs',datatype="import_dseg",template="{template}", suffix="log.txt") + bids_tpl( + root="logs", + datatype="import_dseg", + template="{template}", + suffix="log.txt", + ), shell: "cp {input} {output}" @@ -31,7 +37,9 @@ rule import_lut: output: tsv=bids_tpl(root=root, template="{template}", suffix="dseg.tsv"), log: - bids_tpl(root='logs',datatype="import_lut",template="{template}", suffix="log.txt") + bids_tpl( + root="logs", datatype="import_lut", template="{template}", suffix="log.txt" + ), script: "../scripts/import_labelmapper_lut.py" @@ -75,7 +83,7 @@ rule affine_reg: ), log: bids( - root='logs', + root="logs", subject="{subject}", datatype="affine_reg", sample="{sample}", @@ -129,7 +137,7 @@ rule deform_reg: ), log: bids( - root='logs', + root="logs", subject="{subject}", datatype="deform_reg", sample="{sample}", @@ -178,7 +186,7 @@ rule resample_labels_to_zarr: ), log: bids( - root='logs', + root="logs", subject="{subject}", datatype="resample_labels_to_zarr", sample="{sample}", @@ -187,7 +195,6 @@ rule resample_labels_to_zarr: space="{template}", suffix="log.txt", ), - script: "../scripts/resample_labels_to_zarr.py" @@ -234,7 +241,7 @@ rule zarr_to_ome_zarr_labels: "preproc" log: bids( - root='logs', + root="logs", subject="{subject}", datatype="zarr_to_ome_zarr_labels", sample="{sample}", diff --git a/workflow/rules/bigstitcher.smk b/workflow/rules/bigstitcher.smk index 18847a2..2ed066a 100644 --- a/workflow/rules/bigstitcher.smk +++ b/workflow/rules/bigstitcher.smk @@ -11,7 +11,7 @@ rule zarr_to_bdv: ), metadata_json=rules.raw_to_metadata.output.metadata_json, params: - max_downsampling_layers=5, #1,2,4,8,16 + max_downsampling_layers=5, temp_h5=str( Path( bids( @@ -26,7 +26,6 @@ rule zarr_to_bdv: ) / "dataset.h5" ), - #only temporary, is promptly deleted temp_xml=str( Path( bids( @@ -41,7 +40,6 @@ rule zarr_to_bdv: ) / "dataset.xml" ), - #only temporary, is promptly deleted output: bdv_n5=temp( directory( @@ -77,6 +75,16 @@ rule zarr_to_bdv: desc="{desc}", suffix="benchmark.tsv", ) + log: + bids( + root="logs", + datatype="zarr_to_n5_bdv", + subject="{subject}", + sample="{sample}", + acq="{acq}", + desc="{desc}", + suffix="log.txt", + ), threads: 32 group: "preproc" @@ -86,29 +94,6 @@ rule zarr_to_bdv: "../scripts/zarr_to_n5_bdv.py" -def get_fiji_launcher_cmd(wildcards, output, threads, resources): - launcher_opts_find = "-Xincgc" - launcher_opts_replace = f"-XX:+UseG1GC -verbose:gc -XX:+PrintGCDateStamps -XX:ActiveProcessorCount={threads}" - pipe_cmds = [] - pipe_cmds.append("ImageJ-linux64 --dry-run --headless --console") - pipe_cmds.append(f"sed 's/{launcher_opts_find}/{launcher_opts_replace}'/") - pipe_cmds.append( - f"sed 's/-Xmx[0-9a-z]\+/-Xmx{resources.mem_mb}m -Xms{resources.mem_mb}m/'" - ) - pipe_cmds.append("tr --delete '\\n'") - return "|".join(pipe_cmds) + f" > {output.launcher} && chmod a+x {output.launcher} " - - -def get_macro_args_bigstitcher(wildcards, input, output): - return "{dataset_xml} {ds_x} {ds_y} {ds_z} {min_r}".format( - dataset_xml=input.dataset_xml, - ds_x=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_x"], - ds_y=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_y"], - ds_z=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_z"], - min_r=config["bigstitcher"]["filter_pairwise_shifts"]["min_r"], - ) - - rule bigstitcher: input: dataset_n5=rules.zarr_to_bdv.output.bdv_n5, @@ -164,7 +149,7 @@ rule bigstitcher: container: config["containers"]["spimprep"] resources: - runtime=30, #this should be proportional to the number of images and image size + runtime=30, mem_mb=10000, threads: 32 group: @@ -176,21 +161,6 @@ rule bigstitcher: " && {output.launcher} &> {log} && {params.rm_old_xml}" -def get_macro_args_zarr_fusion(wildcards, input, output): - return "{dataset_xml} {downsampling} {channel:02d} {output_zarr} {bsx} {bsy} {bsz} {bsfx} {bsfy} {bsfz}".format( - dataset_xml=input.dataset_xml, - downsampling=config["bigstitcher"]["fuse_dataset"]["downsampling"], - channel=get_stains(wildcards).index(wildcards.stain), - output_zarr=output.zarr, - bsx=config["bigstitcher"]["fuse_dataset"]["block_size_x"], - bsy=config["bigstitcher"]["fuse_dataset"]["block_size_y"], - bsz=config["bigstitcher"]["fuse_dataset"]["block_size_z"], - bsfx=config["bigstitcher"]["fuse_dataset"]["block_size_factor_x"], - bsfy=config["bigstitcher"]["fuse_dataset"]["block_size_factor_y"], - bsfz=config["bigstitcher"]["fuse_dataset"]["block_size_factor_z"], - ) - - rule fuse_dataset: input: dataset_n5=bids( @@ -299,7 +269,6 @@ rule fuse_dataset_spark: suffix="bigstitcher.xml", ), ijm=Path(workflow.basedir) / "macros" / "FuseImageMacroZarr.ijm", - params: output: zarr=temp( directory( @@ -347,5 +316,3 @@ rule fuse_dataset_spark: "preproc" shell: "affine-fusion ..." - - diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk new file mode 100644 index 0000000..08623ba --- /dev/null +++ b/workflow/rules/common.smk @@ -0,0 +1,157 @@ +# targets +def get_all_targets(): + targets = [] + for i in range(len(datasets)): + targets.extend( + expand( + bids( + root=root, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq}", + desc="{desc}", + stain="{stain}", + suffix="spim.ome.zarr.zip", + ), + subject=datasets.loc[i, "subject"], + sample=datasets.loc[i, "sample"], + acq=datasets.loc[i, "acq"], + desc=config["targets"]["desc"], + stain=[datasets.loc[i, "stain_0"], datasets.loc[i, "stain_1"]], + ) + ) + targets.extend( + expand( + bids( + root=root, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq}", + desc="{desc}", + from_="{template}", + suffix="dseg.ome.zarr.zip", + ), + subject=datasets.loc[i, "subject"], + sample=datasets.loc[i, "sample"], + acq=datasets.loc[i, "acq"], + desc=config["targets"]["desc"], + template=config["templates"], + stain=[datasets.loc[i, "stain_0"], datasets.loc[i, "stain_1"]], + ) + ) + targets.extend( + expand( + bids( + root=root, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq}", + desc="{desc}", + stain="{stain}", + level="{level}", + suffix="spim.nii", + ), + subject=datasets.loc[i, "subject"], + sample=datasets.loc[i, "sample"], + acq=datasets.loc[i, "acq"], + desc=config["targets"]["desc"], + level=config["nifti"]["levels"], + stain=[datasets.loc[i, "stain_0"], datasets.loc[i, "stain_1"]], + ) + ) + + return targets + + +# import +def cmd_get_dataset(wildcards, input, output): + cmds = [] + import tarfile + + # supports tar, tar.gz, tgz, zip, or folder name + dataset_path = Path(input.dataset_path) + suffix = dataset_path.suffix + if dataset_path.is_dir(): + # we have a directory: + # return command to copy folder + cmds.append(f"ln -sr {input} {output}") + + elif tarfile.is_tarfile(dataset_path): + # we have a tar file + # check if gzipped: + cmds.append(f"mkdir -p {output}") + if suffix == "gz" or suffix == "tgz": + cmds.append(f"tar -xzf {input} -C {output}") + else: + cmds.append(f"tar -xf {input} -C {output}") + + else: + print(f"unsupported input: {dataset_path}") + + return " && ".join(cmds) + + +def get_dataset_path(wildcards): + df = datasets.query( + f"subject=='{wildcards.subject}' and sample=='{wildcards.sample}' and acq=='{wildcards.acq}'" + ) + return df.dataset_path.to_list()[0] + + +def get_stains(wildcards): + df = datasets.query( + f"subject=='{wildcards.subject}' and sample=='{wildcards.sample}' and acq=='{wildcards.acq}'" + ) + + return [ + df.stain_0.to_list()[0], + df.stain_1.to_list()[0], + ] + + +# bids +def bids_tpl(root, template, **entities): + """bids() wrapper for files in tpl-template folder""" + return str(Path(bids(root=root, tpl=template)) / bids(tpl=template, **entities)) + + +# bigstitcher +def get_fiji_launcher_cmd(wildcards, output, threads, resources): + launcher_opts_find = "-Xincgc" + launcher_opts_replace = f"-XX:+UseG1GC -verbose:gc -XX:+PrintGCDateStamps -XX:ActiveProcessorCount={threads}" + pipe_cmds = [] + pipe_cmds.append("ImageJ-linux64 --dry-run --headless --console") + pipe_cmds.append(f"sed 's/{launcher_opts_find}/{launcher_opts_replace}'/") + pipe_cmds.append( + f"sed 's/-Xmx[0-9a-z]\+/-Xmx{resources.mem_mb}m -Xms{resources.mem_mb}m/'" + ) + pipe_cmds.append("tr --delete '\\n'") + return "|".join(pipe_cmds) + f" > {output.launcher} && chmod a+x {output.launcher} " + + +def get_macro_args_bigstitcher(wildcards, input, output): + return "{dataset_xml} {ds_x} {ds_y} {ds_z} {min_r}".format( + dataset_xml=input.dataset_xml, + ds_x=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_x"], + ds_y=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_y"], + ds_z=config["bigstitcher"]["calc_pairwise_shifts"]["downsample_in_z"], + min_r=config["bigstitcher"]["filter_pairwise_shifts"]["min_r"], + ) + + +def get_macro_args_zarr_fusion(wildcards, input, output): + return "{dataset_xml} {downsampling} {channel:02d} {output_zarr} {bsx} {bsy} {bsz} {bsfx} {bsfy} {bsfz}".format( + dataset_xml=input.dataset_xml, + downsampling=config["bigstitcher"]["fuse_dataset"]["downsampling"], + channel=get_stains(wildcards).index(wildcards.stain), + output_zarr=output.zarr, + bsx=config["bigstitcher"]["fuse_dataset"]["block_size_x"], + bsy=config["bigstitcher"]["fuse_dataset"]["block_size_y"], + bsz=config["bigstitcher"]["fuse_dataset"]["block_size_z"], + bsfx=config["bigstitcher"]["fuse_dataset"]["block_size_factor_x"], + bsfy=config["bigstitcher"]["fuse_dataset"]["block_size_factor_y"], + bsfz=config["bigstitcher"]["fuse_dataset"]["block_size_factor_z"], + ) diff --git a/workflow/rules/flatfield_corr.smk b/workflow/rules/flatfield_corr.smk index 034b8e4..d5a9335 100644 --- a/workflow/rules/flatfield_corr.smk +++ b/workflow/rules/flatfield_corr.smk @@ -13,7 +13,7 @@ rule fit_basic_flatfield_corr: ), params: channel=lambda wildcards: get_stains(wildcards).index(wildcards.stain), - max_n_images=config["basic_flatfield_corr"]["max_n_images"], #sets maximum number of images to use for fitting (selected randomly) + max_n_images=config["basic_flatfield_corr"]["max_n_images"], basic_opts=config["basic_flatfield_corr"]["fitting_opts"], output: model_dir=temp( @@ -30,7 +30,7 @@ rule fit_basic_flatfield_corr: ) ), resources: - runtime=90, #this should be proportional to the number of images and image size + runtime=90, mem_mb=64000, threads: 8 benchmark: @@ -43,6 +43,16 @@ rule fit_basic_flatfield_corr: stain="{stain}", suffix="benchmark.tsv", ) + log: + bids( + root="logs", + datatype="fit_basic_flatfield", + subject="{subject}", + sample="{sample}", + acq="{acq}", + stain="{stain}", + suffix="log.txt", + ), group: "preproc" container: @@ -91,8 +101,17 @@ rule apply_basic_flatfield_corr: acq="{acq}", suffix="benchmark.tsv", ) + log: + bids( + root="logs", + datatype="apply_basic_flatfield", + subject="{subject}", + sample="{sample}", + acq="{acq}", + suffix="log.txt", + ), resources: - runtime=60, #this should be proportional to the number of images and image size + runtime=60, mem_mb=32000, threads: 32 group: diff --git a/workflow/rules/import.smk b/workflow/rules/import.smk index 71ca961..bf552e2 100644 --- a/workflow/rules/import.smk +++ b/workflow/rules/import.smk @@ -1,29 +1,3 @@ -def cmd_get_dataset(wildcards, input, output): - cmds = [] - import tarfile - - # supports tar, tar.gz, tgz, zip, or folder name - dataset_path = Path(input.dataset_path) - suffix = dataset_path.suffix - if dataset_path.is_dir(): - # we have a directory: - # return command to copy folder - cmds.append(f"ln -sr {input} {output}") - - elif tarfile.is_tarfile(dataset_path): - # we have a tar file - # check if gzipped: - cmds.append(f"mkdir -p {output}") - if suffix == "gz" or suffix == "tgz": - cmds.append(f"tar -xzf {input} -C {output}") - else: - cmds.append(f"tar -xf {input} -C {output}") - - else: - print(f"unsupported input: {dataset_path}") - - return " && ".join(cmds) - rule get_dataset: input: @@ -46,6 +20,16 @@ rule get_dataset: ), group: "preproc" + log: + bids( + root="logs", + subject="{subject}", + datatype="get_dataset", + sample="{sample}", + acq="{acq}", + desc="raw", + suffix="log.txt", + ), shell: "{params.cmd}" @@ -76,6 +60,15 @@ rule raw_to_metadata: acq="{acq}", suffix="benchmark.tsv", ) + log: + bids( + root="logs", + datatype="raw_to_metdata", + subject="{subject}", + sample="{sample}", + acq="{acq}", + suffix="log.txt", + ), group: "preproc" container: @@ -120,6 +113,15 @@ rule tif_to_zarr: acq="{acq}", suffix="benchmark.tsv", ) + log: + bids( + root="logs", + datatype="tif_to_zarr", + subject="{subject}", + sample="{sample}", + acq="{acq}", + suffix="log.txt", + ), group: "preproc" threads: 32 diff --git a/workflow/rules/ome_zarr.smk b/workflow/rules/ome_zarr.smk index e19717e..b403eb3 100644 --- a/workflow/rules/ome_zarr.smk +++ b/workflow/rules/ome_zarr.smk @@ -32,6 +32,17 @@ rule zarr_to_ome_zarr: ) ), threads: 32 + log: + bids( + root="logs", + subject="{subject}", + datatype="zarr_to_ome_zarr", + sample="{sample}", + acq="{acq}", + desc="{desc}", + stain="{stain}", + suffix="log.txt", + ), container: config["containers"]["spimprep"] group: @@ -90,6 +101,18 @@ rule ome_zarr_to_nii: level="{level}", suffix="benchmark.tsv", ) + log: + bids( + root="logs", + datatype="ome_zarr_to_nifti", + subject="{subject}", + sample="{sample}", + acq="{acq}", + desc="{desc}", + stain="{stain}", + level="{level}", + suffix="log.txt", + ), group: "preproc" threads: 32 @@ -97,25 +120,3 @@ rule ome_zarr_to_nii: config["containers"]["spimprep"] script: "../scripts/ome_zarr_to_nii.py" - - -rule zarr_masking_wip: - input: - zarr=rules.zarr_to_ome_zarr.output.zarr, - output: - zarr=directory( - bids( - root=root, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - desc="{desc}", - stain="{stain}", - suffix="mask.ome.zarr", - ) - ), - container: - config["containers"]["spimprep"] - notebook: - "../notebooks/zarr_masking.py.ipynb"