Skip to content

Commit

Permalink
Move actual download step back to init functions of the models. Other…
Browse files Browse the repository at this point in the history
…wise files are downloaded just on import of classes which is undesirable
  • Loading branch information
bputzeys committed May 22, 2024
1 parent fcaf499 commit afc3a31
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 39 deletions.
2 changes: 2 additions & 0 deletions helical/constants/paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pathlib import Path
CACHE_DIR_HELICAL = Path(Path.home(), '.cache', 'helical', 'models')
22 changes: 11 additions & 11 deletions helical/models/geneformer/geneformer_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional
from helical.services.downloader import Downloader
from pathlib import Path
import os
from helical.constants.paths import CACHE_DIR_HELICAL

class GeneformerConfig():
"""Configuration class to use the Geneformer Model.
Expand Down Expand Up @@ -35,16 +34,17 @@ def __init__(
):

model_name = "geneformer-12L-30M"

downloader = Downloader()
downloader.download_via_name("geneformer/gene_median_dictionary.pkl")
downloader.download_via_name("geneformer/human_gene_to_ensemble_id.pkl")
downloader.download_via_name("geneformer/token_dictionary.pkl")
downloader.download_via_name(f"geneformer/{model_name}/config.json")
downloader.download_via_name(f"geneformer/{model_name}/pytorch_model.bin")
downloader.download_via_name(f"geneformer/{model_name}/training_args.bin")

self.model_dir = Path(os.path.join(downloader.CACHE_DIR_HELICAL, 'geneformer'))
self.list_of_files_to_download = [
"geneformer/gene_median_dictionary.pkl",
"geneformer/human_gene_to_ensemble_id.pkl",
"geneformer/token_dictionary.pkl",
f"geneformer/{model_name}/config.json",
f"geneformer/{model_name}/pytorch_model.bin",
f"geneformer/{model_name}/training_args.bin",
]

self.model_dir = Path(CACHE_DIR_HELICAL, 'geneformer')
self.model_name = model_name
self.batch_size = batch_size
self.emb_layer = emb_layer
Expand Down
6 changes: 5 additions & 1 deletion helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
import numpy as np
from anndata import AnnData
import os
from helical.services.downloader import Downloader
import pickle
from transformers import BertForMaskedLM
from helical.models.geneformer.geneformer_utils import get_embs,quant_layers
Expand Down Expand Up @@ -53,6 +53,10 @@ def __init__(self, configurer: GeneformerConfig = default_configurer) -> None:
self.log = logging.getLogger("Geneformer-Model")
self.device = self.config.device

downloader = Downloader()
for file in self.config.list_of_files_to_download:
downloader.download_via_name(file)

self.model = BertForMaskedLM.from_pretrained(self.config.model_dir / self.config.model_name, output_hidden_states=True, output_attentions=False)
self.model.eval()#.to("cuda:0")
self.model = self.model.to(self.device)
Expand Down
9 changes: 4 additions & 5 deletions helical/models/hyena_dna/hyena_dna_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Literal
from pathlib import Path
from helical.services.downloader import Downloader
import os
from helical.constants.paths import CACHE_DIR_HELICAL
class HyenaDNAConfig():
"""
Configuration class for Hyena DNA model.
Expand Down Expand Up @@ -83,11 +82,11 @@ def __init__(
if model_name not in self.model_map:
raise ValueError(f"Model name {model_name} not found in available models: {self.model_map.keys()}")

downloader = Downloader()
downloader.download_via_name(f"hyena_dna/{model_name}.ckpt")
list_of_files_to_download = [f"hyena_dna/{model_name}.ckpt"]

self.config = {
"model_path": Path(os.path.join(downloader.CACHE_DIR_HELICAL, f"hyena_dna/{model_name}.ckpt")),
"model_path": Path(CACHE_DIR_HELICAL, f"hyena_dna/{model_name}.ckpt"),
"list_of_files_to_download": list_of_files_to_download,
"d_model": self.model_map[model_name]['d_model'],
"n_layer": n_layer,
"d_inner": self.model_map[model_name]['d_inner'],
Expand Down
7 changes: 6 additions & 1 deletion helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from helical.models.hyena_dna.pretrained_model import HyenaDNAPreTrainedModel
import torch
from .standalone_hyenadna import CharacterTokenizer
from helical.services.downloader import Downloader

class HyenaDNA(HelicalBaseModel):
"""HyenaDNA model."""
Expand All @@ -14,7 +15,11 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
super().__init__()
self.config = configurer.config
self.log = logging.getLogger("Hyena-DNA-Model")


downloader = Downloader()
for file in self.config["list_of_files_to_download"]:
downloader.download_via_name(file)

self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.model = HyenaDNAPreTrainedModel().from_pretrained(self.config)
Expand Down
5 changes: 5 additions & 0 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Literal
from accelerate import Accelerator
from helical.models.scgpt.scgpt_utils import load_model, get_embedding
from helical.services.downloader import Downloader

os.environ['KMP_DUPLICATE_LIB_OK']='True'

Expand Down Expand Up @@ -49,6 +50,10 @@ def __init__(self, configurer: scGPTConfig = configurer) -> None:
self.config = configurer.config
self.log = logging.getLogger("scGPT-Model")

downloader = Downloader()
for file in self.config["list_of_files_to_download"]:
downloader.download_via_name(file)

self.model, self.vocab = load_model(self.config)

if self.config["accelerator"]:
Expand Down
14 changes: 8 additions & 6 deletions helical/models/scgpt/scgpt_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from helical.services.downloader import Downloader
from helical.constants.paths import CACHE_DIR_HELICAL
from pathlib import Path
import os

class scGPTConfig():
"""
Configuration class to use the scGPT Model.
Expand Down Expand Up @@ -70,13 +70,15 @@ def __init__(
):

model_name = 'best_model' # TODO: Include more models
downloader = Downloader()
downloader.download_via_name("scgpt/scGPT_CP/vocab.json")
downloader.download_via_name(f"scgpt/scGPT_CP/{model_name}.pt")
model_path = Path(os.path.join(downloader.CACHE_DIR_HELICAL, 'scgpt/scGPT_CP', f'{model_name}.pt'))
list_of_files_to_download = [
"scgpt/scGPT_CP/vocab.json",
f"scgpt/scGPT_CP/{model_name}.pt",
]
model_path = Path(CACHE_DIR_HELICAL, 'scgpt/scGPT_CP', f'{model_name}.pt')

self.config = {
"model_path": model_path,
"list_of_files_to_download": list_of_files_to_download,
"pad_token": pad_token,
"batch_size": batch_size,
"fast_transformer": fast_transformer,
Expand Down
5 changes: 5 additions & 0 deletions helical/models/uce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from helical.models.helical import HelicalBaseModel
from helical.models.uce.uce_utils import get_ESM2_embeddings, load_model, process_data, get_gene_embeddings
from accelerate import Accelerator
from helical.services.downloader import Downloader

class UCE(HelicalBaseModel):
"""Universal Cell Embedding Model. This model reads in single-cell RNA-seq data and outputs gene embeddings.
Expand Down Expand Up @@ -42,6 +43,10 @@ def __init__(self, configurer: UCEConfig = default_configurer) -> None:
self.config = configurer.config
self.log = logging.getLogger("UCE-Model")

downloader = Downloader()
for file in self.config["list_of_files_to_download"]:
downloader.download_via_name(file)

self.model_dir = self.config['model_path'].parent

self.embeddings = get_ESM2_embeddings(self.config["token_file_path"], self.config["token_dim"])
Expand Down
21 changes: 11 additions & 10 deletions helical/models/uce/uce_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional
from typing import Literal
from helical.services.downloader import Downloader
import os
from helical.constants.paths import CACHE_DIR_HELICAL
from pathlib import Path
class UCEConfig():
"""Configuration class to use the Universal Cell-Embedding Model.
Expand Down Expand Up @@ -77,17 +76,19 @@ def __init__(self,
if model_name not in self.model_map:
raise ValueError(f"Model name {model_name} not found in available models: {self.model_map.keys()}.")

downloader = Downloader()
downloader.download_via_name("uce/all_tokens.torch")
downloader.download_via_name(f"uce/{model_name}.torch")
downloader.download_via_name("uce/species_chrom.csv")
downloader.download_via_name("uce/species_offsets.pkl")
downloader.download_via_name("uce/protein_embeddings/Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt")
downloader.download_via_name("uce/protein_embeddings/Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt")
model_path = Path(os.path.join(downloader.CACHE_DIR_HELICAL, 'uce', f"{model_name}.torch"))
list_of_files_to_download = [
"uce/all_tokens.torch",
f"uce/{model_name}.torch",
"uce/species_chrom.csv",
"uce/species_offsets.pkl",
"uce/protein_embeddings/Homo_sapiens.GRCh38.gene_symbol_to_embedding_ESM2.pt",
"uce/protein_embeddings/Macaca_fascicularis.Macaca_fascicularis_6.0.gene_symbol_to_embedding_ESM2.pt",
]
model_path = Path(CACHE_DIR_HELICAL, 'uce', f"{model_name}.torch")

self.config = {
"model_path": model_path,
"list_of_files_to_download": list_of_files_to_download,
"batch_size": batch_size,
"pad_length": pad_length,
"pad_token_idx": pad_token_idx,
Expand Down
8 changes: 3 additions & 5 deletions helical/services/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azure.storage.blob import BlobClient
from azure.core.pipeline.transport import RequestsTransport
from git import Repo
from helical.constants.paths import CACHE_DIR_HELICAL

INTERVAL = 1000 # interval to get gene mappings
CHUNK_SIZE = 1024 * 1024 * 10 #8192 # size of individual chunks to download
Expand All @@ -20,7 +21,6 @@ class Downloader(Logger):
def __init__(self, loging_type = LoggingType.CONSOLE, level = LoggingLevel.INFO) -> None:
super().__init__(loging_type, level)
self.log = logging.getLogger("Downloader")
self.CACHE_DIR_HELICAL = os.path.join(str(Path.home()),'.cache/helical/models')
self.display = True

# manually create a requests session
Expand Down Expand Up @@ -145,8 +145,7 @@ def download_via_name_v0(self, name: str) -> None:
link: URL to download the file from.
'''
main_link = "https://helicalpackage.blob.core.windows.net/helicalpackage/data"
CACHE_DIR_HELICAL = Path(self.CACHE_DIR_HELICAL)
output = os.path.join(CACHE_DIR_HELICAL,name)
output = os.path.join(CACHE_DIR_HELICAL, name)

link = f"{main_link}/{name}"
if not os.path.exists(os.path.dirname(output)):
Expand Down Expand Up @@ -192,8 +191,7 @@ def download_via_name(self, name: str) -> None:
'''

main_link = "https://helicalpackage.blob.core.windows.net/helicalpackage/data"
CACHE_DIR_HELICAL = Path(self.CACHE_DIR_HELICAL)
output = os.path.join(CACHE_DIR_HELICAL,name)
output = os.path.join(CACHE_DIR_HELICAL, name)

blob_url = f"{main_link}/{name}"

Expand Down

0 comments on commit afc3a31

Please sign in to comment.