Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor fine-tuning #112

Merged
merged 10 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ jobs:
run: |
sed -i 's/train\[:65%\]/train\[:5%\]/g' ./examples/notebooks/Cell-Type-Annotation.ipynb
sed -i 's/train\[70%:\]/train\[5%:7%\]/g' ./examples/notebooks/Cell-Type-Annotation.ipynb
sed -i 's/get_anndata_from_hf_dataset(ds\[\\"train\\"\])/get_anndata_from_hf_dataset(ds\[\\"train\\"\])[:10]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb
sed -i 's/get_anndata_from_hf_dataset(ds\[\\"test\\"\])/get_anndata_from_hf_dataset(ds\[\\"test\\"\])[:2]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb
sed -i 's/list(np.array(train_dataset.obs\[\\"LVL1\\"].tolist()))/list(np.array(train_dataset.obs\[\\"LVL1\\"].tolist()))[:10]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb
sed -i 's/list(np.array(test_dataset.obs\[\\"LVL1\\"].tolist()))/list(np.array(test_dataset.obs\[\\"LVL1\\"].tolist()))[:2]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb
sed -i 's/get_anndata_from_hf_dataset(ds\[\\"train\\"\])/get_anndata_from_hf_dataset(ds\[\\"train\\"\])[:100]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb
sed -i 's/get_anndata_from_hf_dataset(ds\[\\"test\\"\])/get_anndata_from_hf_dataset(ds\[\\"test\\"\])[:10]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb
sed -i 's/list(np.array(train_dataset.obs\[\\"LVL1\\"].tolist()))/list(np.array(train_dataset.obs\[\\"LVL1\\"].tolist()))[:100]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb
sed -i 's/list(np.array(test_dataset.obs\[\\"LVL1\\"].tolist()))/list(np.array(test_dataset.obs\[\\"LVL1\\"].tolist()))[:10]/g' ./examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb

- name: Run Notebooks
run: |
Expand Down
8 changes: 5 additions & 3 deletions ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,12 @@ def test_cls_eos_tokens_presence(self, geneformer, mock_data):
def test_model_input_size(self, geneformer):
assert geneformer.config["input_size"] == geneformer.configurer.model_map[geneformer.config["model_name"]]['input_size']

def test_fine_tune_classifier_returns_correct_shape(self, geneformer, mock_data, fine_tune_mock_data):
tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols')
def test_fine_tune_classifier_returns_correct_shape(self, mock_data, fine_tune_mock_data):
device = "cuda" if torch.cuda.is_available() else "cpu"
fine_tuned_model = GeneformerFineTuningModel(GeneformerConfig(device=device), fine_tuning_head="classification", output_size=1)
tokenized_dataset = fine_tuned_model.process_data(mock_data, gene_names='gene_symbols')
tokenized_dataset = tokenized_dataset.add_column('labels', fine_tune_mock_data)
fine_tuned_model = GeneformerFineTuningModel(geneformer, fine_tuning_head="classification", output_size=1)

fine_tuned_model.train(train_dataset=tokenized_dataset, label='labels')
assert fine_tuned_model is not None
outputs = fine_tuned_model.get_outputs(tokenized_dataset)
Expand Down
17 changes: 8 additions & 9 deletions ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import pytest
import torch
from helical import HyenaDNA, HyenaDNAConfig, HyenaDNAFineTuningModel
from helical import HyenaDNAConfig, HyenaDNAFineTuningModel

class TestHyenaDNAFineTuning:
@pytest.fixture(params=["hyenadna-tiny-1k-seqlen", "hyenadna-tiny-1k-seqlen-d256"])
def hyenaDNA(self, request):
def hyenaDNAFineTune(self, request):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
config = HyenaDNAConfig(model_name=request.param, batch_size=1, device=self.device)
return HyenaDNA(config)
return HyenaDNAFineTuningModel(hyena_config=config, fine_tuning_head="classification", output_size=1)

@pytest.fixture
def mock_data(self, hyenaDNA):
def mock_data(self, hyenaDNAFineTune):
input_sequences = ["AAAA", "CCCC", "TTTT", "ACGT", "ACGN", "BHIK", "ANNT"]
labels = [0, 0, 0, 0, 0, 0, 0]
tokenized_sequences = hyenaDNA.process_data(input_sequences)
tokenized_sequences = hyenaDNAFineTune.process_data(input_sequences)
return tokenized_sequences, labels

