Skip to content

Commit

Permalink
Merge pull request #98 from helicalAI/fix-scGPT-fine-tuning
Browse files Browse the repository at this point in the history
Fix Weight Scrambling in scGPT
  • Loading branch information
mattwoodx authored Oct 2, 2024
2 parents 68ff34f + 91a4589 commit 8d277fe
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 191 deletions.
File renamed without changes.
52 changes: 52 additions & 0 deletions ci/tests/test_utils/test_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import numpy as np
from datasets import Dataset, Features, Value, Sequence
from helical.utils import get_anndata_from_hf_dataset

def create_mock_dataset(gene_names: str):
data = {
'raw_counts': [
[1, 2, 3, 2],
[98],
[72, 19]
],
'rows': [
[0, 1, 2, 3],
[0],
[1, 3]
],
'obs1': [10, 20, 30],
'obs2': [40, 50, 60],
'size': [4, 4, 4]
}
features = Features({
'raw_counts': Sequence(Value('uint32'), -1, gene_names),
'rows': Sequence(Value('uint32')),
'obs1': Value('int64'),
'obs2': Value('int64'),
'size': Value('uint32')
})
dataset = Dataset.from_dict(data, features=features)
return dataset

def test_get_anndata_from_hf_dataset():
dataset = create_mock_dataset("gene1,gene2,gene3,gene4")
ann_data = get_anndata_from_hf_dataset(dataset)

assert ann_data.shape == (3, 4)

# assert that observation names are correct (ie. no 'rows', 'raw_counts', or 'size')
assert list(ann_data.obs.columns) == ['obs1', 'obs2']

# assert that gene names are converted to uppercase
assert list(ann_data.var_names) == ['GENE1', 'GENE2', 'GENE3', 'GENE4']
assert list(ann_data.var['gene_name']) == ['GENE1', 'GENE2', 'GENE3', 'GENE4']

# assert that counts are placed in the correct positions
assert np.array_equal(ann_data.X.toarray(), np.array([[1, 2, 3, 2], [98, 0, 0, 0], [0, 72, 0, 19]]))

def test_get_anndata_from_hf_dataset_mismatched_gene_names():
dataset = create_mock_dataset("gene1,gene2")

with pytest.raises(ValueError):
get_anndata_from_hf_dataset(dataset)
File renamed without changes.
328 changes: 191 additions & 137 deletions examples/notebooks/Cell-Type-Classification-Fine-Tuning.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions helical/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,8 @@ def forward():

@abstractmethod
def train():
pass

