diff --git a/scatrex/ntssb/node.py b/scatrex/ntssb/node.py index 82d5a48..0b8e2b3 100644 --- a/scatrex/ntssb/node.py +++ b/scatrex/ntssb/node.py @@ -3,6 +3,8 @@ import jax.numpy as jnp from abc import ABC, abstractmethod +from functools import reduce + class AbstractNode(ABC): def __init__( diff --git a/scatrex/ntssb/ntssb.py b/scatrex/ntssb/ntssb.py index 1e1e682..a7af1c3 100644 --- a/scatrex/ntssb/ntssb.py +++ b/scatrex/ntssb/ntssb.py @@ -67,7 +67,7 @@ def __init__( max_depth=15, fixed_weights_pivot_sampling=True, use_weights=True, - weights_concentration=10., + weights_variance=1e-3, min_weight=1e-6, verbosity=logging.INFO, node_hyperparams=dict(), @@ -117,7 +117,7 @@ def __init__( logger.setLevel(verbosity) - self.reset_tree(use_weights=use_weights, weights_concentration=weights_concentration, min_weight=min_weight) + self.reset_tree(use_weights=use_weights, weights_variance=weights_variance, min_weight=min_weight) self.set_pivot_priors() @@ -126,7 +126,7 @@ def __init__( } # ========= Functions to initialize tree. ========= - def reset_tree(self, use_weights=False, weights_concentration=10., min_weight=1e-6): + def reset_tree(self, use_weights=False, weights_variance=1e-3, min_weight=1e-6): if use_weights and "weight" not in self.input_tree_dict["A"].keys(): raise KeyError("No weights were specified in the input tree.") @@ -157,8 +157,8 @@ def descend(input_root, idx=1, depth=0): stick = stick / sum else: stick = 1.0 - psi_prior["alpha_psi"] = stick * (weights_concentration - 2) + 1 - psi_prior["beta_psi"] = (1-stick) * (weights_concentration -2) + 1 + psi_prior["alpha_psi"] = stick * weights_variance + psi_prior["beta_psi"] = (1-stick) * weights_variance psi_priors.append(psi_prior) sticks.append(stick) if len(sticks) == 0: @@ -209,8 +209,8 @@ def descend(input_root, idx=1, depth=0): main = 1.0 # stop at leaf node if use_weights: - alpha_nu = main * (weights_concentration - 2) + 1 - beta_nu = (1-main) * (weights_concentration - 2) + 1 + alpha_nu = main * weights_variance + beta_nu = (1-main) * weights_variance root_dict = { "node": tssb, @@ -771,7 +771,7 @@ def descend(root): descend(self.root) def get_tree_data_sizes(self, normalized=False): - trees = self.get_nodes() + trees = self.get_trees() sizes = [] for tree in trees: @@ -1644,6 +1644,14 @@ def descend(root): descend(child) descend(self.root) + def remove_pivots(self): + def descend(root): + for i, child in enumerate(root['children']): + child['pivot_node'] = None + child['node'].root['node'].set_parent(None) + descend(child) + descend(self.root) + # ========= Functions to update tree structure. ========= def prune_subtrees(self): @@ -2485,7 +2493,7 @@ def get_subtree_obs(self): return subtrees, obs def initialize_gene_node_colormaps( - self, node_obs=None, node_avg_exp=None, gene_specific=False + self, node_obs=None, node_avg_exp=None, gene_specific=False, vmin=-0.5, vmax=0.5, ): nodes, vals = self.get_node_unobs() vals = np.array(vals) @@ -2504,7 +2512,7 @@ def initialize_gene_node_colormaps( else: global_min, global_max = np.nanmin(vals), np.nanmax(vals) cmap = self.exp_cmap - norm = matplotlib.colors.Normalize(vmin=-1.0, vmax=1.0) + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) mapper = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap) self.gene_node_colormaps["unobserved"] = dict() self.gene_node_colormaps["unobserved"]["vals"] = dict( diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index 4346533..df48e35 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -1540,45 +1540,8 @@ def compute_pathway_enrichments( enrichments.append(enr) self.enrichments = dict(zip([node for node in gene_rankings], enrichments)) - def compute_pivot_likelihoods(self, clone="B", normalized=True): - """ - For the given clone, compute the tree ELBO for each possible pivot and - return a dictionary of pivots and their ELBOs. - """ - if clone == "A": - raise ValueError( - "The root clone was selected, which by definition \ - does not have parent nodes. Please select a non-root clone." - ) - - tssbs = ntssb.get_subtrees() - labels = [tssb.label for tssb in tssbs] - tssb = subtrees[np.where(np.array(labels) == clone)[0]] - - parent_tssb = tssb.root["node"].parent().tssb - possible_pivots = parent_tssb.get_nodes() - - pivot_likelihoods = dict() - for pivot in possible_pivots: - if len(possible_pivots) == 1: - possible_pivots[pivot] = self.ntssb.elbo - logger.warning(f"Clone {clone} has only one possible parent node.") - break - ntssb = deepcopy(self.ntssb) - ntssb.pivot_reattach_to(clone, pivot.label) - ntssb.optimize_elbo() - pivot_likelihoods[pivot.label] = ntssb.elbo - - if normalize: - labels = list(pivot_likelihoods.get_keys()) - vals = list(pivot_likelihoods.get_values()) - vals = np.array(vals) / np.sum(vals) - pivot_likelihoods = dict(zip(labels, vals.tolist())) - - return pivot_likelihoods - - def get_cnv_exp(self, max_level=4, method="scatrex"): - cnv_levels = np.unique(self.observed_tree.adata.X) + def get_cnv_exp(self, cnv_levels=[1,2,3,4], max_level=4, method="scatrex"): + cnv_levels = np.unique(cnv_levels) exp_levels = [] for cnv in cnv_levels: gene_avg = [] @@ -2053,8 +2016,8 @@ def plot_proportions(self, dna=True, rna=True, remove_empty_nodes=True, show=Tru dna_props = dna_props[s] if rna: - rna_nodes, rna_props = self.ntssb.get_node_data_sizes( - normalized=True, super_only=True + rna_nodes, rna_props = self.ntssb.get_tree_data_sizes( + normalized=True, ) nodes_labels = [node.label for node in rna_nodes] s = np.argsort(np.array(nodes_labels))