-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor notebook to handle the case of more than
2 classes. Include Github Action to get the data and generate the embeddings with the Hyena Helical model.
- Loading branch information
Showing
3 changed files
with
295 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
name: CI Pipeline | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v2 | ||
|
||
- name: setup python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: 3.11.8 | ||
cache: 'pip' # caching pip dependencies | ||
|
||
- name: Install dependencies | ||
run: | | ||
pip install . | ||
# First download before tests as they make use of the downloaded files | ||
- name: Download all files | ||
run: | | ||
python ci/download_all.py | ||
- name: Execute script to get embeddings | ||
run: | | ||
python ci/get_all_data_embeddings.py | ||
- name: Upload numpy embeddings data | ||
uses: actions/upload-artifact@v2 | ||
with: | ||
name: upload-data | ||
path: data/ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
from datasets import DatasetDict | ||
from tqdm import tqdm | ||
from helical.models.hyena_dna.model import HyenaDNA | ||
from helical.models.hyena_dna.hyena_dna_config import HyenaDNAConfig | ||
from datasets import get_dataset_config_names | ||
from datasets import load_dataset | ||
import os | ||
|
||
configurer = HyenaDNAConfig(model_name="hyenadna-tiny-1k-seqlen-d256") | ||
hyena_model = HyenaDNA(configurer=configurer) | ||
|
||
|
||
def get_model_inputs(dataset: DatasetDict, ratio: float = 1.0): | ||
|
||
x = np.empty((0, configurer.config['d_model'])) | ||
labels = np.empty((0,), dtype=int) | ||
|
||
# disable logging to avoid cluttering the output | ||
import logging | ||
logging.disable(logging.CRITICAL) | ||
|
||
# use tqdm for a progress bar | ||
length = int(len(dataset)*ratio) | ||
for i in tqdm(range(length)): | ||
sequence = dataset["sequence"][i] | ||
|
||
tokenized_sequence = hyena_model.process_data(sequence) | ||
embeddings = hyena_model.get_embeddings(tokenized_sequence) | ||
|
||
numpy_array = embeddings[0].detach().numpy() | ||
mean_array = numpy_array.mean(axis=0) | ||
x = np.append(x, [mean_array], axis=0) | ||
|
||
# normalize the data | ||
x = (x - np.mean(x, axis=0)) / np.std(x, axis=0) | ||
labels = np.array(dataset["label"][:length]) | ||
return x, labels | ||
|
||
labels = get_dataset_config_names("InstaDeepAI/nucleotide_transformer_downstream_tasks") | ||
|
||
for label in labels: | ||
dataset = load_dataset("InstaDeepAI/nucleotide_transformer_downstream_tasks", label) | ||
|
||
x, y = get_model_inputs(dataset["train"], 0.001) | ||
|
||
if not os.path.exists("data"): | ||
os.makedirs("data") | ||
|
||
np.save(f"data/x_{label}_norm_256", x) | ||
np.save(f"data/y_{label}_norm_256", y) | ||
|
||
X_unseen, y_unseen = get_model_inputs(dataset["test"], 0.001) | ||
np.save(f"data/x_unseen_{label}_norm_256", X_unseen) | ||
np.save(f"data/y_unseen_{label}_norm_256", y_unseen) | ||
break | ||
|
||
|
Oops, something went wrong.