Skip to content

Commit

Permalink
Merge pull request #580 from Sichao25/debug_l
Browse files Browse the repository at this point in the history
Debug h5ad saving
  • Loading branch information
Xiaojieqiu authored Nov 16, 2023
2 parents 773a43c + 8f3cc0f commit 8456e3b
Show file tree
Hide file tree
Showing 15 changed files with 297 additions and 60 deletions.
68 changes: 68 additions & 0 deletions dynamo/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from functools import reduce

import pandas as pd
from anndata import (
AnnData,
read,
Expand All @@ -19,6 +20,7 @@
from tqdm import tqdm

from .dynamo_logger import main_info
from .tools.Markov import KernelMarkovChain


def make_dir(path: str, can_exist=True):
Expand Down Expand Up @@ -334,3 +336,69 @@ def export_rank_xlsx(adata, path="rank_info.xlsx", ext="excel", rank_prefix="ran
if key[: len(rank_prefix)] == rank_prefix:
main_info("saving sheet: " + str(key))
adata.uns[key].to_excel(writer, sheet_name=str(key))


def export_kmc(adata: AnnData) -> None:
"""Save the parameters of kmc and delete the kmc object from anndata."""
kmc = adata.uns["kmc"]
adata.uns["kmc_params"] = {
"P": kmc.P,
"Idx": kmc.Idx,
"eignum": kmc.eignum,
"D": kmc.D,
"U": kmc.U,
"W": kmc.W,
"W_inv": kmc.W_inv,
"Kd": kmc.Kd,
}
adata.uns.pop("kmc")


def import_kmc(adata: AnnData) -> None:
"""Construct the kmc object using the parameters saved."""
kmc = KernelMarkovChain(P=adata.uns["kmc_params"]["P"], Idx=adata.uns["kmc_params"]["Idx"])
kmc.eignum = adata.uns["kmc_params"]["eignum"]
kmc.D = adata.uns["kmc_params"]["D"]
kmc.U = adata.uns["kmc_params"]["U"]
kmc.W = adata.uns["kmc_params"]["W"]
kmc.W_inv = adata.uns["kmc_params"]["W_inv"]
kmc.Kd = adata.uns["kmc_params"]["Kd"]
adata.uns["kmc"] = kmc
adata.uns.pop("kmc_params")


def export_h5ad(adata: AnnData, path: str = "data/processed_data.h5ad") -> None:
"""Export the anndata object to h5ad."""

if "kmc" in adata.uns.keys():
export_kmc(adata)

fate_keys = [i if i.startswith("fate") else None for i in adata.uns_keys()]
for i in fate_keys:
if i is not None:
if "prediction" in adata.uns[i].keys():
adata.uns[i]["prediction"] = {str(index): array for index, array in
enumerate(adata.uns[i]["prediction"])}
if "t" in adata.uns[i].keys():
adata.uns[i]["t"] = {str(index): array for index, array in enumerate(adata.uns[i]["t"])}

adata.write_h5ad(path)


def import_h5ad(path: str ="data/processed_data.h5ad") -> AnnData:
"""Import a Dynamo h5ad object into anndata."""

adata = read_h5ad(path)
if "kmc_params" in adata.uns.keys():
import_kmc(adata)

fate_keys = [i if i.startswith("fate") else None for i in adata.uns_keys()]
for i in fate_keys:
if i is not None:
if "prediction" in adata.uns[i].keys():
adata.uns[i]["prediction"] = [adata.uns[i]["prediction"][index] for index in adata.uns[i]["prediction"]]
if "t" in adata.uns[i].keys():
adata.uns[i]["t"] = [adata.uns[i]["t"][index] for index in adata.uns[i]["t"]]

return adata

7 changes: 5 additions & 2 deletions dynamo/plot/topography.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def plot_nullclines(
NCx, NCy = None, None

# if nullcline is not previously calculated, calculate and plot them
if vecfld_dict is None or "nullcline" not in vecfld_dict.keys():
if vecfld_dict is None or "NCx" not in vecfld_dict.keys() or "NCy" not in vecfld_dict.keys():
if vecfld_dict is not None:
X_basis = vecfld_dict["X"][:, :2]
min_, max_ = X_basis.min(0), X_basis.max(0)
Expand All @@ -268,7 +268,10 @@ def plot_nullclines(

NCx, NCy = vecfld2d.NCx, vecfld.NCy
else:
NCx, NCy = vecfld_dict["nullcline"][0], vecfld_dict["nullcline"][1]
NCx, NCy = (
[vecfld_dict["NCx"][index] for index in vecfld_dict["NCx"]],
[vecfld_dict["NCy"][index] for index in vecfld_dict["NCy"]],
)

if ax is None:
ax = plt.gca()
Expand Down
25 changes: 20 additions & 5 deletions dynamo/prediction/fate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
main_info_insert_adata,
main_warning,
)
from ..tools.connectivity import correct_hnsw_neighbors, k_nearest_neighbors
from ..utils import pca_to_expr
from ..tools.connectivity import construct_mapper_umap, correct_hnsw_neighbors, k_nearest_neighbors
from ..tools.utils import fetch_states, getTseq
from ..vectorfield import vector_field_function
from ..vectorfield.utils import vecfld_from_adata, vector_transformation
Expand Down Expand Up @@ -167,9 +168,23 @@ def fate(
if prediction.ndim == 1:
prediction = prediction[None, :]

umap_fit = adata.uns["umap_fit"]["fit"]
PCs = adata.uns["PCs"].T
params = adata.uns["umap_fit"]
umap_fit = construct_mapper_umap(
params["X_data"],
n_components=params["umap_kwargs"]["n_components"],
metric=params["umap_kwargs"]["metric"],
min_dist=params["umap_kwargs"]["min_dist"],
spread=params["umap_kwargs"]["spread"],
max_iter=params["umap_kwargs"]["max_iter"],
alpha=params["umap_kwargs"]["alpha"],
gamma=params["umap_kwargs"]["gamma"],
negative_sample_rate=params["umap_kwargs"]["negative_sample_rate"],
init_pos=params["umap_kwargs"]["init_pos"],
random_state=params["umap_kwargs"]["random_state"],
umap_kwargs=params["umap_kwargs"],
)

PCs = adata.uns["PCs"].T
exprs = []

for cur_pred in prediction:
Expand All @@ -193,12 +208,12 @@ def fate(

adata.uns[fate_key] = {
"init_states": init_states,
"init_cells": init_cells,
"init_cells": list(init_cells),
"average": average,
"t": t,
"prediction": prediction,
# "VecFld": VecFld,
"VecFld_true": VecFld_true,
# "VecFld_true": VecFld_true,
"genes": valid_genes,
}
if exprs is not None:
Expand Down
5 changes: 3 additions & 2 deletions dynamo/prediction/least_action_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

from ..dynamo_logger import LoggerManager
from ..tools.utils import fetch_states, nearest_neighbors
from ..utils import pca_to_expr
from ..vectorfield import SvcVectorField
from ..vectorfield.utils import (
vecfld_from_adata,
vector_field_function_transformation,
vector_transformation,
)
from .trajectory import GeneTrajectory, Trajectory
from .utils import arclength_sampling_n, find_elbow, pca_to_expr
from .utils import arclength_sampling_n, find_elbow


class LeastActionPath(Trajectory):
Expand Down Expand Up @@ -601,7 +602,7 @@ def least_action(

adata.uns[LAP_key] = {
"init_states": init_states,
"init_cells": init_cells,
"init_cells": list(init_cells),
"t": t,
"mftp": mftp,
"prediction": prediction,
Expand Down
3 changes: 2 additions & 1 deletion dynamo/prediction/perturbation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..dynamo_logger import LoggerManager
from ..tools.cell_velocities import cell_velocities
from ..utils import expr_to_pca, pca_to_expr
from ..vectorfield import SvcVectorField
from ..vectorfield.scVectorField import KOVectorField, vector_field_function_knockout
from ..vectorfield.vector_calculus import (
Expand All @@ -15,7 +16,7 @@
vector_transformation,
)
from ..vectorfield.rank_vf import rank_cell_groups, rank_cells, rank_genes
from .utils import expr_to_pca, pca_to_expr, z_score, z_score_inv
from .utils import z_score, z_score_inv


def KO(
Expand Down
3 changes: 2 additions & 1 deletion dynamo/prediction/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from ..dynamo_logger import LoggerManager
from ..tools.utils import flatten
from ..utils import expr_to_pca, pca_to_expr
from ..vectorfield.scVectorField import DifferentiableVectorField
from ..vectorfield.utils import angle, normalize_vectors
from .utils import arclength_sampling_n, expr_to_pca, pca_to_expr
from .utils import arclength_sampling_n


class Trajectory:
Expand Down
24 changes: 0 additions & 24 deletions dynamo/prediction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,30 +495,6 @@ def arclength_sampling_n(X, num, t=None):
return X_, arclen[-1]


# ---------------------------------------------------------------------------------------------------
# trajectory related
def pca_to_expr(X, PCs, mean=0, func=None):
# reverse project from PCA back to raw expression space
if PCs.shape[1] == X.shape[1]:
exprs = X @ PCs.T + mean
if func is not None:
exprs = func(exprs)
else:
raise Exception("PCs dim 1 (%d) does not match X dim 1 (%d)." % (PCs.shape[1], X.shape[1]))
return exprs


def expr_to_pca(expr, PCs, mean=0, func=None):
# project from raw expression space to PCA
if PCs.shape[0] == expr.shape[1]:
X = (expr - mean) @ PCs
if func is not None:
X = func(X)
else:
raise Exception("PCs dim 1 (%d) does not match X dim 1 (%d)." % (PCs.shape[0], expr.shape[1]))
return X


# ---------------------------------------------------------------------------------------------------
# fate related
def fetch_exprs(adata, basis, layer, genes, time, mode, project_back_to_high_dim, traj_ind):
Expand Down
3 changes: 2 additions & 1 deletion dynamo/preprocessing/cell_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ def get_cell_phase(
else:
cell_phase_genes = gene_list

adata.uns["cell_phase_genes"] = cell_phase_genes
adata.uns["cell_phase_order"] = [key for key in cell_phase_genes]
adata.uns["cell_phase_genes"] = dict(cell_phase_genes)
# score each cell cycle phase and Z-normalize
phase_scores = pd.DataFrame(batch_group_score(adata, layer, cell_phase_genes))
normalized_phase_scores = phase_scores.sub(phase_scores.mean(axis=1), axis=0).div(phase_scores.std(axis=1), axis=0)
Expand Down
32 changes: 21 additions & 11 deletions dynamo/tools/cell_velocities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from ..configuration import DKM
from ..dynamo_logger import LoggerManager, main_info, main_warning
from ..utils import areinstance
from .connectivity import _gen_neighbor_keys, adj_to_knn, check_and_recompute_neighbors
from ..utils import areinstance, expr_to_pca
from .connectivity import _gen_neighbor_keys, adj_to_knn, check_and_recompute_neighbors, construct_mapper_umap
from .dimension_reduction import reduceDimension
from .graph_calculus import calc_gaussian_weight, fp_operator, graphize_velocity
from .Markov import ContinuousTimeMarkovChain, KernelMarkovChain, velocity_on_grid
Expand Down Expand Up @@ -519,23 +519,33 @@ def cell_velocities(
adata.obsp["discrete_vector_field"] = E

elif method == "transform":
umap_trans, n_pca_components = (
adata.uns["umap_fit"]["fit"],
adata.uns["umap_fit"]["n_pca_components"],
params = adata.uns["umap_fit"]
umap_trans = construct_mapper_umap(
params["X_data"],
n_components=params["umap_kwargs"]["n_components"],
metric=params["umap_kwargs"]["metric"],
min_dist=params["umap_kwargs"]["min_dist"],
spread=params["umap_kwargs"]["spread"],
max_iter=params["umap_kwargs"]["max_iter"],
alpha=params["umap_kwargs"]["alpha"],
gamma=params["umap_kwargs"]["gamma"],
negative_sample_rate=params["umap_kwargs"]["negative_sample_rate"],
init_pos=params["umap_kwargs"]["init_pos"],
random_state=params["umap_kwargs"]["random_state"],
umap_kwargs=params["umap_kwargs"],
)

if "pca_fit" not in adata.uns_keys() or type(adata.uns["pca_fit"]) == str:
CM = adata.X[:, adata.var.use_for_dynamics.values]
CM = adata.X[:, adata.var.use_for_dynamics.values]
if "PCs" not in adata.uns_keys():
from ..preprocessing.pca import pca

adata, pca_fit, X_pca = pca(adata, CM, n_pca_components, "X", return_all=True)
adata.uns["pca_fit"] = pca_fit
adata, pca_fit, X_pca = pca(adata, CM, params["n_pca_components"], "X", return_all=True)

X_pca, pca_fit = adata.obsm[DKM.X_PCA], adata.uns["pca_fit"]
X_pca, pca_PCs = adata.obsm[DKM.X_PCA], adata.uns["PCs"]
V = adata[:, adata.var.use_for_dynamics.values].layers[vkey] if vkey in adata.layers.keys() else None
CM, V = CM.A if sp.issparse(CM) else CM, V.A if sp.issparse(V) else V
V[np.isnan(V)] = 0
Y_pca = pca_fit.transform(CM + V)
Y_pca = expr_to_pca(CM + V, PCs=pca_PCs, mean=(CM + V).mean(0))

Y = umap_trans.transform(Y_pca)

Expand Down
Loading

0 comments on commit 8456e3b

Please sign in to comment.