diff --git a/ci/tests/test_geneformer/test_geneformer_model.py b/ci/tests/test_geneformer/test_geneformer_model.py index d029aa9b..fd6a2e9b 100644 --- a/ci/tests/test_geneformer/test_geneformer_model.py +++ b/ci/tests/test_geneformer/test_geneformer_model.py @@ -52,7 +52,8 @@ def test_ensure_data_validity_raising_error_with_missing_ensembl_id_column(self, @pytest.mark.parametrize("gene_symbols, raises_error", [ - (['ENSGSAMD11', 'ENSGPLEKHN1', 'ENSGHES4'], True), + (['ENSGSAMD11', 'ENSGPLEKHN1', 'ENSGHES4'], True), # humans + (['ENSMUSG00000021033', 'ENSMUSG00000021033', 'ENSMUSG00000021033'], True), # mice (['SAMD11', 'None', 'HES4'], True), (['SAMD11', 'PLEKHN1', 'HES4'], False), ] diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index 5ca23f78..b7139b78 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -145,7 +145,7 @@ def process_data(self, # map gene symbols to ensemble ids if provided if gene_names != "ensembl_id": - if (adata.var[gene_names].str.startswith("ENSG").all()) or (adata.var[gene_names].str.startswith("None").any()): + if (adata.var[gene_names].str.startswith("ENS").all()) or (adata.var[gene_names].str.startswith("None").any()): message = "It seems an anndata with 'ensemble ids' and/or 'None' was passed. " \ "Please set gene_names='ensembl_id' and remove 'None's to skip mapping." LOGGER.info(message)