diff --git a/examples/run_hyena_dna.py b/examples/run_hyena_dna.py index d6ba91d9..f015b2c4 100644 --- a/examples/run_hyena_dna.py +++ b/examples/run_hyena_dna.py @@ -3,6 +3,6 @@ hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256") model = HyenaDNA(configurer = hyena_config) sequence = 'ACTG' * int(1024/4) -data = model.process_data(sequence) -embeddings = model.get_embeddings(data) +tokenized_sequence = model.process_data(sequence) +embeddings = model.get_embeddings(tokenized_sequence) print(embeddings.shape) diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index bb98efeb..8ed42190 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -29,8 +29,6 @@ class Geneformer(HelicalBaseModel): >>> dataset = geneformer.process_data(ann_data[:100]) >>> embeddings = geneformer.get_embeddings(dataset) - - Parameters ---------- configurer : GeneformerConfig, optional, default = default_configurer diff --git a/helical/models/hyena_dna/hyena_dna_config.py b/helical/models/hyena_dna/hyena_dna_config.py index eeb4cffd..87f7558a 100644 --- a/helical/models/hyena_dna/hyena_dna_config.py +++ b/helical/models/hyena_dna/hyena_dna_config.py @@ -75,7 +75,7 @@ def __init__( "hyenadna-tiny-1k-seqlen-d256": { 'd_model': 256, 'd_inner': 1024, - 'max_length': 1024, # TODO double check this + 'max_length': 1024, # TODO double check this and include more models } } diff --git a/helical/models/hyena_dna/model.py b/helical/models/hyena_dna/model.py index 648a89f6..02cf2565 100644 --- a/helical/models/hyena_dna/model.py +++ b/helical/models/hyena_dna/model.py @@ -9,7 +9,35 @@ LOGGER = logging.getLogger(__name__) class HyenaDNA(HelicalBaseModel): - """HyenaDNA model.""" + """HyenaDNA model. + This class represents the HyenaDNA model, which is a long-range genomic foundation model pretrained on context lengths of up to 1 million tokens at single nucleotide resolution. + + Example + ------- + >>> from helical.models.hyena_dna.model import HyenaDNA, HyenaDNAConfig + >>> hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256") + >>> model = HyenaDNA(configurer = hyena_config) + >>> sequence = 'ACTG' * int(1024/4) + >>> tokenized_sequence = model.process_data(sequence) + >>> embeddings = model.get_embeddings(tokenized_sequence) + >>> print(embeddings.shape) + + Parameters + ---------- + default_configurer : HyenaDNAConfig, optional, default = default_configurer + The model configuration. + + Returns + ------- + None + + Notes + ----- + The link to the paper can be found `here `_. + We use the implementation from the `hyena-dna `_ repository. + + """ + default_configurer = HyenaDNAConfig() def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None: @@ -37,8 +65,20 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None: self.model.eval() LOGGER.info(f"Model finished initializing.") - def process_data(self, sequence): + def process_data(self, sequence: str) -> torch.Tensor: + """Process the input DNA sequence. + + Parameters + ---------- + sequence: str + The input DNA sequence to be processed. + Returns + ------- + torch.Tensor + The processed tokenized sequence. + + """ tok_seq = self.tokenizer(sequence) tok_seq = tok_seq["input_ids"] # grab ids @@ -47,7 +87,17 @@ def process_data(self, sequence): tok_seq = tok_seq.to(self.device) return tok_seq - def get_embeddings(self, tok_seq): + def get_embeddings(self, tok_seq: torch.Tensor) -> torch.Tensor: + """Get the embeddings for the tokenized sequence. + + Args: + tok_seq: torch.Tensor + The tokenized sequence. + + Returns: + torch.Tensor: The embeddings for the tokenized sequence. + + """ LOGGER.info(f"Inference started") with torch.inference_mode(): return self.model(tok_seq) \ No newline at end of file