From 407827ce5074c1cca68e71ff3a254b7295bf28c3 Mon Sep 17 00:00:00 2001 From: Ali Khan Date: Fri, 23 Aug 2024 09:10:32 -0400 Subject: [PATCH] Process files directly from cloud (gcs) (#43) * changes to avoid copy from gcs will read tif files directly from cloud, both for metadata and for dask array creation. --- workflow/rules/bigstitcher.smk | 2 +- workflow/rules/common.smk | 15 +++ workflow/rules/flatfield_corr.smk | 12 +- workflow/rules/import.smk | 148 +++++++++++++++++----- workflow/rules/ome_zarr.smk | 2 +- workflow/rules/qc.smk | 6 +- workflow/scripts/blaze_to_metadata_gcs.py | 107 ++++++++++++++++ workflow/scripts/tif_to_zarr_gcs.py | 87 +++++++++++++ 8 files changed, 337 insertions(+), 42 deletions(-) create mode 100644 workflow/scripts/blaze_to_metadata_gcs.py create mode 100644 workflow/scripts/tif_to_zarr_gcs.py diff --git a/workflow/rules/bigstitcher.smk b/workflow/rules/bigstitcher.smk index 57862ac..3f4f01c 100644 --- a/workflow/rules/bigstitcher.smk +++ b/workflow/rules/bigstitcher.smk @@ -9,7 +9,7 @@ rule zarr_to_bdv: desc="{desc}", suffix="SPIM.zarr", ), - metadata_json=rules.blaze_to_metadata.output.metadata_json, + metadata_json=rules.copy_blaze_metadata.output.metadata_json, params: max_downsampling_layers=5, temp_h5=str( diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index 82c22a6..4e52bc7 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -129,6 +129,10 @@ def get_bids_toplevel_targets(): return targets +def dataset_is_remote(wildcards): + return is_remote_gcs(Path(get_dataset_path(wildcards))) + + def get_input_dataset(wildcards): """returns path to extracted dataset or path to provided input folder""" dataset_path = Path(get_dataset_path(wildcards)) @@ -148,6 +152,17 @@ def get_input_dataset(wildcards): print(f"unsupported input: {dataset_path}") +def get_metadata_json(wildcards): + """returns path to metadata, extracted from local or gcs""" + dataset_path = Path(get_dataset_path(wildcards)) + suffix = dataset_path.suffix + + if is_remote_gcs(dataset_path): + return rules.blaze_to_metadata_gcs.output.metadata_json.format(**wildcards) + else: + return rules.blaze_to_metadata.output.metadata_json.format(**wildcards) + + # import def cmd_extract_dataset(wildcards, input, output): cmds = [] diff --git a/workflow/rules/flatfield_corr.smk b/workflow/rules/flatfield_corr.smk index e7bb666..6dc6786 100644 --- a/workflow/rules/flatfield_corr.smk +++ b/workflow/rules/flatfield_corr.smk @@ -2,15 +2,15 @@ rule fit_basic_flatfield_corr: """ BaSiC flatfield correction""" input: - zarr=bids( + zarr=lambda wildcards: bids( root=work, subject="{subject}", datatype="micr", sample="{sample}", acq="{acq}", - desc="raw", + desc="rawfromgcs" if dataset_is_remote(wildcards) else "raw", suffix="SPIM.zarr", - ), + ).format(**wildcards), params: channel=lambda wildcards: get_stains(wildcards).index(wildcards.stain), max_n_images=config["basic_flatfield_corr"]["max_n_images"], @@ -64,15 +64,15 @@ rule fit_basic_flatfield_corr: rule apply_basic_flatfield_corr: """ apply BaSiC flatfield correction """ input: - zarr=bids( + zarr=lambda wildcards: bids( root=work, subject="{subject}", datatype="micr", sample="{sample}", acq="{acq}", - desc="raw", + desc="rawfromgcs" if dataset_is_remote(wildcards) else "raw", suffix="SPIM.zarr", - ), + ).format(**wildcards), model_dirs=lambda wildcards: expand( rules.fit_basic_flatfield_corr.output.model_dir, stain=get_stains(wildcards), diff --git a/workflow/rules/import.smk b/workflow/rules/import.smk index 6371a89..8a9074a 100644 --- a/workflow/rules/import.smk +++ b/workflow/rules/import.smk @@ -33,42 +33,47 @@ rule extract_dataset: "{params.cmd}" -rule cp_from_gcs: +rule blaze_to_metadata_gcs: + input: + creds=os.path.expanduser(config["remote_creds"]), params: dataset_path=get_dataset_path_gs, + in_tif_pattern=lambda wildcards: config["import_blaze"]["raw_tif_pattern"], + storage_provider_settings=workflow.storage_provider_settings, output: - ome_dir=temp( - directory( - bids( - root=work, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq}", - desc="rawfromgcs", - suffix="SPIM", - ) - ) + metadata_json=bids( + root=root, + desc="gcs", + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq,[a-zA-Z0-9]*blaze[a-zA-Z0-9]*}", + suffix="SPIM.json", ), - threads: config["cores_per_rule"] - group: - "preproc" + benchmark: + bids( + root="benchmarks", + datatype="blaze_to_metadata_gcs", + subject="{subject}", + sample="{sample}", + acq="{acq}", + suffix="benchmark.tsv", + ) log: bids( root="logs", + datatype="blaze_to_metadata_gcs", subject="{subject}", - datatype="cp_from_gcs", sample="{sample}", acq="{acq}", - desc="raw", suffix="log.txt", ), + group: + "preproc" container: - None - conda: - "../envs/google_cloud.yaml" - shell: - "mkdir -p {output} && gcloud storage cp --recursive {params.dataset_path}/* {output}" + config["containers"]["spimprep"] + script: + "../scripts/blaze_to_metadata_gcs.py" rule blaze_to_metadata: @@ -80,13 +85,16 @@ rule blaze_to_metadata: config["import_blaze"]["raw_tif_pattern"], ), output: - metadata_json=bids( - root=root, - subject="{subject}", - datatype="micr", - sample="{sample}", - acq="{acq,[a-zA-Z0-9]*blaze[a-zA-Z0-9]*}", - suffix="SPIM.json", + metadata_json=temp( + bids( + root=work, + subject="{subject}", + desc="local", + datatype="micr", + sample="{sample}", + acq="{acq,[a-zA-Z0-9]*blaze[a-zA-Z0-9]*}", + suffix="SPIM.json", + ) ), benchmark: bids( @@ -114,6 +122,31 @@ rule blaze_to_metadata: "../scripts/blaze_to_metadata.py" +rule copy_blaze_metadata: + input: + json=get_metadata_json, + output: + metadata_json=bids( + root=root, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq,[a-zA-Z0-9]*blaze[a-zA-Z0-9]*}", + suffix="SPIM.json", + ), + log: + bids( + root="logs", + datatype="copy_blaze_metadata", + subject="{subject}", + sample="{sample}", + acq="{acq}", + suffix="log.txt", + ), + shell: + "cp {input} {output} &> {log}" + + rule prestitched_to_metadata: input: ome_dir=get_input_dataset, @@ -162,7 +195,7 @@ rule tif_to_zarr: images as the chunks""" input: ome_dir=get_input_dataset, - metadata_json=rules.blaze_to_metadata.output.metadata_json, + metadata_json=rules.copy_blaze_metadata.output.metadata_json, params: in_tif_pattern=lambda wildcards, input: os.path.join( input.ome_dir, @@ -208,3 +241,56 @@ rule tif_to_zarr: config["containers"]["spimprep"] script: "../scripts/tif_to_zarr.py" + + +rule tif_to_zarr_gcs: + """ use dask to load tifs in parallel and write to zarr + output shape is (tiles,channels,z,y,x), with the 2d + images as the chunks""" + input: + metadata_json=rules.copy_blaze_metadata.output.metadata_json, + creds=os.path.expanduser(config["remote_creds"]), + params: + dataset_path=get_dataset_path_gs, + in_tif_pattern=lambda wildcards: config["import_blaze"]["raw_tif_pattern"], + intensity_rescaling=config["import_blaze"]["intensity_rescaling"], + storage_provider_settings=workflow.storage_provider_settings, + output: + zarr=temp( + directory( + bids( + root=work, + subject="{subject}", + datatype="micr", + sample="{sample}", + acq="{acq}", + desc="rawfromgcs", + suffix="SPIM.zarr", + ) + ) + ), + benchmark: + bids( + root="benchmarks", + datatype="tif_to_zarr", + subject="{subject}", + sample="{sample}", + acq="{acq}", + suffix="benchmark.tsv", + ) + log: + bids( + root="logs", + datatype="tif_to_zarr", + subject="{subject}", + sample="{sample}", + acq="{acq}", + suffix="log.txt", + ), + group: + "preproc" + threads: config["cores_per_rule"] + container: + config["containers"]["spimprep"] + script: + "../scripts/tif_to_zarr_gcs.py" diff --git a/workflow/rules/ome_zarr.smk b/workflow/rules/ome_zarr.smk index a390513..51c635e 100644 --- a/workflow/rules/ome_zarr.smk +++ b/workflow/rules/ome_zarr.smk @@ -16,7 +16,7 @@ rule zarr_to_ome_zarr: desc=config["ome_zarr"]["desc"], allow_missing=True, ), - metadata_json=rules.blaze_to_metadata.output.metadata_json, + metadata_json=rules.copy_blaze_metadata.output.metadata_json, params: max_downsampling_layers=config["ome_zarr"]["max_downsampling_layers"], rechunk_size=config["ome_zarr"]["rechunk_size"], diff --git a/workflow/rules/qc.smk b/workflow/rules/qc.smk index b98bd7f..415b3d7 100644 --- a/workflow/rules/qc.smk +++ b/workflow/rules/qc.smk @@ -1,15 +1,15 @@ rule generate_flatfield_qc: "Generates an html file for comparing before and after flatfield correction" input: - uncorr=bids( + uncorr=lambda wildcards: bids( root=work, subject="{subject}", datatype="micr", sample="{sample}", acq="{acq}", - desc="raw", + desc="rawfromgcs" if dataset_is_remote(wildcards) else "raw", suffix="SPIM.zarr", - ), + ).format(**wildcards), corr=bids( root=work, subject="{subject}", diff --git a/workflow/scripts/blaze_to_metadata_gcs.py b/workflow/scripts/blaze_to_metadata_gcs.py new file mode 100644 index 0000000..24e2ec2 --- /dev/null +++ b/workflow/scripts/blaze_to_metadata_gcs.py @@ -0,0 +1,107 @@ +import tifffile +import xmltodict +import json +import re +import os +from itertools import product +from snakemake.io import glob_wildcards +import gcsfs +from lib.cloud_io import get_fsspec + +dataset_uri = snakemake.params.dataset_path +in_tif_pattern = snakemake.params.in_tif_pattern + + +gcsfs_opts={'project': snakemake.params.storage_provider_settings['gcs'].get_settings().project, + 'token': snakemake.input.creds} +fs = gcsfs.GCSFileSystem(**gcsfs_opts) + +tifs = fs.glob(f"{dataset_uri}/*.tif") + +#parse the filenames to get number of channels, tiles etc.. +prefix, tilex, tiley, channel, zslice = glob_wildcards(in_tif_pattern,files=tifs) + +tiles_x = sorted(list(set(tilex))) +tiles_y = sorted(list(set(tiley))) +channels = sorted(list(set(channel))) +zslices = sorted(list(set(zslice))) +prefixes = sorted(list(set(prefix))) +print(tiles_x) +print(tiles_y) +#read in series metadata from first file +in_tif = in_tif_pattern.format(tilex=tiles_x[0],tiley=tiles_y[0],prefix=prefixes[0],channel=channels[0],zslice=zslices[0]) + +print(in_tif) +print(f"gcs://{in_tif}") + +with fs.open(f"gcs://{in_tif}", 'rb') as tif_file: + raw_tif = tifffile.TiffFile(tif_file,mode='r') + + axes = raw_tif.series[0].get_axes() + shape = raw_tif.series[0].get_shape() + print(axes) + + ome_dict = xmltodict.parse(raw_tif.ome_metadata) + + +physical_size_x = ome_dict['OME']['Image']['Pixels']['@PhysicalSizeX'] +physical_size_y = ome_dict['OME']['Image']['Pixels']['@PhysicalSizeY'] +physical_size_z = ome_dict['OME']['Image']['Pixels']['@PhysicalSizeZ'] +custom_metadata = ome_dict['OME']['Image']['ca:CustomAttributes'] + + + +#read tile configuration from the microscope metadata +if axes == 'CZYX': + tile_config_pattern=r"Blaze\[(?P[0-9]+) x (?P[0-9]+)\]_C(?P[0-9]+)_xyz-Table Z(?P[0-9]+).ome.tif;;\((?P\S+), (?P\S+),(?P\S+), (?P\S+)\)" +elif axes == 'ZYX': + tile_config_pattern=r"Blaze\[(?P[0-9]+) x (?P[0-9]+)\]_C(?P[0-9]+)_xyz-Table Z(?P[0-9]+).ome.tif;;\((?P\S+), (?P\S+), (?P\S+)\)" + +tile_pattern = re.compile(tile_config_pattern) + +#put it in 3 maps, one for each coord, indexed by tilex, tiley, channel, and aslice +map_x=dict() +map_y=dict() +map_z=dict() + +map_tiles_to_chunk=dict() +chunks = [] +for chunk,(tilex,tiley) in enumerate(product(tiles_x,tiles_y)): + map_tiles_to_chunk[tilex+tiley] = chunk + chunks.append(chunk) + +for line in custom_metadata['TileConfiguration']['@TileConfiguration'].split(' ')[1:]: + + d = re.search(tile_pattern,line).groupdict() + chunk = map_tiles_to_chunk[d['tilex']+d['tiley']] # want the key to have chunk instad of tilex,tiley, so map to that first + + #key is: tile-{chunk}_chan-{channel}_z-{zslice} + key = f"tile-{chunk}_chan-{d['channel']}_z-{d['zslice']}" + + map_x[key] = float(d['x']) + map_y[key] = float(d['y']) + map_z[key] = float(d['z']) + + +metadata={} +metadata['tiles_x'] = tiles_x +metadata['tiles_y'] = tiles_y +metadata['channels'] = channels +metadata['zslices'] = zslices +metadata['prefixes'] = prefixes +metadata['chunks'] = chunks +metadata['axes'] = axes +metadata['shape'] = shape +metadata['physical_size_x'] = float(physical_size_x) +metadata['physical_size_y'] = float(physical_size_y) +metadata['physical_size_z'] = float(physical_size_z) +metadata['lookup_tile_offset_x'] = map_x +metadata['lookup_tile_offset_y'] = map_y +metadata['lookup_tile_offset_z'] = map_z +metadata['ome_full_metadata'] = ome_dict +metadata['PixelSize'] = [ metadata['physical_size_z']/1000.0, metadata['physical_size_y']/1000.0, metadata['physical_size_x']/1000.0 ] #zyx since OME-Zarr is ZYX +metadata['PixelSizeUnits'] = 'mm' + +#write metadata to json +with open(snakemake.output.metadata_json, 'w') as fp: + json.dump(metadata, fp,indent=4) diff --git a/workflow/scripts/tif_to_zarr_gcs.py b/workflow/scripts/tif_to_zarr_gcs.py new file mode 100644 index 0000000..01a9af1 --- /dev/null +++ b/workflow/scripts/tif_to_zarr_gcs.py @@ -0,0 +1,87 @@ +import tifffile +import json +import dask.array as da +import dask.array.image +from itertools import product +from dask.diagnostics import ProgressBar +import gcsfs + +gcsfs_opts={'project': snakemake.params.storage_provider_settings['gcs'].get_settings().project, + 'token': snakemake.input.creds} +fs = gcsfs.GCSFileSystem(**gcsfs_opts) + + +def replace_square_brackets(pattern): + """replace all [ and ] in the string (have to use + intermediate variable to avoid conflicts)""" + pattern = pattern.replace('[','##LEFTBRACKET##') + pattern = pattern.replace(']','##RIGHTBRACKET##') + pattern = pattern.replace('##LEFTBRACKET##','[[]') + pattern = pattern.replace('##RIGHTBRACKET##','[]]') + return pattern + +def read_tiff_slice(fs,gcs_uri, key=0): + """Read a single TIFF slice from GCS.""" + with fs.open(gcs_uri, 'rb') as file: + return tifffile.imread(file, key=key) + +def build_zstack(gcs_uris,fs): + """Build a z-stack from a list of GCS URIs.""" + lazy_arrays = [ + dask.delayed(read_tiff_slice)(fs,uri) for uri in gcs_uris + ] + sample_array = read_tiff_slice(fs,gcs_uris[0]) # Read a sample to get shape and dtype + shape = (len(gcs_uris),) + sample_array.shape + dtype = sample_array.dtype + + # Convert the list of delayed objects into a Dask array + return da.stack([da.from_delayed(lazy_array, shape=sample_array.shape, dtype=dtype) for lazy_array in lazy_arrays], axis=0) + + + +#use tif pattern but replace the [ and ] with [[] and []] so glob doesn't choke +in_tif_glob = replace_square_brackets(str(snakemake.params.in_tif_pattern)) + + +#read metadata json +with open(snakemake.input.metadata_json) as fp: + metadata = json.load(fp) + +#TODO: put these in top-level metadata for easier access.. +size_x=metadata['ome_full_metadata']['OME']['Image']['Pixels']['@SizeX'] +size_y=metadata['ome_full_metadata']['OME']['Image']['Pixels']['@SizeY'] +size_z=metadata['ome_full_metadata']['OME']['Image']['Pixels']['@SizeZ'] +size_c=metadata['ome_full_metadata']['OME']['Image']['Pixels']['@SizeC'] +size_tiles=len(metadata['tiles_x'])*len(metadata['tiles_y']) + + +#now get the first channel and first zslice tif +tiles=[] +for i_tile,(tilex,tiley) in enumerate(product(metadata['tiles_x'],metadata['tiles_y'])): + + zstacks=[] + for i_chan,channel in enumerate(metadata['channels']): + + + zstacks.append(build_zstack(fs.glob('gcs://'+in_tif_glob.format(tilex=tilex,tiley=tiley,prefix=metadata['prefixes'][0],channel=channel,zslice='*')),fs=fs)) + + + #have list of zstack dask arrays for the tile, one for each channel + #stack them up and append to list of tiles + tiles.append(da.stack(zstacks)) + + +#now we have list of tiles, each a dask array +#stack them up to get our final array +darr = da.stack(tiles) + +#rescale intensities, and recast +darr = darr * snakemake.params.intensity_rescaling +darr = darr.astype('uint16') + +#now we can do the computation itself, storing to zarr +print('writing images to zarr with dask') +with ProgressBar(): + da.to_zarr(darr,snakemake.output.zarr,overwrite=True,dimension_separator='/') + +