Skip to content

Commit

Permalink
major refactoring of cloud outputs
Browse files Browse the repository at this point in the history
- instead of a final() wrapper, the bids() function is overloaded to
append storage() when the file has a remote URI in it
- this way, we can just add the gcs:// or s3:// prefix to the root
(output folder) config variable, and avoid having a write_to_remote
flag.
- however, it complicates a couple other things, e.g. expand() cannot be
applied to files with the storage tag, so we make another wrapper,
expand_bids() to make sure storage() is applied after expanding..
- also refactored the fsspec code, which now lives in
workflow/lib/cloud_io.py - I considered moving it to zarrnii, but it is
actually snakemake specific so probably better to stay as a helper
function in the snakemake workflow
  • Loading branch information
akhanf committed Jul 17, 2024
1 parent 029746f commit 3f78839
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 238 deletions.
4 changes: 1 addition & 3 deletions config/config.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
datasets: 'config/datasets.tsv'


root: 'bids'
root: 'bids' # can use a s3:// or gcs:// prefix to write output to cloud storage
work: 'work'

write_to_remote: False #save files marked as final() to cloud store
remote_prefix: 'gcs://khanlab-bucket/lightsheet-data'
remote_creds: '~/.config/gcloud/application_default_credentials.json' #this is needed so we can pass creds to container

write_ome_zarr_direct: True #use this to skip writing the final zarr output to work first and copying afterwards -- useful when work is not a fast local disk
Expand Down
4 changes: 2 additions & 2 deletions workflow/Snakefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from snakebids import bids, set_bids_spec
from upath import UPath as Path
from snakebids import set_bids_spec
import pandas as pd
from collections import defaultdict
import os
Expand Down
33 changes: 33 additions & 0 deletions workflow/lib/cloud_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from upath import UPath as Path

def is_remote(uri_string):
uri = Path(uri_string)
if uri.protocol == 'gcs' or uri.protocol == 's3':
return True
else:
return False

def get_fsspec(uri_string,storage_provider_settings=None,creds=None):
uri = Path(uri_string)
if uri.protocol == 'gcs':
print('is gcs')
from gcsfs import GCSFileSystem
gcsfs_opts={}
gcsfs_opts={'project': storage_provider_settings['gcs'].get_settings().project,
'token': creds}
fs = GCSFileSystem(**gcsfs_opts)
elif uri.protocol == 's3':
from s3fs import S3FileSystem
s3fs_opts={'anon': False}
fs = S3FileSystem(**s3fs_opts)
elif uri.protocol == 'file' or uri.protocol == 'local' or uri.protocol == '':
#assumed to be local file
from fsspec import LocalFileSystem
fs = LocalFileSystem()
else:
print(f'unsupported protocol for remote data')
return fs




10 changes: 5 additions & 5 deletions workflow/rules/bids.smk
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ rule raw_dataset_desc:
params:
dd=config["bids"]["raw"],
output:
json=final(Path(root) / "dataset_description.json"),
json=bids_toplevel(root, "dataset_description.json"),
log:
"logs/dd_raw.log",
localrule: True
Expand All @@ -18,7 +18,7 @@ rule resampled_dataset_desc:
params:
dd=config["bids"]["resampled"],
output:
json=final(Path(resampled) / "dataset_description.json"),
json=bids_toplevel(resampled, "dataset_description.json"),
log:
"logs/dd_raw.log",
localrule: True
Expand All @@ -31,7 +31,7 @@ rule bids_readme:
input:
config["bids"]["readme_md"],
output:
final(Path(root) / "README.md"),
bids_toplevel(root, "README.md"),
log:
"logs/bids_readme.log",
localrule: True
Expand All @@ -43,7 +43,7 @@ rule bids_samples_json:
input:
config["bids"]["samples_json"],
output:
final(Path(root) / "samples.json"),
bids_toplevel(root, "samples.json"),
log:
"logs/bids_samples_json.log",
localrule: True
Expand All @@ -55,7 +55,7 @@ rule create_samples_tsv:
params:
datasets_df=datasets,
output:
tsv=final(Path(root) / "samples.tsv"),
tsv=bids_toplevel(root, "samples.tsv"),
log:
"logs/bids_samples_tsv.log",
localrule: True
Expand Down
185 changes: 99 additions & 86 deletions workflow/rules/common.smk
Original file line number Diff line number Diff line change
@@ -1,8 +1,49 @@
import tarfile
from snakebids import bids as _bids
from upath import UPath as Path
from lib.cloud_io import is_remote


def expand_bids(expand_kwargs, **bids_kwargs):

files = expand(_bids(**bids_kwargs), **expand_kwargs)

if is_remote(files[0]):
return [storage(f) for f in files]
else:
return files


def directory_bids(root, *args, **kwargs):
bids_str = _bids(root=root, *args, **kwargs)

if is_remote(root):
return storage(directory(bids_str))
else:
return bids_str


