Skip to content

Commit

Permalink
Merge pull request #52 from helicalAI/ensure-counts-are-ints
Browse files Browse the repository at this point in the history
Ensure counts are ints
  • Loading branch information
bputzeys authored Jul 15, 2024
2 parents 297dafe + eb0ebbe commit 232c7c5
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 46 deletions.
Binary file modified ci/tests/data/cell_type_sample.h5ad
Binary file not shown.
23 changes: 3 additions & 20 deletions ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class TestGeneformerModel:
# Create a dummy AnnData object
data = AnnData()
data.var['gene_symbols'] = ['SAMD11', 'PLEKHN1', 'HES4']
data.obs["n_counts"] = [1]
data.obs["cell_type"] = ["CD4 T cells"]
data.X = [[1, 2, 5]]
tokenized_dataset = geneformer.process_data(data, gene_names='gene_symbols')
Expand All @@ -29,24 +28,8 @@ def test_process_data_padding_and_masking_ids(self):
assert self.geneformer.gene_token_dict.get("<pad>") == 0
assert self.geneformer.gene_token_dict.get("<mask>") == 1

miss_n_counts = AnnData()
miss_n_counts.var["gene_symbols"] = [1]

miss_gene_symbols = AnnData()
miss_gene_symbols.obs["n_counts"] = [1]

miss_ensembl_id = AnnData()
miss_ensembl_id.obs["n_counts"] = [1]
miss_ensembl_id.var["gene_symbols"] = [1]

@pytest.mark.parametrize("data, use_gene_symbols",
[
(miss_n_counts, True),
(miss_gene_symbols, True),
(miss_ensembl_id, False)
]
)
def test_check_data_validity(self, data, use_gene_symbols):
def test_ensure_data_validity_raising_error_with_missing_ensembl_id_column(self):
del self.data.var['ensembl_id']
with pytest.raises(KeyError):
self.geneformer.check_data_validity(data, use_gene_symbols)
self.geneformer.ensure_data_validity(self.data, "ensembl_id")