@abstractmethod
def get_outputs():
pass
2 changes: 1 addition & 1 deletion helical/models/geneformer/fine_tuning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def train(
The column in the dataset containing the training labels. These should be stored as unique per class integers.
epochs : int, optional, default = 10
The number of epochs to train the model
freeze_layers : int, optional, default = 0
freeze_layers : int, optional, default = 2
The number of layers to freeze.
validation_dataset : Dataset, default = None
A helical processed dataset for per epoch validation. If this is not specified, no validation will be performed.
Expand Down
87 changes: 43 additions & 44 deletions helical/models/scgpt/fine_tuning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ def train(
The loss function to be used.
epochs : int, optional, default = 10
The number of epochs to train the model
freeze_layers : int, optional, default = 0
The number of layers to freeze.
lr_scheduler_params : dict, default = None
The learning rate scheduler parameters for the transformers get_scheduler method. The optimizer will be taken from the optimizer input and should not be included in the learning scheduler parameters. If not specified, no scheduler will be used.
e.g. lr_scheduler_params = { 'name': 'linear', 'num_warmup_steps': 0, 'num_training_steps': 5 }
Expand Down Expand Up @@ -171,57 +169,57 @@ def train(
)

self.to(device)

self.scgpt_model.train()
self.fine_tuning_head.train()
optimizer = optimizer(self.parameters(), **optimizer_params)

lr_scheduler = None
if lr_scheduler_params is not None:
lr_scheduler = get_scheduler(optimizer=optimizer, **lr_scheduler_params)

logger.info("Starting Fine-Tuning")
with torch.cuda.amp.autocast(enabled=True): #torch.autocast(device_type=str(device),enabled=True): # torch.cuda.amp.autocast(enabled=True):
for j in range(epochs):
batch_count = 0
batch_loss = 0.0
batches_processed = 0
training_loop = tqdm(data_loader)
for data_dict in training_loop:
input_gene_ids = data_dict["gene"].to(device)
src_key_padding_mask = input_gene_ids.eq(
self.vocab[self.config["pad_token"]]
)
output = self._forward(input_gene_ids, data_dict, src_key_padding_mask, use_batch_labels, device)
labels = torch.tensor(train_labels[batch_count: batch_count + self.config["batch_size"]], device=device)
batch_count += self.config["batch_size"]
loss = loss_function(output, labels)
loss.backward()
batch_loss += loss.item()
batches_processed += 1
optimizer.step()
optimizer.zero_grad()
for j in range(epochs):
batch_count = 0
batch_loss = 0.0
batches_processed = 0
training_loop = tqdm(data_loader)
for data_dict in training_loop:
input_gene_ids = data_dict["gene"].to(device)
src_key_padding_mask = input_gene_ids.eq(
self.vocab[self.config["pad_token"]]
)
output = self._forward(input_gene_ids, data_dict, src_key_padding_mask, use_batch_labels, device)
labels = torch.tensor(train_labels[batch_count: batch_count + self.config["batch_size"]], device=device)
batch_count += self.config["batch_size"]
loss = loss_function(output, labels)
loss.backward()
batch_loss += loss.item()
batches_processed += 1
optimizer.step()
optimizer.zero_grad()

training_loop.set_postfix({"loss": batch_loss/batches_processed})
training_loop.set_description(f"Fine-Tuning: epoch {j+1}/{epochs}")
training_loop.set_postfix({"loss": batch_loss/batches_processed})
training_loop.set_description(f"Fine-Tuning: epoch {j+1}/{epochs}")

if lr_scheduler is not None:
lr_scheduler.step()
if lr_scheduler is not None:
lr_scheduler.step()

if validation_input_data is not None:
testing_loop = tqdm(validation_data_loader, desc="Fine-Tuning Validation")
accuracy = 0.0
count = 0.0
validation_batch_count = 0
for validation_data_dict in testing_loop:
input_gene_ids = validation_data_dict["gene"].to(device)
src_key_padding_mask = input_gene_ids.eq(
self.vocab[self.config["pad_token"]]
)
output = self._forward(input_gene_ids, validation_data_dict, src_key_padding_mask, use_batch_labels, device)
val_labels = torch.tensor(validation_labels[validation_batch_count: validation_batch_count + self.config["batch_size"]], device=device)
validation_batch_count += self.config["batch_size"]
accuracy += accuracy_score(val_labels.cpu(), torch.argmax(output, dim=1).cpu())
count += 1.0
testing_loop.set_postfix({"accuracy": accuracy/count})
if validation_input_data is not None:
testing_loop = tqdm(validation_data_loader, desc="Fine-Tuning Validation")
accuracy = 0.0
count = 0.0
validation_batch_count = 0
for validation_data_dict in testing_loop:
input_gene_ids = validation_data_dict["gene"].to(device)
src_key_padding_mask = input_gene_ids.eq(
self.vocab[self.config["pad_token"]]
)
output = self._forward(input_gene_ids, validation_data_dict, src_key_padding_mask, use_batch_labels, device)
val_labels = torch.tensor(validation_labels[validation_batch_count: validation_batch_count + self.config["batch_size"]], device=device)
validation_batch_count += self.config["batch_size"]
accuracy += accuracy_score(val_labels.cpu(), torch.argmax(output, dim=1).cpu())
count += 1.0
testing_loop.set_postfix({"accuracy": accuracy/count})
logger.info(f"Fine-Tuning Complete. Epochs: {epochs}")

def get_outputs(
Expand All @@ -242,7 +240,8 @@ def get_outputs(
"""
device = next(self.scgpt_model.parameters()).device
self.to(device)

self.scgpt_model.eval()
self.fine_tuning_head.eval()
try:
use_batch_labels = dataset.batch_ids is not None
except:
Expand Down
21 changes: 13 additions & 8 deletions helical/models/uce/fine_tuning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def train(
The loss function to be used.
epochs : int, optional, default = 10
The number of epochs to train the model
freeze_layers : int, optional, default = 0
The number of layers to freeze.
lr_scheduler_params : dict, default = None
The learning rate scheduler parameters for the transformers get_scheduler method. The optimizer will be taken from the optimizer input and should not be included in the learning scheduler parameters. If not specified, no scheduler will be used.
e.g. lr_scheduler_params = { 'name': 'linear', 'num_warmup_steps': 0, 'num_training_steps': 5 }
Expand Down Expand Up @@ -141,16 +139,18 @@ def train(
if validation_input_data is not None:
validation_dataloader = self.accelerator.prepare(validation_dataloader)

self.uce_model.train()
self.fine_tuning_head.train()

# disable progress bar if not the main process
# if self.accelerator is not None:
# pbar = tqdm(dataloader, disable=not self.accelerator.is_local_main_process)
# else:
# pbar = tqdm(dataloader)

model = self.to(self.device)
self.to(self.device)

optimizer = optimizer(model.parameters(), **optimizer_params)
optimizer = optimizer(self.parameters(), **optimizer_params)

lr_scheduler = None
if lr_scheduler_params is not None:
Expand All @@ -170,7 +170,7 @@ def train(
else:
batch_sentences = self.uce_model.pe_embedding(batch_sentences.long())
batch_sentences = torch.nn.functional.normalize(batch_sentences, dim=2) # normalize token outputs
output = model._forward(batch_sentences, mask=mask)
output = self._forward(batch_sentences, mask=mask)
labels = torch.tensor(train_labels[batch_count: batch_count + self.config["batch_size"]], device=self.device)
batch_count += self.config["batch_size"]
loss = loss_function(output, labels)
Expand Down Expand Up @@ -200,13 +200,15 @@ def train(
else:
batch_sentences = self.uce_model.pe_embedding(batch_sentences.long())
batch_sentences = torch.nn.functional.normalize(batch_sentences, dim=2) # normalize token outputs
output = model._forward(batch_sentences, mask=mask)
output = self._forward(batch_sentences, mask=mask)
val_labels = torch.tensor(validation_labels[validation_batch_count: validation_batch_count + self.config["batch_size"]], device=self.device)
validation_batch_count += self.config["batch_size"]
accuracy += accuracy_score(val_labels.cpu(), torch.argmax(output, dim=1).cpu())
count += 1.0
testing_loop.set_postfix({"accuracy": accuracy/count})
logger.info(f"Fine-Tuning Complete. Epochs: {epochs}")
self.uce_model.eval()
self.fine_tuning_head.eval()

def get_outputs(
self,
Expand All @@ -225,7 +227,7 @@ def get_outputs(
np.ndarray
The outputs of the model.
"""
model = self.to(self.device)
self.to(self.device)

batch_size = self.config["batch_size"]
dataloader = DataLoader(dataset,
Expand All @@ -237,6 +239,9 @@ def get_outputs(

if self.accelerator is not None:
dataloader = self.accelerator.prepare(dataloader)

self.uce_model.eval()
self.fine_tuning_head.eval()

testing_loop = tqdm(dataloader, desc="Fine-Tuning Validation")
outputs = []
Expand All @@ -248,7 +253,7 @@ def get_outputs(
else:
batch_sentences = self.uce_model.pe_embedding(batch_sentences.long())
batch_sentences = torch.nn.functional.normalize(batch_sentences, dim=2) # normalize token outputs
output = model._forward(batch_sentences, mask=mask)
output = self._forward(batch_sentences, mask=mask)
outputs.append(output.detach().cpu().numpy())

return np.vstack(outputs)
3 changes: 2 additions & 1 deletion helical/utils/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def get_anndata_from_hf_dataset(dataset: Dataset) -> ad.AnnData:
An AnnData object containing the data from the input dataset
"""
# obs
observation_names = [obs for obs in list(dataset.features.keys()) if not obs == 'raw_counts' and not obs == "rows"]
excluded_features = ['raw_counts', 'rows', 'size']
observation_names = [obs for obs in dataset.features.keys() if obs not in excluded_features]
obs_data = pd.DataFrame(dataset.select_columns(observation_names).data.to_pandas(),columns=observation_names)

# raw counts
Expand Down

0 comments on commit 8d277fe

Please sign in to comment.