def test_output_dimensionality_of_fine_tuned_model(self, hyenaDNA, mock_data):
def test_output_dimensionality_of_fine_tuned_model(self, hyenaDNAFineTune, mock_data):
input_sequences, labels = mock_data
hyena_dna_fine_tune = HyenaDNAFineTuningModel(hyena_model=hyenaDNA, fine_tuning_head="classification", output_size=1)
hyena_dna_fine_tune.train(train_input_data=input_sequences, train_labels=labels, validation_input_data=input_sequences, validation_labels=labels)
outputs = hyena_dna_fine_tune.get_outputs(input_sequences)
hyenaDNAFineTune.train(train_input_data=input_sequences, train_labels=labels, validation_input_data=input_sequences, validation_labels=labels)
outputs = hyenaDNAFineTune.get_outputs(input_sequences)
assert outputs.shape == (len(input_sequences), 1)
6 changes: 3 additions & 3 deletions ci/tests/test_scgpt/test_scgpt_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from helical.models.scgpt.model import scGPT
from helical.models.scgpt.model import scGPT, scGPTConfig
from helical.models.scgpt.fine_tuning_model import scGPTFineTuningModel
from anndata import AnnData
from helical.models.scgpt.tokenizer import GeneVocab
Expand Down Expand Up @@ -110,9 +110,9 @@ def test_ensure_data_validity__no_error(self, data):
assert "total_counts" in data.obs

def test_fine_tune_classification_returns_correct_shape(self):
tokenized_dataset = self.scgpt.process_data(self.data)
labels = list([0])
fine_tuned_model = scGPTFineTuningModel(self.scgpt, fine_tuning_head="classification", output_size=1)
fine_tuned_model = scGPTFineTuningModel(scGPTConfig(), fine_tuning_head="classification", output_size=1)
tokenized_dataset = fine_tuned_model.process_data(self.data)
fine_tuned_model.train(train_input_data=tokenized_dataset, train_labels=labels)
assert fine_tuned_model is not None
outputs = fine_tuned_model.get_outputs(tokenized_dataset)
Expand Down
27 changes: 15 additions & 12 deletions docs/model_cards/geneformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Key improvements in v2.0:

**Example Usage:**
```python
from helical.models.geneformer.model import Geneformer,GeneformerConfig
from helical import Geneformer, GeneformerConfig
import anndata as ad

# Example configuration
Expand Down Expand Up @@ -197,18 +197,11 @@ print("Cancer-tuned model embeddings shape:", cancer_embeddings.shape)
## How To Fine-Tune

```python
from helical.models.geneformer.geneformer_config import Geneformer,GeneformerConfig
from helical.models.geneformer.fine_tuning_model import GeneformerFineTuningModel
from helical import GeneformerConfig, GeneformerFineTuningModel

# Create the Geneformer model with relevant configs
model_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10)
geneformer = Geneformer(configurer = model_config)

# Prepare the data
ann_data = ad.read_h5ad("dataset.h5ad")

# Process the data for training
dataset = geneformer.process_data(ann_data)

# Get the desired label class
cell_types = list(ann_data.obs.cell_type)

Expand All @@ -223,10 +216,20 @@ for i in range(len(cell_types)):
dataset = dataset.add_column('cell_types', cell_types)

# Create the fine-tuning model
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_model=geneformer, fine_tuning_head="classification", label="cell_types", output_size=len(label_set))
model_config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=10)
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=model_config, fine_tuning_head="classification", label="cell_types", output_size=len(label_set))

# Process the data for training
dataset = geneformer_fine_tune.process_data(ann_data)

# Fine-tune
geneformer_fine_tune.train(train_dataset=dataset["train"])
geneformer_fine_tune.train(train_dataset=dataset)

# Get outputs of the fine-tuned model
outputs = geneformer_fine_tune.get_outputs(dataset)

# Get the embeddings of the fine-tuned model
embeddings = geneformer_fine_tune.get_embeddings(dataset)

```

Expand Down
27 changes: 26 additions & 1 deletion docs/model_cards/hyenadna.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@

**Example Usage:**
```python
from helical.models.hyena_dna.model import HyenaDNA, HyenaDNAConfig
from helical import HyenaDNA, HyenaDNAConfig

hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256")
model = HyenaDNA(configurer = hyena_config)
Expand All @@ -108,6 +108,31 @@ embeddings = model.get_embeddings(tokenized_sequence)
print(embeddings.shape)
```

## How to Fine-Tune
```python
from datasets import load_dataset
from helical import HyenaDNAConfig, HyenaDNAFineTuningModel
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load a Hugging Face dataset and task type
ds = load_dataset("dataset", "task")

# Define the desired configs
config = HyenaDNAConfig(device=device, batch_size=10)

# Define the fine-tuning model with the configs we instantiated above
hyena_fine_tune = HyenaDNAFineTuningModel(config, "classification", number_unique_outputs)

# Prepare the sequences for input to the model
input_dataset = hyena_fine_tune.process_data(ds["train"]["sequence"])

# train the fine-tuning model on some downstream task
hyena_fine_tune.train(input_dataset, ds["train"]["label"])

```

## Citation

@article{nguyen2023hyenadna,
Expand Down
25 changes: 12 additions & 13 deletions docs/model_cards/scgpt.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,30 @@ print(embeddings.shape)
## How To Fine-Tune

```python
from helical.models.scgpt.fine_tuning_model import scGPTFineTuningModel
from helical.models.scgpt.model import scGPT,scGPTConfig
from helical import scGPTFineTuningModel, scGPTConfig

# Create the Geneformer model with relevant configs
scgpt_config=scGPTConfig(batch_size=10)
scgpt = scGPT(configurer=scgpt_config)

# Load the desired dataset
adata = ad.read_h5ad("dataset.h5ad")

# Process the data for training
data = scgpt.process_data(adata)

# Get the desired label class
cell_types = list(ann_data.obs.cell_type)

# Create a dictionary mapping the classes to unique integers for training
# Get unique labels
label_set = set(cell_types)

# Create the fine-tuning model with the relevant configs
scgpt_config=scGPTConfig(batch_size=10)
scgpt_fine_tune = scGPTFineTuningModel(scGPT_config=scgpt_config, fine_tuning_head="classification", output_size=len(label_set))

# Process the data for training
data = scgpt_fine_tune.process_data(adata)

# Create a dictionary mapping the classes to unique integers for training
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
cell_types[i] = class_id_dict[cell_types[i]]

# Create the fine-tuning model
scgpt_fine_tune = scGPTFineTuningModel(scGPT_model=scgpt, fine_tuning_head="classification", output_size=len(label_set))

# Fine-tune
scgpt_fine_tune.train(train_input_data=dataset, train_labels=cell_types)
```
Expand Down
20 changes: 10 additions & 10 deletions docs/model_cards/uce.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,31 @@ print(embeddings.shape)
## How To Fine-Tune

```python
from helical.models.uce.model import UCE, UCEConfig
from helical.models.uce.fine_tuning_model import UCEFineTuningModel
from helical import UCEConfig, UCEFineTuningModel
import anndata as ad

configurer=UCEConfig(batch_size=10)
uce = UCE(configurer=configurer)

# Load the data
ann_data = ad.read_h5ad("dataset.h5ad")

# Get unique output labels
label_set = set(cell_types)

# Create the fine-tuning model with the desired configs
configurer=UCEConfig(batch_size=10)
uce_fine_tune = UCEFineTuningModel(uce_config=configurer, fine_tuning_head="classification", output_size=len(label_set))

# Process the data for training
dataset = uce.process_data(ann_data)
dataset = uce_fine_tune.process_data(ann_data)

# Get the desired label class
cell_types = list(ann_data.obs.cell_type)

# Create a dictionary mapping the classes to unique integers for training
label_set = set(cell_types)
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

for i in range(len(cell_types)):
cell_types[i] = class_id_dict[cell_types[i]]

# Create the fine-tuning model
uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set))

# Fine-tune
uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types)

Expand Down
14 changes: 7 additions & 7 deletions examples/fine_tune_models/fine_tune_UCE.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from helical import UCEConfig, UCE, UCEFineTuningModel
from helical import UCEConfig, UCEFineTuningModel
from helical.utils import get_anndata_from_hf_dataset
from datasets import load_dataset
from omegaconf import DictConfig
import hydra

