Skip to content

Commit

Permalink
Allow for cloud uri (GCS) as input to the pipeline (#42)
Browse files Browse the repository at this point in the history
This adds support for inputs coming from a GCS bucket. 
This is mainly used for execution on coiled, for which the wrapper (being developed in a separate repo) makes use of.
- adds global cores_per_rule config
- adds new rules for reading directly from cloud to get metadata and tif files

Note: does not support tar files in the cloud, only folders containing the tif files.
  • Loading branch information
akhanf authored Aug 25, 2024
1 parent 15a01fc commit 923aa84
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 34 deletions.
2 changes: 2 additions & 0 deletions config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ remote_creds: '~/.config/gcloud/application_default_credentials.json' #this is n

write_ome_zarr_direct: True #use this to skip writing the final zarr output to work first and copying afterwards -- useful when work is not a fast local disk

cores_per_rule: 32

#import wildcards: tilex, tiley, channel, zslice (and prefix - unused)
import_blaze:
raw_tif_pattern: "{prefix}_Blaze[{tilex} x {tiley}]_C{channel}_xyz-Table Z{zslice}.ome.tif"
Expand Down
8 changes: 8 additions & 0 deletions workflow/lib/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@ def is_remote(uri_string):
else:
return False

def is_remote_gcs(uri_string):
uri = Path(uri_string)
if uri.protocol == 'gcs':
return True
else:
return False


def get_fsspec(uri_string,storage_provider_settings=None,creds=None):
uri = Path(uri_string)
if uri.protocol == 'gcs':
Expand Down
10 changes: 5 additions & 5 deletions workflow/rules/bigstitcher.smk
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ rule zarr_to_bdv:
desc="{desc}",
suffix="SPIM.zarr",
),
metadata_json=rules.blaze_to_metadata.output.metadata_json,
metadata_json=rules.copy_blaze_metadata.output.metadata_json,
params:
max_downsampling_layers=5,
temp_h5=str(
Expand Down Expand Up @@ -85,7 +85,7 @@ rule zarr_to_bdv:
desc="{desc}",
suffix="log.txt",
),
threads: 32
threads: config["cores_per_rule"]
group:
"preproc"
container:
Expand Down Expand Up @@ -151,7 +151,7 @@ rule bigstitcher:
resources:
runtime=30,
mem_mb=10000,
threads: 32
threads: config["cores_per_rule"]
group:
"preproc"
shell:
Expand Down Expand Up @@ -239,7 +239,7 @@ rule fuse_dataset:
resources:
runtime=30,
mem_mb=20000,
threads: 32
threads: config["cores_per_rule"]
group:
"preproc"
shell:
Expand Down Expand Up @@ -311,7 +311,7 @@ rule fuse_dataset_spark:
resources:
runtime=30,
mem_mb=20000,
threads: 32
threads: config["cores_per_rule"]
group:
"preproc"
shell:
Expand Down
39 changes: 33 additions & 6 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tarfile
from snakebids import bids as _bids
from upath import UPath as Path
from lib.cloud_io import is_remote
from lib.cloud_io import is_remote, is_remote_gcs


def bids(root, *args, **kwargs):
Expand Down Expand Up @@ -129,16 +129,20 @@ def get_bids_toplevel_targets():
return targets


def dataset_is_remote(wildcards):
return is_remote_gcs(Path(get_dataset_path(wildcards)))


def get_input_dataset(wildcards):
"""returns path to extracted dataset or path to provided input folder"""
in_dataset = get_dataset_path(wildcards)

dataset_path = Path(get_dataset_path(wildcards))
suffix = dataset_path.suffix

if is_remote_gcs(dataset_path):
return rules.cp_from_gcs.output.ome_dir.format(**wildcards)

if dataset_path.is_dir():
# we have a directory already, just point to it
return str(dataset_path)
return get_dataset_path_remote(wildcards)

elif tarfile.is_tarfile(dataset_path):
# dataset was a tar file, so point to the extracted folder
Expand All @@ -148,6 +152,17 @@ def get_input_dataset(wildcards):
print(f"unsupported input: {dataset_path}")


def get_metadata_json(wildcards):
"""returns path to metadata, extracted from local or gcs"""
dataset_path = Path(get_dataset_path(wildcards))
suffix = dataset_path.suffix

if is_remote_gcs(dataset_path):
return rules.blaze_to_metadata_gcs.output.metadata_json.format(**wildcards)
else:
return rules.blaze_to_metadata.output.metadata_json.format(**wildcards)


# import
def cmd_extract_dataset(wildcards, input, output):
cmds = []
Expand All @@ -174,6 +189,19 @@ def cmd_extract_dataset(wildcards, input, output):
return " && ".join(cmds)


def get_dataset_path_remote(wildcards):
path = get_dataset_path(wildcards)
if is_remote(path):
return storage(path)
else:
return path


def get_dataset_path_gs(wildcards):
path = Path(get_dataset_path(wildcards)).path
return f"gs://{path}"


