Skip to content

Commit

Permalink
Refactor notebook to handle the case of more than
Browse files Browse the repository at this point in the history
2 classes. Include Github Action to get the data
and generate the embeddings with the Hyena Helical
model.
  • Loading branch information
bputzeys committed Jun 3, 2024
1 parent e1fd342 commit f89a6aa
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 150 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/get_embeddings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: CI Pipeline

on:
workflow_dispatch:

jobs:
build:

runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2

- name: setup python
uses: actions/setup-python@v5
with:
python-version: 3.11.8
cache: 'pip' # caching pip dependencies

- name: Install dependencies
run: |
pip install .
# First download before tests as they make use of the downloaded files
- name: Download all files
run: |
python ci/download_all.py
- name: Execute script to get embeddings
run: |
python ci/get_all_data_embeddings.py
- name: Upload numpy embeddings data
uses: actions/upload-artifact@v2
with:
name: upload-data
path: data/


58 changes: 58 additions & 0 deletions ci/get_all_data_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
from datasets import DatasetDict
from tqdm import tqdm
from helical.models.hyena_dna.model import HyenaDNA
from helical.models.hyena_dna.hyena_dna_config import HyenaDNAConfig
from datasets import get_dataset_config_names
from datasets import load_dataset
import os

configurer = HyenaDNAConfig(model_name="hyenadna-tiny-1k-seqlen-d256")
hyena_model = HyenaDNA(configurer=configurer)


def get_model_inputs(dataset: DatasetDict, ratio: float = 1.0):

x = np.empty((0, configurer.config['d_model']))
labels = np.empty((0,), dtype=int)

# disable logging to avoid cluttering the output
import logging
logging.disable(logging.CRITICAL)

# use tqdm for a progress bar
length = int(len(dataset)*ratio)
for i in tqdm(range(length)):
sequence = dataset["sequence"][i]

tokenized_sequence = hyena_model.process_data(sequence)
embeddings = hyena_model.get_embeddings(tokenized_sequence)

numpy_array = embeddings[0].detach().numpy()
mean_array = numpy_array.mean(axis=0)
x = np.append(x, [mean_array], axis=0)

# normalize the data
x = (x - np.mean(x, axis=0)) / np.std(x, axis=0)
labels = np.array(dataset["label"][:length])
return x, labels

labels = get_dataset_config_names("InstaDeepAI/nucleotide_transformer_downstream_tasks")

for label in labels:
dataset = load_dataset("InstaDeepAI/nucleotide_transformer_downstream_tasks", label)

x, y = get_model_inputs(dataset["train"], 0.001)

if not os.path.exists("data"):
os.makedirs("data")

np.save(f"data/x_{label}_norm_256", x)
np.save(f"data/y_{label}_norm_256", y)

X_unseen, y_unseen = get_model_inputs(dataset["test"], 0.001)
np.save(f"data/x_unseen_{label}_norm_256", X_unseen)
np.save(f"data/y_unseen_{label}_norm_256", y_unseen)
break


Loading

0 comments on commit f89a6aa

Please sign in to comment.