Skip to content

Commit

Permalink
Use variance for input weights
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 28, 2024
1 parent a1737f7 commit 5a0bdcc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 51 deletions.
2 changes: 2 additions & 0 deletions scatrex/ntssb/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import jax.numpy as jnp

from abc import ABC, abstractmethod
from functools import reduce


class AbstractNode(ABC):
def __init__(
Expand Down
28 changes: 18 additions & 10 deletions scatrex/ntssb/ntssb.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
max_depth=15,
fixed_weights_pivot_sampling=True,
use_weights=True,
weights_concentration=10.,
weights_variance=1e-3,
min_weight=1e-6,
verbosity=logging.INFO,
node_hyperparams=dict(),
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(

logger.setLevel(verbosity)

self.reset_tree(use_weights=use_weights, weights_concentration=weights_concentration, min_weight=min_weight)
self.reset_tree(use_weights=use_weights, weights_variance=weights_variance, min_weight=min_weight)

self.set_pivot_priors()

Expand All @@ -126,7 +126,7 @@ def __init__(
}

# ========= Functions to initialize tree. =========
def reset_tree(self, use_weights=False, weights_concentration=10., min_weight=1e-6):
def reset_tree(self, use_weights=False, weights_variance=1e-3, min_weight=1e-6):
if use_weights and "weight" not in self.input_tree_dict["A"].keys():
raise KeyError("No weights were specified in the input tree.")

Expand Down Expand Up @@ -157,8 +157,8 @@ def descend(input_root, idx=1, depth=0):
stick = stick / sum
else:
stick = 1.0
psi_prior["alpha_psi"] = stick * (weights_concentration - 2) + 1
psi_prior["beta_psi"] = (1-stick) * (weights_concentration -2) + 1
psi_prior["alpha_psi"] = stick * weights_variance
psi_prior["beta_psi"] = (1-stick) * weights_variance
psi_priors.append(psi_prior)
sticks.append(stick)
if len(sticks) == 0:
Expand Down Expand Up @@ -209,8 +209,8 @@ def descend(input_root, idx=1, depth=0):
main = 1.0 # stop at leaf node

if use_weights:
alpha_nu = main * (weights_concentration - 2) + 1
beta_nu = (1-main) * (weights_concentration - 2) + 1
alpha_nu = main * weights_variance
beta_nu = (1-main) * weights_variance

root_dict = {
"node": tssb,
Expand Down Expand Up @@ -771,7 +771,7 @@ def descend(root):
descend(self.root)

def get_tree_data_sizes(self, normalized=False):
trees = self.get_nodes()
trees = self.get_trees()
sizes = []

for tree in trees:
Expand Down Expand Up @@ -1644,6 +1644,14 @@ def descend(root):
descend(child)
descend(self.root)

def remove_pivots(self):
def descend(root):
for i, child in enumerate(root['children']):
child['pivot_node'] = None
child['node'].root['node'].set_parent(None)
descend(child)
descend(self.root)

# ========= Functions to update tree structure. =========

def prune_subtrees(self):
Expand Down Expand Up @@ -2485,7 +2493,7 @@ def get_subtree_obs(self):
return subtrees, obs

def initialize_gene_node_colormaps(
self, node_obs=None, node_avg_exp=None, gene_specific=False
self, node_obs=None, node_avg_exp=None, gene_specific=False, vmin=-0.5, vmax=0.5,
):
nodes, vals = self.get_node_unobs()
vals = np.array(vals)
Expand All @@ -2504,7 +2512,7 @@ def initialize_gene_node_colormaps(
else:
global_min, global_max = np.nanmin(vals), np.nanmax(vals)
cmap = self.exp_cmap
norm = matplotlib.colors.Normalize(vmin=-1.0, vmax=1.0)
norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
mapper = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
self.gene_node_colormaps["unobserved"] = dict()
self.gene_node_colormaps["unobserved"]["vals"] = dict(
Expand Down
45 changes: 4 additions & 41 deletions scatrex/scatrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,45 +1540,8 @@ def compute_pathway_enrichments(
enrichments.append(enr)
self.enrichments = dict(zip([node for node in gene_rankings], enrichments))

def compute_pivot_likelihoods(self, clone="B", normalized=True):
"""
For the given clone, compute the tree ELBO for each possible pivot and
return a dictionary of pivots and their ELBOs.
"""
if clone == "A":
raise ValueError(
"The root clone was selected, which by definition \
does not have parent nodes. Please select a non-root clone."
)

tssbs = ntssb.get_subtrees()
labels = [tssb.label for tssb in tssbs]
tssb = subtrees[np.where(np.array(labels) == clone)[0]]

parent_tssb = tssb.root["node"].parent().tssb
possible_pivots = parent_tssb.get_nodes()

pivot_likelihoods = dict()
for pivot in possible_pivots:
if len(possible_pivots) == 1:
possible_pivots[pivot] = self.ntssb.elbo
logger.warning(f"Clone {clone} has only one possible parent node.")
break
ntssb = deepcopy(self.ntssb)
ntssb.pivot_reattach_to(clone, pivot.label)
ntssb.optimize_elbo()
pivot_likelihoods[pivot.label] = ntssb.elbo

if normalize:
labels = list(pivot_likelihoods.get_keys())
vals = list(pivot_likelihoods.get_values())
vals = np.array(vals) / np.sum(vals)
pivot_likelihoods = dict(zip(labels, vals.tolist()))

return pivot_likelihoods

def get_cnv_exp(self, max_level=4, method="scatrex"):
cnv_levels = np.unique(self.observed_tree.adata.X)
def get_cnv_exp(self, cnv_levels=[1,2,3,4], max_level=4, method="scatrex"):
cnv_levels = np.unique(cnv_levels)
exp_levels = []
for cnv in cnv_levels:
gene_avg = []
Expand Down Expand Up @@ -2053,8 +2016,8 @@ def plot_proportions(self, dna=True, rna=True, remove_empty_nodes=True, show=Tru
dna_props = dna_props[s]

if rna:
rna_nodes, rna_props = self.ntssb.get_node_data_sizes(
normalized=True, super_only=True
rna_nodes, rna_props = self.ntssb.get_tree_data_sizes(
normalized=True,
)
nodes_labels = [node.label for node in rna_nodes]
s = np.argsort(np.array(nodes_labels))
Expand Down

0 comments on commit 5a0bdcc

Please sign in to comment.