Skip to content

Commit

Permalink
Remove CNA node global vars
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Jun 12, 2024
1 parent 49ebd23 commit 4d3f2d1
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions scatrex/models/cna/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from ...utils.math_utils import *
from ...ntssb.node import *

MIN_ALPHA = jnp.log(1e-2)
MAX_BETA = jnp.log(1e6)

def update_params(params, params_gradient, step_size):
new_params = []
for i, param in enumerate(params):
Expand All @@ -36,6 +33,8 @@ def __init__(
factor_precision_shape=2.,
min_cnv = 1e-6,
max_cnv=6.,
min_alpha=1e-2,
max_beta=1e6,
**kwargs,
):
"""
Expand All @@ -50,6 +49,9 @@ def __init__(
self.observed_parameters = np.array(self.cnvs)
self.cnvs = jnp.array(self.cnvs)

self.min_log_alpha = jnp.log(min_alpha)
self.max_log_beta = jnp.log(max_beta)

self.n_genes = self.cnvs.size

# Node hyperparameters
Expand Down Expand Up @@ -974,8 +976,8 @@ def update_direction_params(self, direction_params_grad, direction_sample_grad,
direction_log_beta_grad = mc_grad + direction_params_entropy_grad[1]
self.variational_parameters['kernel']['direction']['log_beta'] += direction_log_beta_grad * step_size

self.variational_parameters['kernel']['direction']['log_alpha'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_alpha'], minval=MIN_ALPHA)
self.variational_parameters['kernel']['direction']['log_beta'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_beta'], maxval=MAX_BETA)
self.variational_parameters['kernel']['direction']['log_alpha'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_alpha'], minval=self.min_log_alpha)
self.variational_parameters['kernel']['direction']['log_beta'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_beta'], maxval=self.max_log_beta)

def update_state_params(self, state_params_grad, state_sample_grad, state_params_entropy_grad, step_size=0.001):
mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0)
Expand All @@ -1000,8 +1002,8 @@ def update_cell_scales_params(self, idx, local_params_grad, local_sample_grad, l
new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + param_grad * step_size
self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param)

self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA)
self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA)
self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=self.min_log_alpha)
self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=self.max_log_beta)

def update_obs_weights_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001):
mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0)
Expand Down Expand Up @@ -1031,8 +1033,8 @@ def update_gene_scales_params(self, global_params_grad, global_sample_grad, glob
param_grad = mc_grad + global_params_entropy_grad[1]
self.variational_parameters['global']['gene_scales']['log_beta'] += param_grad * step_size

self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA)
self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA)
self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=self.min_log_alpha)
self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=self.max_log_beta)

def update_factor_weights_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001):
mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
Expand All @@ -1052,8 +1054,8 @@ def update_factor_precisions_params(self, global_params_grad, global_sample_grad
param_grad = mc_grad + global_params_entropy_grad[1]
self.variational_parameters['global']['factor_precisions']['log_beta'] += param_grad * step_size

self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA)
self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA)
self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=self.min_log_alpha)
self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=self.max_log_beta)

def update_global_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001,
param_names=["factor_weights", "gene_scales", "factor_precisions"], **kwargs):
Expand Down Expand Up @@ -1164,8 +1166,8 @@ def update_gene_scales_adaptive(self, global_params_grad, global_sample_grad, gl
state2 = (m, v)
self.variational_parameters['global']['gene_scales']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA)
self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA)
self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=self.min_log_alpha)
self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=self.max_log_beta)

states = (state1, state2)
return states
Expand Down Expand Up @@ -1194,8 +1196,8 @@ def update_factor_precisions_adaptive(self, global_params_grad, global_sample_gr
state2 = (m, v)
self.variational_parameters['global']['factor_precisions']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps)

self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA)
self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA)
self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=self.min_log_alpha)
self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=self.max_log_beta)

states = (state1, state2)
return states
Expand Down Expand Up @@ -1323,8 +1325,8 @@ def update_cell_scales_adaptive(self, idx, local_params_grad, local_sample_grad,
new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps)
self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param)

self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA)
self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA)
self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=self.min_log_alpha)
self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=self.max_log_beta)

states = (state1, state2)
return states
Expand Down

0 comments on commit 4d3f2d1

Please sign in to comment.