@hydra.main(version_base=None, config_path="../run_models/configs", config_name="uce_config")
def run_fine_tuning(cfg: DictConfig):
uce_config=UCEConfig(**cfg)
uce = UCE(configurer=uce_config)

hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
ann_data = get_anndata_from_hf_dataset(hf_dataset)

dataset = uce.process_data(ann_data[:10], name="train")

cell_types = ann_data.obs["LVL1"][:10].tolist()

label_set = set(cell_types)

uce_config=UCEConfig(**cfg)
uce_fine_tune = UCEFineTuningModel(uce_config=uce_config, fine_tuning_head="classification", output_size=len(label_set))

dataset = uce_fine_tune.process_data(ann_data[:10], name="train")

class_id_dict = {label: i for i, label in enumerate(label_set)}
cell_types = [class_id_dict[cell] for cell in cell_types]

uce_fine_tune = UCEFineTuningModel(uce_model=uce, fine_tuning_head="classification", output_size=len(label_set))
uce_fine_tune.train(train_input_data=dataset, train_labels=cell_types)

if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions examples/fine_tune_models/fine_tune_geneformer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from helical import GeneformerConfig, Geneformer, GeneformerFineTuningModel
from helical import GeneformerConfig, GeneformerFineTuningModel
from helical.utils import get_anndata_from_hf_dataset
from datasets import load_dataset
import hydra
from omegaconf import DictConfig

@hydra.main(version_base=None, config_path="../run_models/configs", config_name="geneformer_config")
def run_fine_tuning(cfg: DictConfig):
geneformer_config = GeneformerConfig(**cfg)
geneformer = Geneformer(configurer = geneformer_config)

hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
ann_data = get_anndata_from_hf_dataset(hf_dataset)

cell_types = list(ann_data.obs["LVL1"][:10])
label_set = set(cell_types)

geneformer_config = GeneformerConfig(**cfg)
geneformer_fine_tune = GeneformerFineTuningModel(geneformer_config=geneformer_config, fine_tuning_head="classification", output_size=len(label_set))

dataset = geneformer.process_data(ann_data[:10])
dataset = geneformer_fine_tune.process_data(ann_data[:10])

dataset = dataset.add_column('cell_types', cell_types)
label_set = set(dataset["cell_types"])
class_id_dict = dict(zip(label_set, [i for i in range(len(label_set))]))

def classes_to_ids(example):
Expand All @@ -26,7 +27,6 @@ def classes_to_ids(example):

dataset = dataset.map(classes_to_ids, num_proc=1)

geneformer_fine_tune = GeneformerFineTuningModel(geneformer_model=geneformer, fine_tuning_head="classification", output_size=len(label_set))
geneformer_fine_tune.train(train_dataset=dataset)

if __name__ == "__main__":
Expand Down
15 changes: 7 additions & 8 deletions examples/fine_tune_models/fine_tune_scgpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from helical import scGPTConfig, scGPT, scGPTFineTuningModel
from helical import scGPTConfig, scGPTFineTuningModel
from helical.utils import get_anndata_from_hf_dataset
from datasets import load_dataset
from omegaconf import DictConfig
Expand All @@ -7,21 +7,20 @@

@hydra.main(version_base=None, config_path="../run_models/configs", config_name="scgpt_config")
def run_fine_tuning(cfg: DictConfig):
scgpt_config=scGPTConfig(**cfg)
scgpt = scGPT(configurer=scgpt_config)

hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
ann_data = get_anndata_from_hf_dataset(hf_dataset)

dataset = scgpt.process_data(ann_data[:10])

cell_types = ann_data.obs["LVL1"][:10].tolist()

label_set = set(cell_types)

scgpt_config=scGPTConfig(**cfg)
scgpt_fine_tune = scGPTFineTuningModel(scGPT_config=scgpt_config, fine_tuning_head="classification", output_size=len(label_set))

dataset = scgpt_fine_tune.process_data(ann_data[:10])

class_id_dict = {label: i for i, label in enumerate(label_set)}
cell_types = [class_id_dict[cell] for cell in cell_types]

scgpt_fine_tune = scGPTFineTuningModel(scGPT_model=scgpt, fine_tuning_head="classification", output_size=len(label_set))
scgpt_fine_tune.train(train_input_data=dataset, train_labels=cell_types)

if __name__ == "__main__":
Expand Down
Loading
Loading