2 changes: 1 addition & 1 deletion ci/tests/test_geneformer/test_geneformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def test_tokenize_anndata(
number_of_obs = 5
data = AnnData()
data.var["ensembl_id"] = ensembl_ids
data.obs["n_counts"] = [1] * number_of_obs
data.obs["key_from_file"] = ["CD4 T cells"] * number_of_obs
data.X = [x_data_count] * number_of_obs
data.obs["total_counts"] = data.X.sum(axis=1)

self.tokenizer = TranscriptomeTokenizer(
custom_attr_name_dict={"key_from_file": "desired_key_in_dataset"},
Expand Down
44 changes: 39 additions & 5 deletions ci/tests/test_scgpt/test_scgpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from anndata import AnnData
from helical.models.scgpt.tokenizer import GeneVocab
import pytest
import anndata as ad
import numpy as np
from scipy.sparse import csr_matrix

class TestSCGPTModel:
scgpt = scGPT()

# Create a dummy AnnData object
data = AnnData()
data.var["gene_names"] = ['SAMD11', 'PLEKHN1', "NOT_IN_VOCAB", "<pad>", 'HES4']
data.obs["n_counts"] = [1]
data.obs["cell_type"] = ["CD4 T cells"]

vocab = {
Expand Down Expand Up @@ -60,8 +63,7 @@ def test_get_embeddings(self):
embeddings = self.scgpt.get_embeddings(dataset)
assert embeddings.shape == (1, 512)

dummy_data = AnnData()
dummy_data.var.index = ['gene1', 'gene2', 'gene3']
dummy_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
@pytest.mark.parametrize("data, gene_names, batch_labels",
[
# missing gene_names in data.var
Expand All @@ -70,6 +72,38 @@ def test_get_embeddings(self):
(dummy_data, "index", True),
]
)
def test_check_data_validity(self, data, gene_names, batch_labels):
def test_ensure_data_validity__key_error(self, data, gene_names, batch_labels):
with pytest.raises(KeyError):
self.scgpt.check_data_validity(data, gene_names, batch_labels)
self.scgpt.ensure_data_validity(data, gene_names, batch_labels)

err_np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
err_np_arr_data.X.dtype=float
err_np_arr_data.X[0,0] = 0.5

err_csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
err_csr_data.X = csr_matrix(np.random.rand(100, 5), dtype=np.float32)
@pytest.mark.parametrize("data",
[
(err_np_arr_data),
(err_csr_data),
]
)
def test_ensure_data_validity__value_error(self, data):
'''The data in X must be ints. Test an error is raised for both np.ndarray and csr_matrix.'''
with pytest.raises(ValueError):
self.scgpt.ensure_data_validity(data, "index", False)
assert "total_counts" in data.obs

np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
csr_data.X = csr_matrix(np.random.poisson(1, size=(100, 5)), dtype=np.float32)
@pytest.mark.parametrize("data",
[
(np_arr_data),
(csr_data),
]
)
def test_ensure_data_validity__no_error(self, data):
'''The data in X must be ints. Test no error is raised for both np.ndarray and csr_matrix.'''
self.scgpt.ensure_data_validity(data, "index", False)
assert "total_counts" in data.obs
2 changes: 0 additions & 2 deletions examples/notebooks/Cell-Type-Annotation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,6 @@
],
"source": [
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"adata.obs['n_counts'] = adata.X.sum(axis=1)\n",
"adata.obs['cell_type'] = adata.obs['celltype']\n",
"adata.var[\"ensembl_id\"] = adata.var[\"index_column\"]\n",
"\n",
Expand Down Expand Up @@ -1293,7 +1292,6 @@
"source": [
"adata_unseen = sc.read(\"ms_default.h5ad\")\n",
"adata_unseen.var[\"ensembl_id\"] = adata_unseen.var[\"index_column\"]\n",
"adata_unseen.obs['n_counts'] = adata_unseen.X.sum(axis=1)\n",
"adata_unseen.obs['cell_type'] = adata_unseen.obs['celltype']\n",
"data_unseen_geneformer = geneformer.process_data(adata_unseen, gene_names = \"ensembl_id\")\n",
"x_unseen_geneformer = geneformer.get_embeddings(data_unseen_geneformer)\n",
Expand Down
13 changes: 11 additions & 2 deletions helical/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def get_embeddings():
pass

class HelicalRNAModel(HelicalBaseFoundationModel):
def check_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
"""Checks if the data is contains the gene_names, which is needed for all Helical RNA models.
def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
"""Ensures that the data contains the gene_names and has integer counts for adata.X which is saved
in 'total_counts'.
Parameters
----------
Expand Down Expand Up @@ -77,6 +78,14 @@ def check_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
message = f"Data must have the provided key '{gene_names}' in its 'var' section to be processed by the Helical RNA model."
LOGGER.error(message)
raise KeyError(message)

# verify that the data in X are integers
adata.obs["total_counts"] = adata.X.sum(axis=1)
if not (adata.obs["total_counts"] % 1 == 0).all():
message = "The data in X must be integers."
LOGGER.error(message)
raise ValueError(message)


class HelicalDNAModel(HelicalBaseFoundationModel):
def check_dna_data_validity(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions helical/models/geneformer/geneformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
| *Required format:* raw counts scRNAseq data without feature selection as .loom or anndata file.
| *Required row (gene) attribute:* "ensembl_id"; Ensembl ID for each gene.
| *Required col (cell) attribute:* "n_counts"; total read counts in that cell.
| *Required col (cell) attribute:* "total_counts"; total read counts in that cell.
| *Optional col (cell) attribute:* "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria.
| *Optional col (cell) attributes:* any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below.
Expand All @@ -26,7 +26,7 @@
| The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.
| Genes should be labeled with Ensembl IDs (loom row attribute "ensembl_id"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart.
Cells should be labeled with the total read count in the cell (loom column attribute "n_counts") to be used for normalization.
Cells should be labeled with the total read count in the cell (loom column attribute "total_counts") to be used for normalization.
| No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name.
For example, if the original .loom dataset has column attributes "cell_type" and "organ_major" and one would like to retain these attributes as labels in the tokenized dataset with the new names "cell_type" and "organ", respectively,
Expand Down Expand Up @@ -228,7 +228,7 @@ def tokenize_anndata(self, adata: AnnData, target_sum=10_000):
for i in range(0, len(filter_pass_loc), self.chunk_size):
idx = filter_pass_loc[i : i + self.chunk_size]

n_counts = adata[idx].obs["n_counts"].values[:, None]
n_counts = adata[idx].obs["total_counts"].values[:, None]
X_view = adata[idx, coding_miRNA_loc].X
X_norm = X_view / n_counts * target_sum / norm_factor_vector
X_norm = sp.csr_matrix(X_norm)
Expand Down
14 changes: 5 additions & 9 deletions helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def process_data(self,
The tokenized dataset in the form of a Hugginface Dataset object.
"""
self.check_data_validity(adata, gene_names)
self.ensure_data_validity(adata, gene_names)

files_config = {
"gene_median_path": self.config.model_dir / "gene_median_dictionary.pkl",
Expand Down Expand Up @@ -164,8 +164,9 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
return embeddings


def check_data_validity(self, adata: AnnData, gene_names: str) -> None:
"""Checks if the data is eligible for processing by the Geneformer model
def ensure_data_validity(self, adata: AnnData, gene_names: str) -> None:
"""Ensure that the data is eligible for processing by the Geneformer model. This checks
if the data contains the gene_names, and sets the total_counts column in adata.obs.
Parameters
----------
Expand All @@ -179,9 +180,4 @@ def check_data_validity(self, adata: AnnData, gene_names: str) -> None:
KeyError
If the data is missing column names.
"""
self.check_rna_data_validity(adata, gene_names)

if not 'n_counts' in adata.obs.columns.to_list():
message = f"Data must have the 'obs' keys 'n_counts' to be processed by the Geneformer model."
LOGGER.error(message)
raise KeyError(message)
self.ensure_rna_data_validity(adata, gene_names)
6 changes: 3 additions & 3 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def process_data(self,
The processed dataset.
"""

self.check_data_validity(adata, gene_names, use_batch_labels)
self.ensure_data_validity(adata, gene_names, use_batch_labels)
self.gene_names = gene_names
if fine_tuning:
# Preprocess the dataset and select `N_HVG` highly variable genes for downstream analysis.
Expand Down Expand Up @@ -216,7 +216,7 @@ def process_data(self,
return dataset


def check_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool) -> None:
def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool) -> None:
"""Checks if the data is eligible for processing by the scGPT model
Parameters
Expand All @@ -233,7 +233,7 @@ def check_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels:
KeyError
If the data is missing column names.
"""
self.check_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names)

if use_batch_labels:
if not "batch_id" in adata.obs:
Expand Down
2 changes: 1 addition & 1 deletion helical/models/uce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def process_data(self,
Inherits from Dataset class.
"""

self.check_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names)

if gene_names != "index":
adata.var.index = adata.var[gene_names]
Expand Down

0 comments on commit 232c7c5

Please sign in to comment.