def bids_toplevel(root, filename):
bids_str = str(Path(_bids(root=root)) / filename)

if is_remote(root):
return storage(bids_str)
else:
return bids_str


def bids(root, *args, **kwargs):
bids_str = _bids(root=root, *args, **kwargs)

if is_remote(root):
return storage(bids_str)
else:
return bids_str


def get_extension_ome_zarr():
if config["write_to_remote"]:

if is_remote(config["root"]):
return "ome.zarr/.snakemake_touch"
else:
if config["ome_zarr"]["use_zipstore"]:
Expand All @@ -11,60 +52,43 @@ def get_extension_ome_zarr():
return "ome.zarr"


def final(path_or_paths):
if config["write_to_remote"]:
if type(path_or_paths) == list:
out_paths = []
for path in path_or_paths:
out_paths.append(storage(os.path.join(config["remote_prefix"], path)))
return out_paths
else:
return storage(os.path.join(config["remote_prefix"], path_or_paths))
else:
return path_or_paths


# targets
def get_all_targets():
targets = []
for i in range(len(datasets)):
targets.extend(
final(
expand(
bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}",
),
expand_bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}",
expand_kwargs=dict(
subject=datasets.loc[i, "subject"],
sample=datasets.loc[i, "sample"],
acq=datasets.loc[i, "acq"],
extension=[get_extension_ome_zarr(), "json"],
)
),
)
)
targets.extend(
final(
expand(
bids(
root=resampled,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
res="{level}x",
stain="{stain}",
suffix="SPIM.nii",
),
expand_bids(
root=resampled,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
res="{level}x",
stain="{stain}",
suffix="SPIM.nii",
expand_kwargs=dict(
subject=datasets.loc[i, "subject"],
sample=datasets.loc[i, "sample"],
acq=datasets.loc[i, "acq"],
level=config["nifti"]["levels"],
stain=get_stains_by_row(i),
)
),
)
)

Expand All @@ -73,12 +97,12 @@ def get_all_targets():

def get_bids_toplevel_targets():
targets = []
targets.append(Path(root) / "README.md")
targets.append(Path(root) / "dataset_description.json")
targets.append(Path(root) / "samples.tsv")
targets.append(Path(root) / "samples.json")
targets.append(Path(resampled) / "dataset_description.json")
return [final(target) for target in targets]
targets.append(bids_toplevel(root, "README.md"))
targets.append(bids_toplevel(root, "dataset_description.json"))
targets.append(bids_toplevel(root, "samples.tsv"))
targets.append(bids_toplevel(root, "samples.json"))
targets.append(bids_toplevel(resampled, "dataset_description.json"))
return targets


def get_input_dataset(wildcards):
Expand Down Expand Up @@ -201,37 +225,31 @@ def get_macro_args_zarr_fusion(wildcards, input, output):


def get_output_ome_zarr(acq_type):
if config["write_to_remote"]:
if is_remote(config["root"]):
return {
"zarr": touch(
final(
bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq=f"{{acq,[a-zA-Z0-9]*{acq_type}[a-zA-Z0-9]*}}",
suffix="SPIM.{extension}".format(
extension=get_extension_ome_zarr()
),
)
bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq=f"{{acq,[a-zA-Z0-9]*{acq_type}[a-zA-Z0-9]*}}",
suffix="SPIM.{extension}".format(
extension=get_extension_ome_zarr()
),
)
)
}
else:
if config["write_ome_zarr_direct"]:
return {
"zarr": final(
directory(
bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq=f"{{acq,[a-zA-Z0-9]*{acq_type}[a-zA-Z0-9]*}}",
suffix="SPIM.ome.zarr",
)
)
"zarr": directory_bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq=f"{{acq,[a-zA-Z0-9]*{acq_type}[a-zA-Z0-9]*}}",
suffix="SPIM.ome.zarr",
)
}
else:
Expand All @@ -252,31 +270,25 @@ def get_output_ome_zarr(acq_type):


def get_input_ome_zarr_to_nii():
if config["write_to_remote"]:
return final(
bids(
if is_remote(config["root"]):
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
else:
if config["write_ome_zarr_direct"]:
return bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(extension=get_extension_ome_zarr()),
)
)
else:
if config["write_ome_zarr_direct"]:
return final(
bids(
root=root,
subject="{subject}",
datatype="micr",
sample="{sample}",
acq="{acq}",
suffix="SPIM.{extension}".format(
extension=get_extension_ome_zarr()
),
)
)
else:
return bids(
root=work,
Expand All @@ -290,7 +302,8 @@ def get_input_ome_zarr_to_nii():

def get_storage_creds():
"""for rules that deal with remote storage directly"""
if config["write_to_remote"]:
protocol = Path(config["root"]).protocol
if protocol == "gcs":
# currently only works with gcs
creds = os.path.expanduser(config["remote_creds"])
return {"creds": creds}
Expand Down
Loading

0 comments on commit 3f78839

Please sign in to comment.