def get_dataset_path(wildcards):
df = datasets.query(
f"subject=='{wildcards.subject}' and sample=='{wildcards.sample}' and acq=='{wildcards.acq}'"
Expand All @@ -182,7 +210,6 @@ def get_dataset_path(wildcards):


def get_stains_by_row(i):

# Select columns that match the pattern 'stain_'
stain_columns = datasets.filter(like="stain_").columns

Expand Down
14 changes: 7 additions & 7 deletions workflow/rules/flatfield_corr.smk
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
rule fit_basic_flatfield_corr:
""" BaSiC flatfield correction"""
input:
zarr=bids(
zarr=lambda wildcards: bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="raw",
desc="rawfromgcs" if dataset_is_remote(wildcards) else "raw",
suffix="SPIM.zarr",
),
).format(**wildcards),
params:
channel=lambda wildcards: get_stains(wildcards).index(wildcards.stain),
max_n_images=config["basic_flatfield_corr"]["max_n_images"],
Expand Down Expand Up @@ -64,15 +64,15 @@ rule fit_basic_flatfield_corr:
rule apply_basic_flatfield_corr:
""" apply BaSiC flatfield correction """
input:
zarr=bids(
zarr=lambda wildcards: bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="raw",
desc="rawfromgcs" if dataset_is_remote(wildcards) else "raw",
suffix="SPIM.zarr",
),
).format(**wildcards),
model_dirs=lambda wildcards: expand(
rules.fit_basic_flatfield_corr.output.model_dir,
stain=get_stains(wildcards),
Expand Down Expand Up @@ -113,7 +113,7 @@ rule apply_basic_flatfield_corr:
resources:
runtime=60,
mem_mb=32000,
threads: 32
threads: config["cores_per_rule"]
group:
"preproc"
script:
Expand Down
143 changes: 133 additions & 10 deletions workflow/rules/import.smk
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

rule extract_dataset:
input:
dataset_path=get_dataset_path,
dataset_path=get_dataset_path_remote,
params:
cmd=cmd_extract_dataset,
output:
Expand Down Expand Up @@ -34,23 +33,69 @@ rule extract_dataset:
"{params.cmd}"


rule blaze_to_metadata:
rule blaze_to_metadata_gcs:
input:
ome_dir=get_input_dataset,
creds=os.path.expanduser(config["remote_creds"]),
params:
in_tif_pattern=lambda wildcards, input: os.path.join(
input.ome_dir,
config["import_blaze"]["raw_tif_pattern"],
),
dataset_path=get_dataset_path_gs,
in_tif_pattern=lambda wildcards: config["import_blaze"]["raw_tif_pattern"],
storage_provider_settings=workflow.storage_provider_settings,
output:
metadata_json=bids(
root=root,
desc="gcs",
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq,[a-zA-Z0-9]*blaze[a-zA-Z0-9]*}",
suffix="SPIM.json",
),
benchmark:
bids(
root="benchmarks",
datatype="blaze_to_metadata_gcs",
subject="{subject}",
sample="{sample}",
acq="{acq}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="blaze_to_metadata_gcs",
subject="{subject}",
sample="{sample}",
acq="{acq}",
suffix="log.txt",
),
group:
"preproc"
container:
config["containers"]["spimprep"]
script:
"../scripts/blaze_to_metadata_gcs.py"


rule blaze_to_metadata:
input:
ome_dir=get_input_dataset,
params:
in_tif_pattern=lambda wildcards, input: os.path.join(
input.ome_dir,
config["import_blaze"]["raw_tif_pattern"],
),
output:
metadata_json=temp(
bids(
root=work,
subject="{subject}",
desc="local",
datatype="micr",
sample="{sample}",
acq="{acq,[a-zA-Z0-9]*blaze[a-zA-Z0-9]*}",
suffix="SPIM.json",
)
),
benchmark:
bids(
root="benchmarks",
Expand All @@ -77,6 +122,31 @@ rule blaze_to_metadata:
"../scripts/blaze_to_metadata.py"


rule copy_blaze_metadata:
input:
json=get_metadata_json,
output:
metadata_json=bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq,[a-zA-Z0-9]*blaze[a-zA-Z0-9]*}",
suffix="SPIM.json",
),
log:
bids(
root="logs",
datatype="copy_blaze_metadata",
subject="{subject}",
sample="{sample}",
acq="{acq}",
suffix="log.txt",
),
shell:
"cp {input} {output} &> {log}"


rule prestitched_to_metadata:
input:
ome_dir=get_input_dataset,
Expand Down Expand Up @@ -125,7 +195,7 @@ rule tif_to_zarr:
images as the chunks"""
input:
ome_dir=get_input_dataset,
metadata_json=rules.blaze_to_metadata.output.metadata_json,
metadata_json=rules.copy_blaze_metadata.output.metadata_json,
params:
in_tif_pattern=lambda wildcards, input: os.path.join(
input.ome_dir,
Expand Down Expand Up @@ -166,8 +236,61 @@ rule tif_to_zarr:
),
group:
"preproc"
threads: 32
threads: config["cores_per_rule"]
container:
config["containers"]["spimprep"]
script:
"../scripts/tif_to_zarr.py"


rule tif_to_zarr_gcs:
""" use dask to load tifs in parallel and write to zarr
output shape is (tiles,channels,z,y,x), with the 2d
images as the chunks"""
input:
metadata_json=rules.copy_blaze_metadata.output.metadata_json,
creds=os.path.expanduser(config["remote_creds"]),
params:
dataset_path=get_dataset_path_gs,
in_tif_pattern=lambda wildcards: config["import_blaze"]["raw_tif_pattern"],
intensity_rescaling=config["import_blaze"]["intensity_rescaling"],
storage_provider_settings=workflow.storage_provider_settings,
output:
zarr=temp(
directory(
bids(
root=work,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
desc="rawfromgcs",
suffix="SPIM.zarr",
)
)
),
benchmark:
bids(
root="benchmarks",
datatype="tif_to_zarr",
subject="{subject}",
sample="{sample}",
acq="{acq}",
suffix="benchmark.tsv",
)
log:
bids(
root="logs",
datatype="tif_to_zarr",
subject="{subject}",
sample="{sample}",
acq="{acq}",
suffix="log.txt",
),
group:
"preproc"
threads: config["cores_per_rule"]
container:
config["containers"]["spimprep"]
script:
"../scripts/tif_to_zarr_gcs.py"
Loading

0 comments on commit 923aa84

Please sign in to comment.