Skip to content

Commit

Permalink
Add deepcell-users auth layer for weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
rossbar committed Apr 5, 2024
1 parent 7c88e67 commit 110e01e
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 27 deletions.
173 changes: 173 additions & 0 deletions cellSAM/_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""User interface to authentication layer for data/models."""

import os
import requests
from pathlib import Path
from hashlib import md5
from tqdm import tqdm
import logging
import tarfile
import zipfile


_api_endpoint = "https://users.deepcell.org/api/getData/"
_asset_location = Path.home() / ".deepcell"


def fetch_data(asset_key: str, cache_subdir=None, file_hash=None):
"""Fetch assets through users.deepcell.org authentication system.
Download assets from the deepcell suite of datasets and models which
require user-authentication.
.. note::
You must have a Deepcell Access Token set as an environment variable
with the name ``DEEPCELL_ACCESS_TOKEN`` in order to access assets.
Access tokens can be created at <https://users.deepcell.org>_
Args:
:param asset_key: Key of the file to download.
The list of available assets can be found on the users.deepcell.org
homepage.
:param cache_subdir: `str` indicating directory relative to
`~/.deepcell` where downloaded data will be cached. The default is
`None`, which means cache the data in `~/.deepcell`.
:param file_hash: `str` represented the md5 checksum of datafile. The
checksum is used to perform data caching. If no checksum is provided or
the checksum differs from that found in the data cache, the data will
be (re)-downloaded.
"""
logging.basicConfig(level=logging.INFO)

download_location = _asset_location
if cache_subdir is not None:
download_location /= cache_subdir
download_location.mkdir(exist_ok=True, parents=True)

# Extract the filename from the asset_key, which can be a full path
fname = os.path.split(asset_key)[-1]
fpath = download_location / fname

# Check for cached data
if file_hash is not None:
logging.info('Checking for cached data')
try:
with open(fpath, "rb") as fh:
hasher = md5(fh.read())
logging.info(f"Checking {fname} against provided file_hash...")
md5sum = hasher.hexdigest()
if md5sum == file_hash:
logging.info(
f"{fname} with hash {file_hash} already available."
)
return fpath
logging.info(
f"{fname} with hash {file_hash} not found in {download_location}"
)
except FileNotFoundError:
pass

# Check for access token
access_token = os.environ.get("DEEPCELL_ACCESS_TOKEN")
if access_token is None:
raise ValueError(
"\nDEEPCELL_ACCESS_TOKEN not found.\n"
"Please set your access token to the DEEPCELL_ACCESS_TOKEN\n"
"environment variable.\n"
"For example:\n\n"
"\texport DEEPCELL_ACCESS_TOKEN=<your-token>.\n\n"
"If you don't yet have a token, you can create one at\n"
"https://users.deepcell.org"
)

# Request download URL
headers = {"X-Api-Key": access_token}
logging.info("Making request to server")
resp = requests.post(
_api_endpoint, headers=headers, data={"s3_key": asset_key}
)
# Raise informative exception for the specific case when the asset_key is
# not found in the bucket
if resp.status_code == 404 and resp.json().get("error") == "Key not found":
raise ValueError(f"Object {asset_key} not found.")
# Raise informative exception for the specific case when an invalid
# API token is provided.
if resp.status_code == 403 and (
resp.json().get("detail") == "Authentication credentials were not provided."
):
raise ValueError(
f"\n\nAPI token {access_token} is not valid.\n"
"The token may be expired - if so, create a new one at\n"
"https://users.deepcell.org"
)
# Handle all other non-http-200 status
resp.raise_for_status()

# Parse response
response_data = resp.json()
download_url = response_data["url"]
file_size = response_data["size"]
# Parse file_size (TODO: would be more convenient if it were numerical, i.e. always bytes)
val, suff = file_size.split(" ")
# TODO: Case statement would be awesome here, but need to support all the
# way back to Python 3.8
suffix_mapping = {"KB": 1e3, "MB": 1e6, "B": 1, "GB": 1e9}
file_size_numerical = int(float(val) * suffix_mapping[suff])

logging.info(
f"Downloading {asset_key} with size {file_size} to {download_location}"
)
data_req = requests.get(
download_url, headers={"user-agent": "Wget/1.20 (linux-gnu)"}, stream=True
)
data_req.raise_for_status()

chunk_size = 4096
with tqdm.wrapattr(
open(fpath, "wb"), "write", miniters=1, total=file_size_numerical
) as fh:
for chunk in data_req.iter_content(chunk_size=chunk_size):
fh.write(chunk)

logging.info(f"🎉 Successfully downloaded file to {fpath}")

return fpath


def extract_archive(file_path, path="."):
"""Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
Args:
file_path: Path to the archive file.
path: Where to extract the archive file.
Returns:
True if a match was found and an archive extraction was completed,
False otherwise.
"""
logging.basicConfig(level=logging.INFO)

file_path = os.fspath(file_path) if isinstance(file_path, os.PathLike) else file_path
path = os.fspath(path) if isinstance(path, os.PathLike) else path

logging.info(f'Extracting {file_path}')

status = False

if tarfile.is_tarfile(file_path):
with tarfile.open(file_path) as archive:
archive.extractall(path)
status = True
elif zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path) as archive:
archive.extractall(path)
status = True

if status:
logging.info(f'Successfully extracted {file_path} into {path}')
else:
logging.info(f'Failed to extract {file_path} into {path}')
38 changes: 11 additions & 27 deletions cellSAM/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm

import os
from pathlib import Path
import yaml
from pkg_resources import resource_filename

Expand All @@ -24,44 +23,29 @@
fill_holes_and_remove_small_masks,
subtract_boundaries,
)
import requests
from ._auth import fetch_data, extract_archive


def download_file_with_progress(url, destination):
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
__all__ = ["segment_cellular_image"]

progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)

with open(destination, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()

if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR: Something went wrong")

def get_model(model: nn.Module = None) -> nn.Module:
"""
Returns a loaded CellSAM model. If model is None, downloads weights and loads the model with a progress bar.
"""
cellsam_assets_dir = os.path.join(os.path.expanduser("~"), ".cellsam_assets")
model_path = os.path.join(cellsam_assets_dir, "cellsam_base.pt")
cellsam_assets_dir = Path.home() / ".deepcell/models"
model_path = cellsam_assets_dir / "cellsam_base.pt"
config_path = resource_filename(__name__, 'modelconfig.yaml')
with open(config_path, 'r') as config_file:
config = yaml.safe_load(config_file)

if model is None:
if not os.path.exists(cellsam_assets_dir):
os.makedirs(cellsam_assets_dir)
if not os.path.isfile(model_path):
print("Downloading CellSAM model weights, please wait...")
download_file_with_progress(
"https://storage.googleapis.com/cellsam-data/cellsam_base.pt",
model_path,
)
if not cellsam_assets_dir.exists():
cellsam_assets_dir.mkdir(parents=True, exist_ok=True)
if not model_path.exists():
fetch_data("models/cellsam_base.tar.gz", cache_subdir="models")
extract_archive(model_path, cellsam_assets_dir)
assert model_path.exists()
model = CellSAM(config)
model.load_state_dict(torch.load(model_path))
return model
Expand Down

0 comments on commit 110e01e

Please sign in to comment.