From 4d3f2d13255adc1217a2404ab8b616185d272a45 Mon Sep 17 00:00:00 2001 From: pedrofale Date: Wed, 12 Jun 2024 10:06:09 +0100 Subject: [PATCH] Remove CNA node global vars --- scatrex/models/cna/node.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/scatrex/models/cna/node.py b/scatrex/models/cna/node.py index 134ed86..720b49e 100644 --- a/scatrex/models/cna/node.py +++ b/scatrex/models/cna/node.py @@ -12,9 +12,6 @@ from ...utils.math_utils import * from ...ntssb.node import * -MIN_ALPHA = jnp.log(1e-2) -MAX_BETA = jnp.log(1e6) - def update_params(params, params_gradient, step_size): new_params = [] for i, param in enumerate(params): @@ -36,6 +33,8 @@ def __init__( factor_precision_shape=2., min_cnv = 1e-6, max_cnv=6., + min_alpha=1e-2, + max_beta=1e6, **kwargs, ): """ @@ -50,6 +49,9 @@ def __init__( self.observed_parameters = np.array(self.cnvs) self.cnvs = jnp.array(self.cnvs) + self.min_log_alpha = jnp.log(min_alpha) + self.max_log_beta = jnp.log(max_beta) + self.n_genes = self.cnvs.size # Node hyperparameters @@ -974,8 +976,8 @@ def update_direction_params(self, direction_params_grad, direction_sample_grad, direction_log_beta_grad = mc_grad + direction_params_entropy_grad[1] self.variational_parameters['kernel']['direction']['log_beta'] += direction_log_beta_grad * step_size - self.variational_parameters['kernel']['direction']['log_alpha'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_alpha'], minval=MIN_ALPHA) - self.variational_parameters['kernel']['direction']['log_beta'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_beta'], maxval=MAX_BETA) + self.variational_parameters['kernel']['direction']['log_alpha'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_alpha'], minval=self.min_log_alpha) + self.variational_parameters['kernel']['direction']['log_beta'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_beta'], maxval=self.max_log_beta) def update_state_params(self, state_params_grad, state_sample_grad, state_params_entropy_grad, step_size=0.001): mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0) @@ -1000,8 +1002,8 @@ def update_cell_scales_params(self, idx, local_params_grad, local_sample_grad, l new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + param_grad * step_size self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param) - self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA) - self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA) + self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=self.min_log_alpha) + self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=self.max_log_beta) def update_obs_weights_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001): mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0) @@ -1031,8 +1033,8 @@ def update_gene_scales_params(self, global_params_grad, global_sample_grad, glob param_grad = mc_grad + global_params_entropy_grad[1] self.variational_parameters['global']['gene_scales']['log_beta'] += param_grad * step_size - self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA) - self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA) + self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=self.min_log_alpha) + self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=self.max_log_beta) def update_factor_weights_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001): mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) @@ -1052,8 +1054,8 @@ def update_factor_precisions_params(self, global_params_grad, global_sample_grad param_grad = mc_grad + global_params_entropy_grad[1] self.variational_parameters['global']['factor_precisions']['log_beta'] += param_grad * step_size - self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA) - self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA) + self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=self.min_log_alpha) + self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=self.max_log_beta) def update_global_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001, param_names=["factor_weights", "gene_scales", "factor_precisions"], **kwargs): @@ -1164,8 +1166,8 @@ def update_gene_scales_adaptive(self, global_params_grad, global_sample_grad, gl state2 = (m, v) self.variational_parameters['global']['gene_scales']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps) - self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA) - self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA) + self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=self.min_log_alpha) + self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=self.max_log_beta) states = (state1, state2) return states @@ -1194,8 +1196,8 @@ def update_factor_precisions_adaptive(self, global_params_grad, global_sample_gr state2 = (m, v) self.variational_parameters['global']['factor_precisions']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps) - self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA) - self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA) + self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=self.min_log_alpha) + self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=self.max_log_beta) states = (state1, state2) return states @@ -1323,8 +1325,8 @@ def update_cell_scales_adaptive(self, idx, local_params_grad, local_sample_grad, new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps) self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param) - self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA) - self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA) + self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=self.min_log_alpha) + self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=self.max_log_beta) states = (state1, state2) return states