Skip to content

Commit

Permalink
Merge pull request #587 from Sichao25/sim
Browse files Browse the repository at this point in the history
Debug the dyn.pl.kinetic_curves()
  • Loading branch information
Xiaojieqiu authored Oct 11, 2023
2 parents cc31f65 + 3f066cb commit 92e3687
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
3 changes: 2 additions & 1 deletion dynamo/plot/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def kinetic_curves(
import seaborn as sns

if mode == "pseudotime" and tkey == "potential" and "potential" not in adata.obs_keys():
ddhodge(adata)
ddhodge(adata, basis=basis)
tkey = basis + "_ddhodge_potential"

exprs, valid_genes, time = fetch_exprs(
adata,
Expand Down
27 changes: 15 additions & 12 deletions dynamo/prediction/fate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,27 @@ def fate(

elif basis == "umap" and inverse_transform:
# this requires umap 0.4; reverse project to PCA space.
if prediction.ndim == 1:
prediction = prediction[None, :]
exprs = adata.uns["umap_fit"]["fit"].inverse_transform(prediction)
if hasattr(prediction, "ndim"):
if prediction.ndim == 1:
prediction = prediction[None, :]

# further reverse project back to raw expression space
umap_fit = adata.uns["umap_fit"]["fit"]
PCs = adata.uns["PCs"].T
if PCs.shape[0] == exprs.shape[1]:
exprs = np.expm1(exprs @ PCs + adata.uns["pca_mean"])

ndim = adata.uns["umap_fit"]["fit"]._raw_data.shape[1]
exprs = []

if "X" in adata.obsm_keys():
if ndim == adata.obsm[DKM.X_PCA].shape[1]: # lift the dimension up again
exprs = adata.uns["pca_fit"].inverse_transform(prediction)
for cur_pred in prediction:
expr = umap_fit.inverse_transform(cur_pred.T)

if adata.var.use_for_dynamics.sum() == exprs.shape[1]:
# further reverse project back to raw expression space
if PCs.shape[0] == expr.shape[1]:
expr = np.expm1(expr @ PCs + adata.uns["pca_mean"])

exprs.append(expr)

if adata.var.use_for_dynamics.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_dynamics]
elif adata.var.use_for_transition.sum() == exprs.shape[1]:
elif adata.var.use_for_transition.sum() == exprs[0].shape[1]:
valid_genes = adata.var_names[adata.var.use_for_transition]
else:
raise Exception(
Expand Down
12 changes: 6 additions & 6 deletions dynamo/prediction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def fetch_exprs(adata, basis, layer, genes, time, mode, project_back_to_high_dim
time = time[traj_ind]

if mode.lower() not in ["vector_field", "lap"]:
valid_genes = list(set(genes).intersection(adata.var.index))
valid_genes = list(sorted(set(genes).intersection(adata.var.index), key=genes.index))

if layer == "X":
exprs = adata[np.isfinite(time), :][:, valid_genes].X
Expand All @@ -536,27 +536,27 @@ def fetch_exprs(adata, basis, layer, genes, time, mode, project_back_to_high_dim
raise Exception(f"The {layer} you passed in is not existed in the adata object.")
else:
fate_genes = adata.uns[traj_key]["genes"]
valid_genes = list(set(genes).intersection(fate_genes))
valid_genes = list(sorted(set(genes).intersection(fate_genes), key=genes.index))

if basis is not None:
if project_back_to_high_dim:
exprs = adata.uns[traj_key]["exprs"]
if type(exprs) == list:
exprs = exprs[traj_ind]
exprs = exprs[np.isfinite(time), :][:, pd.Series(fate_genes).isin(valid_genes)]
exprs = exprs[np.isfinite(time), :][:, fate_genes.get_indexer(valid_genes)]
else:
exprs = adata.uns[traj_key]["prediction"]
if type(exprs) == list:
exprs = exprs[traj_ind]
exprs = exprs[np.isfinite(time), :]
exprs = exprs.T[np.isfinite(time), :]
valid_genes = [basis + "_" + str(i) for i in np.arange(exprs.shape[1])]
else:
exprs = adata.uns[traj_key]["prediction"]
if type(exprs) == list:
exprs = exprs[traj_ind]
exprs = exprs[np.isfinite(time), pd.Series(fate_genes).isin(valid_genes)]
exprs = exprs.T[np.isfinite(time), adata.var.index.get_indexer(valid_genes)]

time = time[np.isfinite(time)]
time = np.array(time)[np.isfinite(time)]

return exprs, valid_genes, time

Expand Down

0 comments on commit 92e3687

Please sign in to comment.