Skip to content

Commit

Permalink
Prepare for HyenaDNA model
Browse files Browse the repository at this point in the history
  • Loading branch information
bputzeys committed May 20, 2024
1 parent 3e22994 commit a46da23
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/run_hyena_dna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from helical.models.hyena_dna.model import HyenaDNA,HyenaDNAConfig
config = HyenaDNAConfig(model_name="hyenadna-tiny-1k-seqlen-d256")
model = HyenaDNA(model_config=config)
print("Done")
2 changes: 2 additions & 0 deletions helical/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .models.uce.model import UCEConfig, UCE
from .models.geneformer.model import Geneformer,GeneformerConfig
from .models.scgpt.model import scGPT, scGPTConfig
from .models.hyena_dna.model import HyenaDNA, HyenaDNAConfig

61 changes: 61 additions & 0 deletions helical/models/hyena_dna/hyena_dna_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Literal
class HyenaDNAConfig():
def __init__(
self,
model_name: Literal["hyenadna-tiny-1k-seqlen", "hyenadna-tiny-1k-seqlen-d256"] = "hyenadna-tiny-1k-seqlen",
n_layer: int = 2,
vocab_size: int = 12,
resid_dropout: float = 0.0,
embed_dropout: float = 0.1,
fused_mlp: bool = False,
fused_dropout_add_ln: bool = True,
residual_in_fp32: bool = True,
pad_vocab_size_multiple: int = 8,
return_hidden_state: bool = True,
layer: dict = {
"_name_": "hyena",
"emb_dim": 5,
"filter_order": 64,
"local_order": 3,
"l_max": 1026,
"modulate": True,
"w": 10,
"lr": 6e-4,
"wd": 0.0,
"lr_pos_emb": 0.0
}
):

# model specific parameters
self.model_map = {
"hyenadna-tiny-1k-seqlen": {
'd_model': 128,
'd_inner': 512,
},
"hyenadna-tiny-1k-seqlen-d256": {
'd_model': 256,
'd_inner': 1024,
}
}

if model_name not in self.model_map:
raise ValueError(f"Model name {model_name} not found in available models: {self.model_map.keys()}")

self.config = {
"model_name": model_name,
"d_model": self.model_map[model_name]['d_model'],
"n_layer": n_layer,
"d_inner": self.model_map[model_name]['d_inner'],
"vocab_size": vocab_size,
"resid_dropout": resid_dropout,
"embed_dropout": embed_dropout,
"fused_mlp": fused_mlp,
"fused_dropout_add_ln": fused_dropout_add_ln,
"residual_in_fp32": residual_in_fp32,
"pad_vocab_size_multiple": pad_vocab_size_multiple,
"return_hidden_state": return_hidden_state,
"layer": layer
}



37 changes: 37 additions & 0 deletions helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging
import numpy as np
from anndata import AnnData
from torch.utils.data import DataLoader
import os
from pathlib import Path
from helical.models.hyena_dna.hyena_dna_config import HyenaDNAConfig
from helical.models.helical import HelicalBaseModel
from helical.models.uce.uce_utils import get_ESM2_embeddings, load_model, process_data, get_gene_embeddings
from accelerate import Accelerator
from helical.services.downloader import Downloader
from typing import Optional

class HyenaDNA(HelicalBaseModel):
"""HyenaDNA model."""
default_config = HyenaDNAConfig()

def __init__(self, model_dir: Optional[str] = None, model_config: HyenaDNAConfig = default_config) -> None:
super().__init__()
self.model_config = model_config.config
self.log = logging.getLogger("Hyena-DNA-Model")

if model_dir is None:
self.downloader = Downloader()
model_path = f"hyena_dna/{self.model_config['model_name']}.ckpt"
self.downloader.download_via_name(model_path)
self.model_path = Path(os.path.join(self.downloader.CACHE_DIR_HELICAL, model_path))
else:
self.model_path = Path(os.path.join(model_dir, f"{self.model_config['model_name']}.ckpt"))



def process_data(self):
pass

def get_embeddings(self):
pass

0 comments on commit a46da23

Please sign in to comment.