Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trVAE training on query data indefinitely hang on "Instantiating Dataset" #258

Open
li-xuyang28 opened this issue Nov 11, 2024 · 0 comments

Comments

@li-xuyang28
Copy link

li-xuyang28 commented Nov 11, 2024

Dear developers/maintainers,

I have used trVAE a while ago before it became part of scArches. Now I'm trying to use the scArches implementation to annotate a new unlabeled dataset. The training on the reference/source (a bit over 2m cells) was successful (took a bit more than 13h, plateaued and stopped after 226 iters); however, when I'm trying to train on query dataset (~230k cells), it is hanging at "Instantiating Dataset" indefinitely (>24h). I was wondering if you could kindly advise.

The code up to the training on query set:

########### Loading packages #################
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
import pickle
print("Loaded packages...")
import warnings
warnings.filterwarnings('ignore')
import psutil
mem = psutil.virtual_memory()
print(f"Available memory: {mem.available / 1024**3:.2f} GB")

########### Train on reference ###############
print("Read data...")
ref = sc.read_h5ad(REF_DATA_PATH)
early_stopping_kwargs = {
    "early_stopping_metric": "val_unweighted_loss",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

condition_key = 'donor_id'
source_conditions = ref.obs[condition_key].unique().tolist()
cell_type_key = 'supercluster_term'

trvae = sca.models.TRVAE(
    adata=ref,
    condition_key=condition_key,
    conditions=source_conditions,
    hidden_layer_sizes=[128, 128, 128],
    recon_loss="zinb",
)

trvae.train(
    n_epochs=500,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs
)
trvae.save(REF_MODEL_PATH, overwrite=True)

########### Train on query ###############
adata = sc.read_h5ad(QUERY_DATA_PATH)

new_trvae = sca.models.TRVAE.load_query_data(adata=adata, reference_model=REF_MODEL_PATH)
new_trvae.train(
    n_epochs=500,
    alpha_epoch_anneal=200,
    early_stopping_kwargs=early_stopping_kwargs,
    weight_decay=0
)

Session info:

gdown    5.2.0
matplotlib 3.9.2
numpy 1.26.1
psutil 6.1.0
scanpy 1.9.6
scarches 0.6.1
session_info 1.0.0
torch 2.5.1+cu124
----
Python 3.9.20 (main, Oct 3 2024, 07:27:41) [GCC 11.2.0]
Linux-5.14.0-427.42.1.el9_4.x86_64-with-glibx2.34

I checked that memory usage was way less than 50% (1500GB mem, gres = 6x Nvidia Tesla T4). I did use mmd and am aware that would slow process down considerably, but would expect at least the progress bar to show up. Any help would be greatly appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant