From 110e01e7a530e45edad84135426d84c29ad34cab Mon Sep 17 00:00:00 2001 From: Ross Barnowski Date: Mon, 1 Apr 2024 23:05:49 -0700 Subject: [PATCH] Add deepcell-users auth layer for weights. --- cellSAM/_auth.py | 173 +++++++++++++++++++++++++++++++++++++++++++++++ cellSAM/model.py | 38 +++-------- 2 files changed, 184 insertions(+), 27 deletions(-) create mode 100644 cellSAM/_auth.py diff --git a/cellSAM/_auth.py b/cellSAM/_auth.py new file mode 100644 index 0000000..e4e9b5c --- /dev/null +++ b/cellSAM/_auth.py @@ -0,0 +1,173 @@ +"""User interface to authentication layer for data/models.""" + +import os +import requests +from pathlib import Path +from hashlib import md5 +from tqdm import tqdm +import logging +import tarfile +import zipfile + + +_api_endpoint = "https://users.deepcell.org/api/getData/" +_asset_location = Path.home() / ".deepcell" + + +def fetch_data(asset_key: str, cache_subdir=None, file_hash=None): + """Fetch assets through users.deepcell.org authentication system. + + Download assets from the deepcell suite of datasets and models which + require user-authentication. + + .. note:: + + You must have a Deepcell Access Token set as an environment variable + with the name ``DEEPCELL_ACCESS_TOKEN`` in order to access assets. + + Access tokens can be created at _ + + Args: + :param asset_key: Key of the file to download. + The list of available assets can be found on the users.deepcell.org + homepage. + + :param cache_subdir: `str` indicating directory relative to + `~/.deepcell` where downloaded data will be cached. The default is + `None`, which means cache the data in `~/.deepcell`. + + :param file_hash: `str` represented the md5 checksum of datafile. The + checksum is used to perform data caching. If no checksum is provided or + the checksum differs from that found in the data cache, the data will + be (re)-downloaded. + """ + logging.basicConfig(level=logging.INFO) + + download_location = _asset_location + if cache_subdir is not None: + download_location /= cache_subdir + download_location.mkdir(exist_ok=True, parents=True) + + # Extract the filename from the asset_key, which can be a full path + fname = os.path.split(asset_key)[-1] + fpath = download_location / fname + + # Check for cached data + if file_hash is not None: + logging.info('Checking for cached data') + try: + with open(fpath, "rb") as fh: + hasher = md5(fh.read()) + logging.info(f"Checking {fname} against provided file_hash...") + md5sum = hasher.hexdigest() + if md5sum == file_hash: + logging.info( + f"{fname} with hash {file_hash} already available." + ) + return fpath + logging.info( + f"{fname} with hash {file_hash} not found in {download_location}" + ) + except FileNotFoundError: + pass + + # Check for access token + access_token = os.environ.get("DEEPCELL_ACCESS_TOKEN") + if access_token is None: + raise ValueError( + "\nDEEPCELL_ACCESS_TOKEN not found.\n" + "Please set your access token to the DEEPCELL_ACCESS_TOKEN\n" + "environment variable.\n" + "For example:\n\n" + "\texport DEEPCELL_ACCESS_TOKEN=.\n\n" + "If you don't yet have a token, you can create one at\n" + "https://users.deepcell.org" + ) + + # Request download URL + headers = {"X-Api-Key": access_token} + logging.info("Making request to server") + resp = requests.post( + _api_endpoint, headers=headers, data={"s3_key": asset_key} + ) + # Raise informative exception for the specific case when the asset_key is + # not found in the bucket + if resp.status_code == 404 and resp.json().get("error") == "Key not found": + raise ValueError(f"Object {asset_key} not found.") + # Raise informative exception for the specific case when an invalid + # API token is provided. + if resp.status_code == 403 and ( + resp.json().get("detail") == "Authentication credentials were not provided." + ): + raise ValueError( + f"\n\nAPI token {access_token} is not valid.\n" + "The token may be expired - if so, create a new one at\n" + "https://users.deepcell.org" + ) + # Handle all other non-http-200 status + resp.raise_for_status() + + # Parse response + response_data = resp.json() + download_url = response_data["url"] + file_size = response_data["size"] + # Parse file_size (TODO: would be more convenient if it were numerical, i.e. always bytes) + val, suff = file_size.split(" ") + # TODO: Case statement would be awesome here, but need to support all the + # way back to Python 3.8 + suffix_mapping = {"KB": 1e3, "MB": 1e6, "B": 1, "GB": 1e9} + file_size_numerical = int(float(val) * suffix_mapping[suff]) + + logging.info( + f"Downloading {asset_key} with size {file_size} to {download_location}" + ) + data_req = requests.get( + download_url, headers={"user-agent": "Wget/1.20 (linux-gnu)"}, stream=True + ) + data_req.raise_for_status() + + chunk_size = 4096 + with tqdm.wrapattr( + open(fpath, "wb"), "write", miniters=1, total=file_size_numerical + ) as fh: + for chunk in data_req.iter_content(chunk_size=chunk_size): + fh.write(chunk) + + logging.info(f"🎉 Successfully downloaded file to {fpath}") + + return fpath + + +def extract_archive(file_path, path="."): + """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. + + Args: + file_path: Path to the archive file. + path: Where to extract the archive file. + + Returns: + True if a match was found and an archive extraction was completed, + False otherwise. + """ + logging.basicConfig(level=logging.INFO) + + file_path = os.fspath(file_path) if isinstance(file_path, os.PathLike) else file_path + path = os.fspath(path) if isinstance(path, os.PathLike) else path + + logging.info(f'Extracting {file_path}') + + status = False + + if tarfile.is_tarfile(file_path): + with tarfile.open(file_path) as archive: + archive.extractall(path) + status = True + elif zipfile.is_zipfile(file_path): + with zipfile.ZipFile(file_path) as archive: + archive.extractall(path) + status = True + + if status: + logging.info(f'Successfully extracted {file_path} into {path}') + else: + logging.info(f'Failed to extract {file_path} into {path}') diff --git a/cellSAM/model.py b/cellSAM/model.py index 499fea4..370758b 100644 --- a/cellSAM/model.py +++ b/cellSAM/model.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn import numpy as np -from tqdm import tqdm -import os +from pathlib import Path import yaml from pkg_resources import resource_filename @@ -24,44 +23,29 @@ fill_holes_and_remove_small_masks, subtract_boundaries, ) -import requests +from ._auth import fetch_data, extract_archive -def download_file_with_progress(url, destination): - response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 Kibibyte +__all__ = ["segment_cellular_image"] - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - - with open(destination, 'wb') as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - print("ERROR: Something went wrong") def get_model(model: nn.Module = None) -> nn.Module: """ Returns a loaded CellSAM model. If model is None, downloads weights and loads the model with a progress bar. """ - cellsam_assets_dir = os.path.join(os.path.expanduser("~"), ".cellsam_assets") - model_path = os.path.join(cellsam_assets_dir, "cellsam_base.pt") + cellsam_assets_dir = Path.home() / ".deepcell/models" + model_path = cellsam_assets_dir / "cellsam_base.pt" config_path = resource_filename(__name__, 'modelconfig.yaml') with open(config_path, 'r') as config_file: config = yaml.safe_load(config_file) if model is None: - if not os.path.exists(cellsam_assets_dir): - os.makedirs(cellsam_assets_dir) - if not os.path.isfile(model_path): - print("Downloading CellSAM model weights, please wait...") - download_file_with_progress( - "https://storage.googleapis.com/cellsam-data/cellsam_base.pt", - model_path, - ) + if not cellsam_assets_dir.exists(): + cellsam_assets_dir.mkdir(parents=True, exist_ok=True) + if not model_path.exists(): + fetch_data("models/cellsam_base.tar.gz", cache_subdir="models") + extract_archive(model_path, cellsam_assets_dir) + assert model_path.exists() model = CellSAM(config) model.load_state_dict(torch.load(model_path)) return model