From 86652be5143fa388c547c3e9ed18e544ed0854f8 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Wed, 3 Jul 2024 17:17:24 +0200 Subject: [PATCH 1/2] Pass the embeddings of each model to the anndata object outside of the model classes. This is to keep things separate: input an anndata object for inference and only expect the embeddings as output. --- examples/run_benchmark.py | 10 ++++------ helical/models/geneformer/geneformer_config.py | 4 ---- helical/models/geneformer/model.py | 7 ++----- helical/models/scgpt/model.py | 11 ++++------- helical/models/scgpt/scgpt_config.py | 4 ---- helical/models/uce/model.py | 10 +++------- helical/models/uce/uce_config.py | 4 ---- 7 files changed, 13 insertions(+), 37 deletions(-) diff --git a/examples/run_benchmark.py b/examples/run_benchmark.py index c8dccf89..89f37628 100644 --- a/examples/run_benchmark.py +++ b/examples/run_benchmark.py @@ -97,12 +97,10 @@ def run_integration_example(adata: ad.AnnData, cfg: DictConfig) -> dict[str, dic adata = ad.concat([batch_0[:20], batch_1[:20]]) dataset = uce.process_data(adata) - uce.get_embeddings(dataset) - processed_adata_uce = getattr(uce, "adata") + adata.obsm["X_uce"] = uce.get_embeddings(dataset) dataset = scgpt.process_data(adata) - scgpt.get_embeddings(dataset) - processed_adata_scgpt = getattr(scgpt, "adata") + adata.obsm["X_scgpt"] = scgpt.get_embeddings(dataset) # data specific configurations cfg["data"]["batch_key"] = "batch" @@ -114,8 +112,8 @@ def run_integration_example(adata: ad.AnnData, cfg: DictConfig) -> dict[str, dic return evaluate_integration( [ - ("scgpt", processed_adata_scgpt, "X_scgpt"), - ("uce", processed_adata_uce, "X_uce"), + ("scgpt", adata, "X_scgpt"), + ("uce", adata, "X_uce"), ("scanorama", adata, "X_scanorama") ], cfg ) diff --git a/helical/models/geneformer/geneformer_config.py b/helical/models/geneformer/geneformer_config.py index 13bb0315..d4d17faf 100644 --- a/helical/models/geneformer/geneformer_config.py +++ b/helical/models/geneformer/geneformer_config.py @@ -8,8 +8,6 @@ class GeneformerConfig(): Parameters ---------- - embed_obsm_name : str, optional, default = "X_geneformer" - The name of the obsm under which the embeddings will be saved in the AnnData object batch_size : int, optional, default = 5 The batch size emb_layer : int, optional, default = -1 @@ -29,7 +27,6 @@ class GeneformerConfig(): """ def __init__( self, - embed_obsm_name: str = "X_geneformer", batch_size: int = 5, emb_layer: int = -1, emb_mode: str = "cell", @@ -50,7 +47,6 @@ def __init__( self.model_dir = Path(CACHE_DIR_HELICAL, 'geneformer') self.model_name = model_name - self.embed_obsm_name = embed_obsm_name self.batch_size = batch_size self.emb_layer = emb_layer self.emb_mode = emb_mode diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index 472da7fd..7ed27e6e 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -103,7 +103,6 @@ def process_data(self, """ self.check_data_validity(adata, gene_column_name) - self.adata = adata files_config = { "mapping_path": self.config.model_dir / "human_gene_to_ensemble_id.pkl", @@ -114,7 +113,7 @@ def process_data(self, # map gene symbols to ensemble ids if provided if gene_column_name != "ensembl_id": mappings = pkl.load(open(files_config["mapping_path"], 'rb')) - self.adata.var['ensembl_id'] = self.adata.var[gene_column_name].apply(lambda x: mappings.get(x,{"id":None})['id']) + adata.var['ensembl_id'] = adata.var[gene_column_name].apply(lambda x: mappings.get(x,{"id":None})['id']) # load token dictionary (Ensembl IDs:token) with open(files_config["token_path"], "rb") as f: @@ -125,7 +124,7 @@ def process_data(self, gene_median_file = files_config["gene_median_path"], token_dictionary_file = files_config["token_path"]) - tokenized_cells, cell_metadata = self.tk.tokenize_anndata(self.adata) + tokenized_cells, cell_metadata = self.tk.tokenize_anndata(adata) tokenized_dataset = self.tk.create_dataset(tokenized_cells, cell_metadata, use_generator=False) if output_path: @@ -157,8 +156,6 @@ def get_embeddings(self, dataset: Dataset) -> np.array: self.device ).cpu().detach().numpy() - # save the embeddings in the adata object - self.adata.obsm[self.config["embed_obsm_name"]] = embeddings return embeddings diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py index 48be4b12..9c8d90c1 100644 --- a/helical/models/scgpt/model.py +++ b/helical/models/scgpt/model.py @@ -142,8 +142,6 @@ def get_embeddings(self, dataset: Dataset) -> np.array: # obs_df = adata.obs[obs_to_save] if obs_to_save is not None else None # return sc.AnnData(X=cell_embeddings, obs=obs_df, dtype="float32") - # save the embeddings in the adata object - self.adata.obsm[self.config["embed_obsm_name"]] = cell_embeddings return cell_embeddings def process_data(self, @@ -183,15 +181,14 @@ def process_data(self, self.check_data_validity(adata, gene_column_name, use_batch_labels) self.gene_column_name = gene_column_name - self.adata = adata if fine_tuning: # Preprocess the dataset and select `N_HVG` highly variable genes for downstream analysis. - sc.pp.normalize_total(self.adata, target_sum=1e4) - sc.pp.log1p(self.adata) + sc.pp.normalize_total(adata, target_sum=1e4) + sc.pp.log1p(adata) # highly variable genes - sc.pp.highly_variable_genes(self.adata, n_top_genes=n_top_genes, flavor=flavor) - self.adata = self.adata[:, self.adata.var['highly_variable']] + sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=flavor) + adata = adata[:, adata.var['highly_variable']] adata.var["id_in_vocab"] = [ self.vocab[gene] if gene in self.vocab else -1 for gene in adata.var[self.gene_column_name] ] adata = adata[:, adata.var["id_in_vocab"] >= 0] diff --git a/helical/models/scgpt/scgpt_config.py b/helical/models/scgpt/scgpt_config.py index 551bf75b..d72ca2f9 100644 --- a/helical/models/scgpt/scgpt_config.py +++ b/helical/models/scgpt/scgpt_config.py @@ -8,8 +8,6 @@ class scGPTConfig(): Parameters ---------- - embed_obsm_name : str, optional, default = "X_scgpt" - The name of the obsm under which the embeddings will be saved in the AnnData object pad_token : str, optional, default = "" The padding token batch_size : int, optional, default = 24 @@ -54,7 +52,6 @@ class scGPTConfig(): def __init__( self, - embed_obsm_name: str = "X_scgpt", pad_token: str = "", batch_size: int = 24, fast_transformer: bool = True, @@ -82,7 +79,6 @@ def __init__( self.config = { "model_path": model_path, "list_of_files_to_download": list_of_files_to_download, - "embed_obsm_name": embed_obsm_name, "pad_token": pad_token, "batch_size": batch_size, "fast_transformer": fast_transformer, diff --git a/helical/models/uce/model.py b/helical/models/uce/model.py index 6c215413..b1efacd6 100644 --- a/helical/models/uce/model.py +++ b/helical/models/uce/model.py @@ -98,10 +98,9 @@ def process_data(self, """ self.check_rna_data_validity(adata, gene_column_name) - self.adata = adata if gene_column_name != "index": - self.adata.var.index = self.adata.var[gene_column_name] + adata.var.index = adata.var[gene_column_name] files_config = { "spec_chrom_csv_path": self.model_dir / "species_chrom.csv", @@ -110,10 +109,10 @@ def process_data(self, } if filter_genes_min_cell is not None: - sc.pp.filter_genes(self.adata, min_cells=filter_genes_min_cell) + sc.pp.filter_genes(adata, min_cells=filter_genes_min_cell) # sc.pp.filter_cells(ad, min_genes=25) ##Filtering out the Expression Data That we do not have in the protein embeddings - filtered_adata, species_to_all_gene_symbols = load_gene_embeddings_adata(adata=self.adata, + filtered_adata, species_to_all_gene_symbols = load_gene_embeddings_adata(adata=adata, species=[species], embedding_model=embedding_model, embeddings_path=Path(files_config["protein_embeddings_dir"])) @@ -211,7 +210,4 @@ def get_embeddings(self, dataset: UCEDataset) -> np.array: else: dataset_embeds.append(embedding.detach().cpu().numpy()) embeddings = np.vstack(dataset_embeds) - - # save the embeddings in the adata object - self.adata.obsm[self.config["embed_obsm_name"]] = embeddings return embeddings diff --git a/helical/models/uce/uce_config.py b/helical/models/uce/uce_config.py index 8a8fc2ad..a4157b7a 100644 --- a/helical/models/uce/uce_config.py +++ b/helical/models/uce/uce_config.py @@ -9,8 +9,6 @@ class UCEConfig(): ---------- model_name : Literal["33l_8ep_1024t_1280", "4layer_model"], optional, default = "4layer_model" The model name - embed_obsm_name : str, optional, default = "X_uce" - The name of the obsm under which the embeddings will be saved in the AnnData object batch_size : int, optional, default = 5 The batch size pad_length : int, optional, default = 1536 @@ -51,7 +49,6 @@ class UCEConfig(): """ def __init__(self, model_name: Literal["33l_8ep_1024t_1280", "4layer_model"] = "4layer_model", - embed_obsm_name: str = "X_uce", batch_size: int = 5, pad_length: int = 1536, pad_token_idx: int = 0, @@ -95,7 +92,6 @@ def __init__(self, self.config = { "model_path": model_path, "list_of_files_to_download": list_of_files_to_download, - "embed_obsm_name": embed_obsm_name, "batch_size": batch_size, "pad_length": pad_length, "pad_token_idx": pad_token_idx, From 9663ab6ae4446dc8ead177d5b762c1d88de1c719 Mon Sep 17 00:00:00 2001 From: Benoit Putzeys Date: Thu, 4 Jul 2024 16:03:43 +0200 Subject: [PATCH 2/2] Add log messages for mapping and filtering in each model. --- helical/models/geneformer/model.py | 2 ++ helical/models/scgpt/model.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index 7ed27e6e..921dcec5 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -114,6 +114,8 @@ def process_data(self, if gene_column_name != "ensembl_id": mappings = pkl.load(open(files_config["mapping_path"], 'rb')) adata.var['ensembl_id'] = adata.var[gene_column_name].apply(lambda x: mappings.get(x,{"id":None})['id']) + non_none_mappings = adata.var['ensembl_id'].notnull().sum() + LOGGER.info(f"Mapped {non_none_mappings} genes to Ensembl IDs from a total of {adata.var.shape[0]} genes.") # load token dictionary (Ensembl IDs:token) with open(files_config["token_path"], "rb") as f: diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py index 9c8d90c1..0a45dcca 100644 --- a/helical/models/scgpt/model.py +++ b/helical/models/scgpt/model.py @@ -190,7 +190,9 @@ def process_data(self, sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=flavor) adata = adata[:, adata.var['highly_variable']] + # filtering adata.var["id_in_vocab"] = [ self.vocab[gene] if gene in self.vocab else -1 for gene in adata.var[self.gene_column_name] ] + LOGGER.info(f"Filtering out {np.sum(adata.var['id_in_vocab'] < 0)} genes to a total of {np.sum(adata.var['id_in_vocab'] >= 0)} genes with an id in the scGPT vocabulary.") adata = adata[:, adata.var["id_in_vocab"] >= 0] # Binning will be applied after tokenization. A possible way to do is to use the unified way of binning in the data collator.