Skip to content

Commit

Permalink
Include docstrings for HyenaDNA
Browse files Browse the repository at this point in the history
  • Loading branch information
bputzeys committed May 22, 2024
1 parent e2b86d7 commit d212ab8
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
4 changes: 2 additions & 2 deletions examples/run_hyena_dna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256")
model = HyenaDNA(configurer = hyena_config)
sequence = 'ACTG' * int(1024/4)
data = model.process_data(sequence)
embeddings = model.get_embeddings(data)
tokenized_sequence = model.process_data(sequence)
embeddings = model.get_embeddings(tokenized_sequence)
print(embeddings.shape)
2 changes: 0 additions & 2 deletions helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class Geneformer(HelicalBaseModel):
>>> dataset = geneformer.process_data(ann_data[:100])
>>> embeddings = geneformer.get_embeddings(dataset)
Parameters
----------
configurer : GeneformerConfig, optional, default = default_configurer
Expand Down
2 changes: 1 addition & 1 deletion helical/models/hyena_dna/hyena_dna_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
"hyenadna-tiny-1k-seqlen-d256": {
'd_model': 256,
'd_inner': 1024,
'max_length': 1024, # TODO double check this
'max_length': 1024, # TODO double check this and include more models
}
}

Expand Down
56 changes: 53 additions & 3 deletions helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,35 @@
LOGGER = logging.getLogger(__name__)

class HyenaDNA(HelicalBaseModel):
"""HyenaDNA model."""
"""HyenaDNA model.
This class represents the HyenaDNA model, which is a long-range genomic foundation model pretrained on context lengths of up to 1 million tokens at single nucleotide resolution.
Example
-------
>>> from helical.models.hyena_dna.model import HyenaDNA, HyenaDNAConfig
>>> hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256")
>>> model = HyenaDNA(configurer = hyena_config)
>>> sequence = 'ACTG' * int(1024/4)
>>> tokenized_sequence = model.process_data(sequence)
>>> embeddings = model.get_embeddings(tokenized_sequence)
>>> print(embeddings.shape)
Parameters
----------
default_configurer : HyenaDNAConfig, optional, default = default_configurer
The model configuration.
Returns
-------
None
Notes
-----
The link to the paper can be found `here <https://arxiv.org/abs/2306.15794>`_.
We use the implementation from the `hyena-dna <https://github.com/HazyResearch/hyena-dna>`_ repository.
"""

default_configurer = HyenaDNAConfig()

def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
Expand Down Expand Up @@ -37,8 +65,20 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
self.model.eval()
LOGGER.info(f"Model finished initializing.")

def process_data(self, sequence):
def process_data(self, sequence: str) -> torch.Tensor:
"""Process the input DNA sequence.
Parameters
----------
sequence: str
The input DNA sequence to be processed.
Returns
-------
torch.Tensor
The processed tokenized sequence.
"""
tok_seq = self.tokenizer(sequence)
tok_seq = tok_seq["input_ids"] # grab ids

Expand All @@ -47,7 +87,17 @@ def process_data(self, sequence):
tok_seq = tok_seq.to(self.device)
return tok_seq

def get_embeddings(self, tok_seq):
def get_embeddings(self, tok_seq: torch.Tensor) -> torch.Tensor:
"""Get the embeddings for the tokenized sequence.
Args:
tok_seq: torch.Tensor
The tokenized sequence.
Returns:
torch.Tensor: The embeddings for the tokenized sequence.
"""
LOGGER.info(f"Inference started")
with torch.inference_mode():
return self.model(tok_seq)

0 comments on commit d212ab8

Please sign in to comment.