Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load embeddings outside models #42

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions examples/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,10 @@ def run_integration_example(adata: ad.AnnData, cfg: DictConfig) -> dict[str, dic
adata = ad.concat([batch_0[:20], batch_1[:20]])

dataset = uce.process_data(adata)
uce.get_embeddings(dataset)
processed_adata_uce = getattr(uce, "adata")
adata.obsm["X_uce"] = uce.get_embeddings(dataset)

dataset = scgpt.process_data(adata)
scgpt.get_embeddings(dataset)
processed_adata_scgpt = getattr(scgpt, "adata")
adata.obsm["X_scgpt"] = scgpt.get_embeddings(dataset)

# data specific configurations
cfg["data"]["batch_key"] = "batch"
Expand All @@ -114,8 +112,8 @@ def run_integration_example(adata: ad.AnnData, cfg: DictConfig) -> dict[str, dic

return evaluate_integration(
[
("scgpt", processed_adata_scgpt, "X_scgpt"),
("uce", processed_adata_uce, "X_uce"),
("scgpt", adata, "X_scgpt"),
("uce", adata, "X_uce"),
("scanorama", adata, "X_scanorama")
], cfg
)
Expand Down
4 changes: 0 additions & 4 deletions helical/models/geneformer/geneformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ class GeneformerConfig():

Parameters
----------
embed_obsm_name : str, optional, default = "X_geneformer"
The name of the obsm under which the embeddings will be saved in the AnnData object
batch_size : int, optional, default = 5
The batch size
emb_layer : int, optional, default = -1
Expand All @@ -29,7 +27,6 @@ class GeneformerConfig():
"""
def __init__(
self,
embed_obsm_name: str = "X_geneformer",
batch_size: int = 5,
emb_layer: int = -1,
emb_mode: str = "cell",
Expand All @@ -50,7 +47,6 @@ def __init__(

self.model_dir = Path(CACHE_DIR_HELICAL, 'geneformer')
self.model_name = model_name
self.embed_obsm_name = embed_obsm_name
self.batch_size = batch_size
self.emb_layer = emb_layer
self.emb_mode = emb_mode
Expand Down
9 changes: 4 additions & 5 deletions helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def process_data(self,

"""
self.check_data_validity(adata, gene_column_name)
self.adata = adata

files_config = {
"mapping_path": self.config.model_dir / "human_gene_to_ensemble_id.pkl",
Expand All @@ -114,7 +113,9 @@ def process_data(self,
# map gene symbols to ensemble ids if provided
if gene_column_name != "ensembl_id":
mappings = pkl.load(open(files_config["mapping_path"], 'rb'))
self.adata.var['ensembl_id'] = self.adata.var[gene_column_name].apply(lambda x: mappings.get(x,{"id":None})['id'])
adata.var['ensembl_id'] = adata.var[gene_column_name].apply(lambda x: mappings.get(x,{"id":None})['id'])
non_none_mappings = adata.var['ensembl_id'].notnull().sum()
LOGGER.info(f"Mapped {non_none_mappings} genes to Ensembl IDs from a total of {adata.var.shape[0]} genes.")

# load token dictionary (Ensembl IDs:token)
with open(files_config["token_path"], "rb") as f:
Expand All @@ -125,7 +126,7 @@ def process_data(self,
gene_median_file = files_config["gene_median_path"],
token_dictionary_file = files_config["token_path"])

tokenized_cells, cell_metadata = self.tk.tokenize_anndata(self.adata)
tokenized_cells, cell_metadata = self.tk.tokenize_anndata(adata)
tokenized_dataset = self.tk.create_dataset(tokenized_cells, cell_metadata, use_generator=False)

if output_path:
Expand Down Expand Up @@ -157,8 +158,6 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
self.device
).cpu().detach().numpy()

# save the embeddings in the adata object
self.adata.obsm[self.config.embed_obsm_name] = embeddings
return embeddings


Expand Down
13 changes: 6 additions & 7 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
# obs_df = adata.obs[obs_to_save] if obs_to_save is not None else None
# return sc.AnnData(X=cell_embeddings, obs=obs_df, dtype="float32")

# save the embeddings in the adata object
self.adata.obsm[self.config["embed_obsm_name"]] = cell_embeddings
return cell_embeddings

def process_data(self,
Expand Down Expand Up @@ -183,17 +181,18 @@ def process_data(self,

self.check_data_validity(adata, gene_column_name, use_batch_labels)
self.gene_column_name = gene_column_name
self.adata = adata
if fine_tuning:
# Preprocess the dataset and select `N_HVG` highly variable genes for downstream analysis.
sc.pp.normalize_total(self.adata, target_sum=1e4)
sc.pp.log1p(self.adata)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

# highly variable genes
sc.pp.highly_variable_genes(self.adata, n_top_genes=n_top_genes, flavor=flavor)
self.adata = self.adata[:, self.adata.var['highly_variable']]
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor=flavor)
adata = adata[:, adata.var['highly_variable']]

# filtering
adata.var["id_in_vocab"] = [ self.vocab[gene] if gene in self.vocab else -1 for gene in adata.var[self.gene_column_name] ]
LOGGER.info(f"Filtering out {np.sum(adata.var['id_in_vocab'] < 0)} genes to a total of {np.sum(adata.var['id_in_vocab'] >= 0)} genes with an id in the scGPT vocabulary.")
adata = adata[:, adata.var["id_in_vocab"] >= 0]

# Binning will be applied after tokenization. A possible way to do is to use the unified way of binning in the data collator.
Expand Down
4 changes: 0 additions & 4 deletions helical/models/scgpt/scgpt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ class scGPTConfig():

Parameters
----------
embed_obsm_name : str, optional, default = "X_scgpt"
The name of the obsm under which the embeddings will be saved in the AnnData object
pad_token : str, optional, default = "<pad>"
The padding token
batch_size : int, optional, default = 24
Expand Down Expand Up @@ -54,7 +52,6 @@ class scGPTConfig():

def __init__(
self,
embed_obsm_name: str = "X_scgpt",
pad_token: str = "<pad>",
batch_size: int = 24,
fast_transformer: bool = True,
Expand Down Expand Up @@ -82,7 +79,6 @@ def __init__(
self.config = {
"model_path": model_path,
"list_of_files_to_download": list_of_files_to_download,
"embed_obsm_name": embed_obsm_name,
"pad_token": pad_token,
"batch_size": batch_size,
"fast_transformer": fast_transformer,
Expand Down
10 changes: 3 additions & 7 deletions helical/models/uce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ def process_data(self,
"""

self.check_rna_data_validity(adata, gene_column_name)
self.adata = adata

if gene_column_name != "index":
self.adata.var.index = self.adata.var[gene_column_name]
adata.var.index = adata.var[gene_column_name]

files_config = {
"spec_chrom_csv_path": self.model_dir / "species_chrom.csv",
Expand All @@ -110,10 +109,10 @@ def process_data(self,
}

if filter_genes_min_cell is not None:
sc.pp.filter_genes(self.adata, min_cells=filter_genes_min_cell)
sc.pp.filter_genes(adata, min_cells=filter_genes_min_cell)
# sc.pp.filter_cells(ad, min_genes=25)
##Filtering out the Expression Data That we do not have in the protein embeddings
filtered_adata, species_to_all_gene_symbols = load_gene_embeddings_adata(adata=self.adata,
filtered_adata, species_to_all_gene_symbols = load_gene_embeddings_adata(adata=adata,
species=[species],
embedding_model=embedding_model,
embeddings_path=Path(files_config["protein_embeddings_dir"]))
Expand Down Expand Up @@ -211,7 +210,4 @@ def get_embeddings(self, dataset: UCEDataset) -> np.array:
else:
dataset_embeds.append(embedding.detach().cpu().numpy())
embeddings = np.vstack(dataset_embeds)

# save the embeddings in the adata object
self.adata.obsm[self.config["embed_obsm_name"]] = embeddings
return embeddings
4 changes: 0 additions & 4 deletions helical/models/uce/uce_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ class UCEConfig():
----------
model_name : Literal["33l_8ep_1024t_1280", "4layer_model"], optional, default = "4layer_model"
The model name
embed_obsm_name : str, optional, default = "X_uce"
The name of the obsm under which the embeddings will be saved in the AnnData object
batch_size : int, optional, default = 5
The batch size
pad_length : int, optional, default = 1536
Expand Down Expand Up @@ -51,7 +49,6 @@ class UCEConfig():
"""
def __init__(self,
model_name: Literal["33l_8ep_1024t_1280", "4layer_model"] = "4layer_model",
embed_obsm_name: str = "X_uce",
batch_size: int = 5,
pad_length: int = 1536,
pad_token_idx: int = 0,
Expand Down Expand Up @@ -95,7 +92,6 @@ def __init__(self,
self.config = {
"model_path": model_path,
"list_of_files_to_download": list_of_files_to_download,
"embed_obsm_name": embed_obsm_name,
"batch_size": batch_size,
"pad_length": pad_length,
"pad_token_idx": pad_token_idx,
Expand Down
Loading