diff --git a/batchglm/train/tf2/__init__.py b/batchglm/train/tf2/__init__.py new file mode 100644 index 00000000..9170f2ff --- /dev/null +++ b/batchglm/train/tf2/__init__.py @@ -0,0 +1,3 @@ +from . import glm_nb as nb +from . import glm_norm as norm +from . import glm_beta as beta \ No newline at end of file diff --git a/batchglm/train/tf2/base/__init__.py b/batchglm/train/tf2/base/__init__.py new file mode 100644 index 00000000..9b75ab32 --- /dev/null +++ b/batchglm/train/tf2/base/__init__.py @@ -0,0 +1,3 @@ +from .estimator import TFEstimator +from .model import ProcessModelBase, ModelBase, LossBase +from .optim import OptimizerBase diff --git a/batchglm/train/tf2/base/estimator.py b/batchglm/train/tf2/base/estimator.py new file mode 100644 index 00000000..15fc0906 --- /dev/null +++ b/batchglm/train/tf2/base/estimator.py @@ -0,0 +1,32 @@ +from .external import pkg_constants, TrainingStrategies +from .model import ModelBase, LossBase + +import numpy as np +import tensorflow as tf + + +class TFEstimator: + model: ModelBase + loss: LossBase + + def __init__(self, input_data, dtype): + + self._input_data = input_data + self.dtype = dtype + + def _train( + self, + batched_model: bool, + batch_size: int, + optimizer_object: tf.keras.optimizers.Optimizer, + optimizer_enum: TrainingStrategies, + convergence_criteria: str, + stopping_criteria: int, + autograd: bool, + featurewise: bool, + benchmark: bool + ): + pass + + def fetch_fn(self, idx): + pass diff --git a/batchglm/train/tf2/base/external.py b/batchglm/train/tf2/base/external.py new file mode 100644 index 00000000..08784cca --- /dev/null +++ b/batchglm/train/tf2/base/external.py @@ -0,0 +1,5 @@ +#from batchglm.models.base import _Estimator_Base +#from batchglm.xarray_sparse import SparseXArrayDataArray, SparseXArrayDataSet +from batchglm.train.tf2.base_glm.training_strategies import TrainingStrategies +#import batchglm.utils.stats as stat_utils +from batchglm import pkg_constants diff --git a/batchglm/train/tf2/base/model.py b/batchglm/train/tf2/base/model.py new file mode 100644 index 00000000..acce4dee --- /dev/null +++ b/batchglm/train/tf2/base/model.py @@ -0,0 +1,57 @@ +import abc +import logging +import tensorflow as tf +import numpy as np + +logger = logging.getLogger(__name__) + + +class ModelBase(tf.keras.Model, metaclass=abc.ABCMeta): + + def __init__(self): + super(ModelBase, self).__init__() + + @abc.abstractmethod + def call(self, inputs, training=False, mask=None): + pass + + +class LossBase(tf.keras.losses.Loss, metaclass=abc.ABCMeta): + + def __init__(self): + super(LossBase, self).__init__() + + @abc.abstractmethod + def call(self, y_true, y_pred): + pass + + +class ProcessModelBase: + + @abc.abstractmethod + def param_bounds(self, dtype): + pass + + def tf_clip_param( + self, + param, + name + ): + bounds_min, bounds_max = self.param_bounds(param.dtype) + return tf.clip_by_value( + param, + bounds_min[name], + bounds_max[name] + ) + + def np_clip_param( + self, + param, + name + ): + bounds_min, bounds_max = self.param_bounds(param.dtype) + return np.clip( + param, + bounds_min[name], + bounds_max[name] + ) diff --git a/batchglm/train/tf2/base/optim.py b/batchglm/train/tf2/base/optim.py new file mode 100644 index 00000000..5fc8d13b --- /dev/null +++ b/batchglm/train/tf2/base/optim.py @@ -0,0 +1,52 @@ +import abc +import logging +import tensorflow as tf + +logger = logging.getLogger("batchglm") + + +class OptimizerBase(tf.keras.optimizers.Optimizer, metaclass=abc.ABCMeta): + + def __init__(self, name): + super(OptimizerBase, self).__init__(name=name) + + @abc.abstractmethod + def _resource_apply_dense(self, grad, handle): + pass + + @abc.abstractmethod + def _resource_apply_sparse(self, grad, handle, apply_state): + pass + + @abc.abstractmethod + def _create_slots(self): + pass + + """ + @property + @abc.abstractmethod + def vars(self): + pass + + @property + @abc.abstractmethod + def gradients(self): + return None + + @property + @abc.abstractmethod + def hessians(self): + pass + + @property + @abc.abstractmethod + def fims(self): + pass + + @abc.abstractmethod + def step(self, learning_rate): + pass + """ + @abc.abstractmethod + def get_config(self): + pass diff --git a/batchglm/train/tf2/base_glm/README.md b/batchglm/train/tf2/base_glm/README.md new file mode 100644 index 00000000..eea79ccc --- /dev/null +++ b/batchglm/train/tf2/base_glm/README.md @@ -0,0 +1,2 @@ +# Classes with GLM specific code. +All noise models that are in the GLM category inherit all of these classes. \ No newline at end of file diff --git a/batchglm/train/tf2/base_glm/__init__.py b/batchglm/train/tf2/base_glm/__init__.py new file mode 100644 index 00000000..a662e17d --- /dev/null +++ b/batchglm/train/tf2/base_glm/__init__.py @@ -0,0 +1,10 @@ +from .processModel import ProcessModelGLM +from .model import GLM, LossGLM + +from .estimator import Estimator +from .vars import ModelVarsGLM +from .layers import LinearLocGLM, LinearScaleGLM, LinkerLocGLM, LinkerScaleGLM +from .layers import LikelihoodGLM, UnpackParamsGLM +from .layers_gradients import JacobianGLM, HessianGLM, FIMGLM +from .optim import NR, IRLS +from .training_strategies import TrainingStrategies diff --git a/batchglm/train/tf2/base_glm/estimator.py b/batchglm/train/tf2/base_glm/estimator.py new file mode 100644 index 00000000..11cddd75 --- /dev/null +++ b/batchglm/train/tf2/base_glm/estimator.py @@ -0,0 +1,485 @@ +import abc +import logging +import numpy as np +import scipy +import tensorflow as tf +from .model import GLM +from .training_strategies import TrainingStrategies +from .external import TFEstimator, _EstimatorGLM +from .optim import NR, IRLS +from .external import pkg_constants +import time + +logger = logging.getLogger("batchglm") + +class Estimator(TFEstimator, _EstimatorGLM, metaclass=abc.ABCMeta): + """ + Estimator for Generalized Linear Models (GLMs). + """ + model: GLM + _train_loc: bool + _train_scale: bool + _initialized: bool = False + noise_model: str + + def initialize(self, **kwargs): + self.values = [] + self.times = [] + self.converged = [] + self._initialized = True + + def finalize(self, **kwargs): + """ + Evaluate all tensors that need to be exported from session and save these as class attributes + and close session. + + Changes .model entry from tf-based EstimatorGraph to numpy based Model instance and + transfers relevant attributes. + """ + a_var, b_var = self.model.unpack_params([self.model.params, self.model.model_vars.a_var.get_shape()[0]]) + self.model = self.get_model_container(self._input_data) + self.model._a_var = a_var + self.model._b_var = b_var + self._loss = tf.reduce_sum(-self._log_likelihood / self.input_data.num_observations) + + def __init__( + self, + input_data, + dtype, + ): + + self._input_data = input_data + + TFEstimator.__init__( + self=self, + input_data=input_data, + dtype=dtype, + ) + _EstimatorGLM.__init__( + self=self, + model=None, + input_data=input_data + ) + + def train_sequence(self, training_strategy: []): + for strategy in training_strategy: + self.train( + batched_model=strategy['use_batching'], + optimizer=strategy['optim_algo'], + convergence_criteria=strategy['convergence_criteria'], + stopping_criteria=strategy['stopping_criteria'], + batch_size=strategy['batch_size'] if 'batch_size' in strategy else 500, + learning_rate=strategy['learning_rate'] if 'learning_rate' in strategy else 1e-2, + autograd=strategy['autograd'] if 'autograd' in strategy else False, + featurewise=strategy['featurewise'] if 'featurewise' in strategy else True + ) + + def _train( + self, + noise_model: str, + batched_model: bool = True, + batch_size: int = 500, + optimizer_object: tf.keras.optimizers.Optimizer = tf.keras.optimizers.Adam(), + optimizer_enum: TrainingStrategies = TrainingStrategies.DEFAULT, + convergence_criteria: str = "step", + stopping_criteria: int = 1000, + autograd: bool = False, + featurewise: bool = True, + benchmark: bool = False, + ): + + if not self._initialized: + raise RuntimeError("Cannot train the model: \ + Estimator not initialized. Did you forget to call estimator.initialize() ?") + + if autograd and optimizer_enum.value['hessian']: + logger.warning("Automatic differentiation is currently not supported for hessians. \ + Falling back to closed form. Only Jacobians are calculated using autograd.") + + self.noise_model = noise_model + # Slice data and create batches + data_ids = tf.data.Dataset.from_tensor_slices( + (tf.range(self._input_data.num_observations, name="sample_index", dtype=tf.dtypes.int64)) + ) + if batched_model: + data = data_ids.shuffle(buffer_size=2 * batch_size).repeat().batch(batch_size) + else: + data = data_ids.shuffle(buffer_size=2 * batch_size).batch(batch_size, drop_remainder=True) + input_list = data.map(self.fetch_fn, num_parallel_calls=pkg_constants.TF_NUM_THREADS) + + # Iterate until conditions are fulfilled. + train_step = 0 + + # Set all to convergence status = False, this is needed if multiple + # training strategies are run: + converged_current = np.repeat( + False, repeats=self.model.model_vars.n_features) + + def convergence_decision(convergence_status, train_step): + if convergence_criteria == "step": + return train_step < stopping_criteria + elif convergence_criteria == "all_converged": + return np.any(np.logical_not(convergence_status)) + elif convergence_criteria == "both": + return np.any(np.logical_not(convergence_status)) and train_step < stopping_criteria + else: + raise ValueError("convergence_criteria %s not recognized." % convergence_criteria) + + # fill with highest possible number: + ll_current = np.zeros([self._input_data.num_features], self.dtype) + np.nextafter(np.inf, 0, dtype=self.dtype) + + dataset_iterator = iter(input_list) + calc_separated = False + if optimizer_enum.value["hessian"] is True or optimizer_enum.value["fim"] is True: + second_order_optim = True + calc_separated = optimizer_enum.value['calc_separated'] + update_func = optimizer_object.perform_parameter_update + else: + update_func = optimizer_object.apply_gradients + second_order_optim = False + n_obs = self._input_data.num_observations + + curr_norm_loc = np.sqrt(np.sum(np.square( + np.abs(self.model.params.numpy()[self.model.model_vars.idx_train_loc, :])), axis=0)) + curr_norm_scale = np.sqrt(np.sum(np.square( + np.abs(self.model.params.numpy()[self.model.model_vars.idx_train_scale, :])), axis=0)) + + batch_features = False + while convergence_decision(converged_current, train_step): + # ### Iterate over the batches of the dataset. + # x_batch is a tuple (idx, (X_tensor, design_loc_tensor, design_scale_tensor, size_factors_tensor)) + if benchmark: + t0_epoch = time.time() + + not_converged = np.logical_not(self.model.model_vars.converged) + ll_prev = ll_current.copy() + if train_step % 10 == 0: + logger.info('step %i', train_step) + + if not batched_model: + results = None + x_batch = None + first_batch = True + for x_batch_tuple in input_list: + x_batch = self.getModelInput(x_batch_tuple, batch_features, not_converged) + + current_results = self.model(x_batch) + if first_batch: + results = list(current_results) + first_batch = False + else: + for i, x in enumerate(current_results): + results[i] += x + + else: + x_batch_tuple = next(dataset_iterator) + x_batch = self.getModelInput(x_batch_tuple, batch_features, not_converged) + + results = self.model(x_batch) + if second_order_optim: + if calc_separated: + update_func([x_batch, *results, False, n_obs], True, False, batch_features, ll_prev) + if self._train_scale: + update_func([x_batch, *results, False, n_obs], False, True, batch_features, ll_prev) + else: + update_func([x_batch, *results, False, n_obs], True, True, batch_features, ll_prev) + features_updated = self.model.model_vars.updated + else: + if batch_features: + indices = tf.where(not_converged) + update_var = tf.transpose(tf.scatter_nd( + indices, + tf.transpose(results[1]), + shape=(self.model.model_vars.n_features, results[1].get_shape()[0]) + )) + else: + update_var = results[1] + update_func([(update_var, self.model.params)]) + features_updated = not_converged + + if benchmark: + self.values.append(self.model.trainable_variables[0].numpy().copy()) + + # Update converged status + prev_norm_loc = curr_norm_loc.copy() + prev_norm_scale = curr_norm_scale.copy() + converged_prev = converged_current.copy() + ll_current = self.loss.norm_neg_log_likelihood(results[0]).numpy() + + if batch_features: + indices = tf.where(not_converged) + updated_lls = tf.scatter_nd(indices, ll_current, shape=ll_prev.shape) + ll_current = np.where(features_updated, updated_lls.numpy(), ll_prev) + + if batched_model: + jac_normalization = batch_size + else: + jac_normalization = self._input_data.num_observations + if optimizer_enum.value["optim_algo"] in ['irls', 'irls_gd', 'irls_gd_tr', 'irls_tr']: + grad_numpy = tf.abs(tf.concat((results[1], results[2]), axis=1)) + elif optimizer_enum.value["optim_algo"] in ['nr', 'nr_tr']: + grad_numpy = tf.abs(results[1]) + else: + grad_numpy = tf.abs(tf.transpose(results[1])) + if batch_features: + indices = tf.where(not_converged) + grad_numpy = tf.scatter_nd(indices, grad_numpy, shape=(self.model.model_vars.n_features, + self.model.params.get_shape()[0])) + grad_numpy = grad_numpy.numpy() + convergences = self.calculate_convergence(converged_prev, ll_prev, prev_norm_loc, prev_norm_scale, + ll_current, jac_normalization, grad_numpy, features_updated) + converged_current, converged_f, converged_g, converged_x = convergences + + self.model.model_vars.convergence_update(converged_current, features_updated) + num_converged = np.sum(converged_current).astype("int32") + if np.sum(converged_current) != np.sum(converged_prev): + if featurewise and not batch_features: + batch_features = True + self.model.batch_features = batch_features + logger.info("Step: %i loss: %f, converged %i, updated %i, (logs: %i, grad: %i, x_step: %i)", + train_step, + np.sum(ll_current), + num_converged, + np.sum(features_updated).astype("int32"), + np.sum(converged_f), np.sum(converged_g), np.sum(converged_x)) + train_step += 1 + if benchmark: + t1_epoch = time.time() + self.times.append(t1_epoch-t0_epoch) + self.converged.append(num_converged) + + # Evaluate final params + self._log_likelihood = results[0].numpy() + self._fisher_inv = tf.zeros(shape=()).numpy() + self._hessian = tf.zeros(shape=()).numpy() + + if optimizer_enum.value["hessian"] is True: + self._hessian = results[2].numpy() + self._jacobian = results[1].numpy() + elif optimizer_enum.value["fim"] is True: + self._fisher_inv = tf.concat([results[3], results[4]], axis=0).numpy() + self._jacobian = tf.concat([results[1], results[2]], axis=0).numpy() + else: + self._jacobian = results[1].numpy() + + def getModelInput(self, x_batch_tuple: tuple, batch_features: bool, not_converged): + + if batch_features: + x_tensor, design_loc_tensor, design_scale_tensor, size_factors_tensor = x_batch_tuple + if isinstance(self._input_data.x, scipy.sparse.csr_matrix): + not_converged_idx = np.where(not_converged)[0] + feature_columns = tf.sparse.split( + x_tensor, + num_split=self.model.model_vars.n_features, + axis=1) + feature_columns = [feature_columns[i] for i in not_converged_idx] + x_tensor = tf.sparse.concat(axis=1, sp_inputs=feature_columns) + if not isinstance(x_tensor, tf.sparse.SparseTensor): + raise RuntimeError("x_tensor now dense!!!") + else: + x_tensor = tf.boolean_mask(tensor=x_tensor, mask=not_converged, axis=1) + x_batch = (x_tensor, design_loc_tensor, design_scale_tensor, size_factors_tensor) + else: + x_batch = x_batch_tuple + + return x_batch + + def calculate_convergence(self, converged_prev, ll_prev, prev_norm_loc, prev_norm_scale, ll_current, + jac_normalization, grad_numpy, features_updated): + def get_convergence(converged_previous, condition1, condition2): + return np.logical_or(converged_previous, np.logical_and(condition1, condition2)) + + def get_convergence_by_method(converged_previous, condition1, condition2): + return np.logical_and(np.logical_not(converged_previous), np.logical_and(condition1, condition2)) + + def calc_x_step(idx_train, prev_norm): + if len(idx_train) > 0 and len(self.values) > 1: + curr_norm = np.sqrt(np.sum(np.square( + np.abs(self.model.params.numpy()[idx_train, :]) + ), axis=0)) + return np.abs(curr_norm - prev_norm) + else: + return np.zeros([self.model.model_vars.n_features]) + np.nextafter(np.inf, 0, dtype=self.dtype) + + x_norm_loc = calc_x_step(self.model.model_vars.idx_train_loc, prev_norm_loc) + x_norm_scale = calc_x_step(self.model.model_vars.idx_train_scale, prev_norm_scale) + + ll_converged = np.abs(ll_prev - ll_current) / ll_prev < pkg_constants.LLTOL_BY_FEATURE + + converged_current = get_convergence(converged_prev, ll_converged, features_updated) + + # those features which were not converged in the prev run, but converged now + converged_f = get_convergence_by_method(converged_prev, ll_converged, features_updated) + grad_loc = np.sum(grad_numpy[:, self.model.model_vars.idx_train_loc], axis=1) + grad_norm_loc = grad_loc / jac_normalization + grad_scale = np.sum(grad_numpy[:, self.model.model_vars.idx_train_scale], axis=1) + grad_norm_scale = grad_scale / jac_normalization + + converged_current = get_convergence(converged_current, + grad_norm_loc < pkg_constants.GTOL_BY_FEATURE_LOC, + grad_norm_scale < pkg_constants.GTOL_BY_FEATURE_SCALE) + # those features which were not converged in the prev run, but converged now + converged_g = get_convergence_by_method(converged_prev, + grad_norm_loc < pkg_constants.GTOL_BY_FEATURE_LOC, + grad_norm_scale < pkg_constants.GTOL_BY_FEATURE_SCALE) + + # Step length: + converged_current = get_convergence(converged_current, + x_norm_loc < pkg_constants.XTOL_BY_FEATURE_LOC, + x_norm_scale < pkg_constants.XTOL_BY_FEATURE_SCALE) + + # those features which were not converged in the prev run, but converged now + converged_x = get_convergence_by_method(converged_prev, + x_norm_loc < pkg_constants.XTOL_BY_FEATURE_LOC, + x_norm_scale < pkg_constants.XTOL_BY_FEATURE_SCALE) + return converged_current, converged_f, converged_g, converged_x + + def get_optimizer_object(self, optimizer, learning_rate): + + optimizer = optimizer.lower() + + if optimizer == "gd": + return tf.keras.optimizers.SGD(learning_rate=learning_rate), TrainingStrategies.GD + if optimizer == "adam": + return tf.keras.optimizers.Adam(learning_rate=learning_rate), TrainingStrategies.ADAM + if optimizer == "adagrad": + return tf.keras.optimizers.Adagrad(learning_rate=learning_rate), TrainingStrategies.ADAGRAD + if optimizer == "rmsprop": + return tf.keras.optimizers.RMSprop(learning_rate=learning_rate), TrainingStrategies.RMSPROP + if optimizer == "irls": + return IRLS(dtype=self.dtype, + trusted_region_mode=False, + model=self.model, + name="IRLS"), TrainingStrategies.IRLS + if optimizer == "irls_tr": + return IRLS(dtype=self.dtype, + trusted_region_mode=True, + model=self.model, + name="IRLS_TR"), TrainingStrategies.IRLS_TR + if optimizer == "irls_gd": + return IRLS(dtype=self.dtype, + trusted_region_mode=False, + model=self.model, + name="IRLS_GD"), TrainingStrategies.IRLS_GD + if optimizer == "irls_gd_tr": + return IRLS(dtype=self.dtype, + trusted_region_mode=True, + model=self.model, + name="IRLS_GD_TR"), TrainingStrategies.IRLS_GD_TR + if optimizer == "nr": + return NR(dtype=self.dtype, + trusted_region_mode=False, + model=self.model, + name="NR"), TrainingStrategies.NR + if optimizer == "nr_tr": + return NR(dtype=self.dtype, + trusted_region_mode=True, + model=self.model, + name="NR_TR"), TrainingStrategies.NR_TR + + logger.warning("No valid optimizer given. Default optimizer Adam chosen.") + return tf.keras.optimizers.Adam(learning_rate=learning_rate), TrainingStrategies.ADAM + + def fetch_fn(self, idx): + """ + Documentation of tensorflow coding style in this function: + tf.py_func defines a python function (the getters of the InputData object slots) + as a tensorflow operation. Here, the shape of the tensor is lost and + has to be set with set_shape. For size factors, we use explicit broadcasting + as explained below. + """ + # Catch dimension collapse error if idx is only one element long, ie. 0D: + if len(idx.shape) == 0: + idx = tf.expand_dims(idx, axis=0) + + if isinstance(self._input_data.x, scipy.sparse.csr_matrix): + + x_tensor_idx, x_tensor_val, x = tf.py_function( + func=self._input_data.fetch_x_sparse, + inp=[idx], + Tout=[np.int64, np.float64, np.int64], + ) + # Note on Tout: np.float64 for val seems to be required to avoid crashing v1.12. + x_tensor_idx = tf.cast(x_tensor_idx, dtype=tf.int64) + x = tf.cast(x, dtype=tf.int64) + x_tensor_val = tf.cast(x_tensor_val, dtype=self.dtype) + x_tensor = tf.SparseTensor(x_tensor_idx, x_tensor_val, x) + x_tensor = tf.cast(x_tensor, dtype=self.dtype) + + else: + + x_tensor = tf.py_function( + func=self._input_data.fetch_x_dense, + inp=[idx], + Tout=self._input_data.x.dtype, + ) + + x_tensor.set_shape(idx.get_shape().as_list() + [self._input_data.num_features]) + x_tensor = tf.cast(x_tensor, dtype=self.dtype) + + design_loc_tensor = tf.py_function( + func=self._input_data.fetch_design_loc, + inp=[idx], + Tout=self._input_data.design_loc.dtype, + ) + design_loc_tensor.set_shape(idx.get_shape().as_list() + [self._input_data.num_design_loc_params]) + design_loc_tensor = tf.cast(design_loc_tensor, dtype=self.dtype) + + design_scale_tensor = tf.py_function( + func=self._input_data.fetch_design_scale, + inp=[idx], + Tout=self._input_data.design_scale.dtype, + ) + design_scale_tensor.set_shape(idx.get_shape().as_list() + [self._input_data.num_design_scale_params]) + design_scale_tensor = tf.cast(design_scale_tensor, dtype=self.dtype) + + if self._input_data.size_factors is not None and self.noise_model in ["nb", "norm"]: + size_factors_tensor = tf.py_function( + func=self._input_data.fetch_size_factors, + inp=[idx], + Tout=self._input_data.size_factors.dtype, + ) + + size_factors_tensor.set_shape(idx.get_shape()) + size_factors_tensor = tf.expand_dims(size_factors_tensor, axis=-1) + size_factors_tensor = tf.cast(size_factors_tensor, dtype=self.dtype) + + else: + size_factors_tensor = tf.constant(1, shape=[1, 1], dtype=self.dtype) + + # feature batching + return x_tensor, design_loc_tensor, design_scale_tensor, size_factors_tensor + + @staticmethod + def get_init_from_model(init_a, init_b, input_data, init_model): + # Locations model: + if isinstance(init_a, str) and (init_a.lower() == "auto" or init_a.lower() == "init_model"): + my_loc_names = set(input_data.loc_names) + my_loc_names = my_loc_names.intersection(set(init_model.input_data.loc_names)) + + init_loc = np.zeros([input_data.num_loc_params, input_data.num_features]) + for parm in my_loc_names: + init_idx = np.where(init_model.input_data.loc_names == parm)[0] + my_idx = np.where(input_data.loc_names == parm)[0] + init_loc[my_idx] = init_model.a_var[init_idx] + + init_a = init_loc + + # Scale model: + if isinstance(init_b, str) and (init_b.lower() == "auto" or init_b.lower() == "init_model"): + my_scale_names = set(input_data.scale_names) + my_scale_names = my_scale_names.intersection(init_model.input_data.scale_names) + + init_scale = np.zeros([input_data.num_scale_params, input_data.num_features]) + for parm in my_scale_names: + init_idx = np.where(init_model.input_data.scale_names == parm)[0] + my_idx = np.where(input_data.scale_names == parm)[0] + init_scale[my_idx] = init_model.b_var[init_idx] + + init_b = init_scale + + return init_a, init_b + + @abc.abstractmethod + def get_model_container(self, input_data): + pass diff --git a/batchglm/train/tf2/base_glm/external.py b/batchglm/train/tf2/base_glm/external.py new file mode 100644 index 00000000..9188d2b0 --- /dev/null +++ b/batchglm/train/tf2/base_glm/external.py @@ -0,0 +1,9 @@ +from batchglm.train.tf2.base import ProcessModelBase, ModelBase, TFEstimator +from batchglm.train.tf2.base import OptimizerBase, LossBase +#from batchglm.train.tf2.glm_nb import NR, IRLS + +from batchglm.models.base_glm import InputDataGLM, _ModelGLM, _EstimatorGLM + +#import batchglm.train.tf.ops as op_utils +from batchglm.utils.linalg import groupwise_solve_lm +from batchglm import pkg_constants diff --git a/batchglm/train/tf2/base_glm/layers.py b/batchglm/train/tf2/base_glm/layers.py new file mode 100644 index 00000000..8ced3a4b --- /dev/null +++ b/batchglm/train/tf2/base_glm/layers.py @@ -0,0 +1,268 @@ +from typing import Union + +import abc +import tensorflow as tf + +from .processModel import ProcessModelGLM + + +class UnpackParamsGLM(tf.keras.layers.Layer, ProcessModelGLM): + + """ + Layer that slices the parameter tensor into mean and variance block. + """ + + def __init__(self): + super(UnpackParamsGLM, self).__init__() + + def call(self, inputs, **kwargs): + """ + :param inputs: tuple (params, border) + Must contain the parameter matrix (params) and the first index + of the variance block within the parameters matrix (border) + + :return tf.Tensor, tf.Tensor + The two returned tensor correspond to the mean and variance block + of the parameter matrix. + """ + params, border = inputs + a_var = params[0:border] # loc obs + b_var = params[border:] # scale obs + a_var = self.tf_clip_param(a_var, "a_var") + b_var = self.tf_clip_param(b_var, "b_var") + return a_var, b_var + + +class LinearLocGLM(tf.keras.layers.Layer, ProcessModelGLM): + + """ + Computes the dot product between the design matrix of the mean model and the mean block of the parameter matrix. + """ + + def __init__(self): + super(LinearLocGLM, self).__init__() + + def _eta_loc( + self, + a_var: tf.Tensor, + design_loc: tf.Tensor, + constraints_loc: Union[tf.Tensor, None] = None, + size_factors: Union[tf.Tensor, None] = None + ): + """ + Does the actual computation of eta_loc. + + :param a_var: tf.Tensor + the mean block of the parameter matrix + :param design_loc: tf.Tensor + the design matrix of the mean model + :param contraints_loc: tf.Tensor, optional + ??? # TODO + :param size_factors: tf.Tensor, optional + ??? # TODO + + :return tf.Tensor + the mean values for each individual distribution, encoded in linker space. + """ + if constraints_loc is not None: + eta_loc = tf.matmul( + design_loc, + tf.matmul(constraints_loc, a_var) + ) + else: + eta_loc = tf.matmul(design_loc, a_var) + + if size_factors is not None and size_factors.shape != (1, 1): + eta_loc = self.with_size_factors(eta_loc, size_factors) + + eta_loc = self.tf_clip_param(eta_loc, "eta_loc") + + return eta_loc + + @abc.abstractmethod + def with_size_factors(self, eta_loc, size_factors): + """ + Calculates eta_loc with size_factors. Is noise model specific and needs to be implemented in the inheriting + layer. + :param eta_loc: tf.Tensor + the mean values for each individual distribution, encoded in linker space + """ + + def call(self, inputs, **kwargs): + """ + Calculates the eta_loc tensor, containing the mean values for each individual distribution, + encoded in linker space. + + :param input: tuple + Must contain a_var, design_loc, constraints_loc and size_factors in this order, where + contraints_loc and size_factor can be None. + + :return tf.Tensor + the mean values for each individual distribution, encoded in linker space. + """ + return self._eta_loc(*inputs) + + +class LinearScaleGLM(tf.keras.layers.Layer, ProcessModelGLM): + + """ + Computes the dot product between the design matrix of the variance model + and the variance block of the parameter matrix. + """ + + def __init__(self): + super(LinearScaleGLM, self).__init__() + + def _eta_scale( + self, + b_var: tf.Tensor, + design_scale: tf.Tensor, + constraints_scale: Union[tf.Tensor, None] = None + ): + """ + Does the actual computation of eta_scale. + + :param b_var: tf.Tensor + the variance block of the parameter matrix + :param design_scale: tf.Tensor + the design matrix of the mean model + :param contraints_scale: tf.Tensor, optional + ??? # TODO + + :return tf.Tensor + the variance values for each individual distribution, encoded in linker space. + """ + if constraints_scale is not None: + eta_scale = tf.matmul( + design_scale, + tf.matmul(constraints_scale, b_var) + ) + else: + eta_scale = tf.matmul(design_scale, b_var) + + eta_scale = self.tf_clip_param(eta_scale, "eta_scale") + + return eta_scale + + def call(self, inputs, **kwargs): + """ + Calculates the eta_scale tensor, containing the variance values for each individual distribution, + encoded in linker space. + + :param input: tuple + Must contain b_var, design_scale and constraints_loc in this order, where + contraints_loc can be None. + + :return tf.Tensor + the variance values for each individual distribution, encoded in linker space. + """ + return self._eta_scale(*inputs) + + +class LinkerLocGLM(tf.keras.layers.Layer): + + """ + Translation from linker to data space for the mean model. + """ + + def __init__(self): + super(LinkerLocGLM, self).__init__() + + @abc.abstractmethod + def _inv_linker(self, loc: tf.Tensor): + """ + Translates the given mean values from linker to data space. Depends on the given noise model and needs to + be implemented in the inheriting layer. + + :param loc: tf. Tensor + the mean values for each individual distribution, encoded in linker space. + + :return tf.Tensor + the mean values for each individual distribution, encoded in data space. + """ + + def call(self, eta_loc: tf.Tensor, **kwargs): + """ + Calls the distribution specific linker function to translate from linker to data space. + + :param eta_loc: tf.Tensor + the mean values for each individual distribution, encoded in linker space. + + :return tf.Tensor + the mean values for each individual distribution, encoded in data space. + """ + loc = self._inv_linker(eta_loc) + return loc + + +class LinkerScaleGLM(tf.keras.layers.Layer): + + """ + Translation from linker to data space for the variance model. + """ + + def __init__(self): + super(LinkerScaleGLM, self).__init__() + + @abc.abstractmethod + def _inv_linker(self, scale: tf.Tensor): + pass + + def call(self, eta_scale: tf.Tensor, **kwargs): + """ + Calls the distribution specific linker function to translate from linker to data space. + + :param eta_scale: tf.Tensor + the variance values for each individual distribution, encoded in linker space. + + :return tf.Tensor + the variance values for each individual distribution, encoded in data space. + """ + scale = self._inv_linker(eta_scale) + return scale + + +class LikelihoodGLM(tf.keras.layers.Layer, ProcessModelGLM): + + """ + Contains the computation of the distribution specific log-likelihood function + """ + + def __init__(self, dtype): + super(LikelihoodGLM, self).__init__() + self.ll_dtype = dtype + + @abc.abstractmethod + def _ll(self, eta_loc, eta_scale, loc, scale, x, n_features): + """ + Does the actual likelihood calculation. Depends on the given noise model and needs to be implemented in the + inheriting layer. + + :param eta_loc: tf.Tensor + the mean values for each individual distribution, encoded in linker space. + :param eta_scale: tf.Tensor + the variance values for each individual distribution, encoded in linker space. + :param loc: tf.Tensor + the mean values for each individual distribution, encoded in data space. + :param scale: tf.Tensor + the variance values for each individual distribution, encoded in data space. + :param x: tf.Tensor + the input data + :param n_features + number of features. + + :return tf.Tensor + the log-likelihoods of each individual data point. + """ + + def call(self, inputs, **kwargs): + """ + Calls the distribution specific log-likelihood function. + + :param inputs: tuple + Must contain eta_loc, eta_scale, loc, scale, x, n_features in this order. + + :return tf.Tensor + the log-likelihoods of each individual data point. + """ + return self._ll(*inputs) diff --git a/batchglm/train/tf2/base_glm/layers_gradients.py b/batchglm/train/tf2/base_glm/layers_gradients.py new file mode 100644 index 00000000..01b7dfb7 --- /dev/null +++ b/batchglm/train/tf2/base_glm/layers_gradients.py @@ -0,0 +1,450 @@ +import abc +import tensorflow as tf + + +class Gradient(tf.keras.layers.Layer): + + """Superclass for Jacobians, Hessian, FIM""" + + def __init__(self, model_vars, compute_a, compute_b, dtype): + super(Gradient, self).__init__() + self.model_vars = model_vars + self.compute_a = compute_a + self.compute_b = compute_b + self.grad_dtype = dtype + + @abc.abstractmethod + def call(self, inputs, **kwargs): + pass + + @staticmethod + def calc_design_mat(design_mat, constraints): + if constraints is not None: + xh = tf.matmul(design_mat, constraints) + else: + xh = design_mat + return xh + + # Here, we use the einsum to efficiently perform the two outer products and the marginalisation. + @staticmethod + def create_specific_block(w, xh_loc, xh_scale): + return tf.einsum('ofc,od->fcd', tf.einsum('of,oc->ofc', w, xh_loc), xh_scale) + + +class FIMGLM(Gradient): + """ + Compute expected fisher information matrix (FIM) + for iteratively re-weighted least squares (IWLS or IRLS) parameter updates for GLMs. + """ + + def call(self, inputs, **kwargs): + return self._fim_analytic(*inputs) + + def _fim_analytic(self, x, design_loc, design_scale, loc, scale, concat=False) -> tf.Tensor: + """ + Compute the closed-form of the base_glm_all model fim + by evalutating its terms grouped by observations. + """ + + def _a_byobs(): + """ + Compute the mean model diagonal block of the + closed form fim of base_glm_all model by observation across features + for a batch of observations. + """ + w = self._weight_fim_aa(x=x, loc=loc, scale=scale) # [observations x features] + # The computation of the fim block requires two outer products between + # feature-wise constants and the coefficient wise design matrix entries, for each observation. + # The resulting tensor is observations x features x coefficients x coefficients which + # is too large too store in memory in most cases. However, the full 4D tensor is never + # actually needed but only its marginal across features, the final hessian block shape. + # Here, we use the einsum to efficiently perform the two outer products and the marginalisation. + xh = self.calc_design_mat(design_loc, self.model_vars.constraints_loc) + + fim_block = self.create_specific_block(w, xh, xh) + return fim_block + + def _b_byobs(): + """ + Compute the dispersion model diagonal block of the + closed form fim of base_glm_all model by observation across features. + """ + w = self._weight_fim_bb(x=x, loc=loc, scale=scale) # [observations=1 x features] + # The computation of the fim block requires two outer products between + # feature-wise constants and the coefficient wise design matrix entries, for each observation. + # The resulting tensor is observations x features x coefficients x coefficients which + # is too large too store in memory in most cases. However, the full 4D tensor is never + # actually needed but only its marginal across features, the final hessian block shape. + # Here, we use the Einstein summation to efficiently perform the two outer products and the marginalisation. + xh = self.calc_design_mat(design_scale, self.model_vars.constraints_scale) + + fim_block = self.create_specific_block(w, xh, xh) + return fim_block + + # The full fisher information matrix is block-diagonal with the cross-model + # blocks all zero. Accordingly, mean and dispersion model updates can be + # treated independently and the full fisher information matrix is never required. + # Here, the non-zero model-wise diagonal blocks are computed and returned + # as a dictionary. The according score function vectors are also returned as a dictionary. + + if self.compute_a and self.compute_b: + fim_a = _a_byobs() + fim_b = _b_byobs() + + elif self.compute_a and not self.compute_b: + fim_a = _a_byobs() + fim_b = tf.zeros(fim_a.get_shape(), self.grad_dtype) + elif not self.compute_a and self.compute_b: + fim_a = tf.zeros(fim_a.get_shape(), self.grad_dtype) + fim_b = _b_byobs() + else: + fim_a = tf.zeros_like(self.model_vars.a_var, dtype=self.grad_dtype) + fim_b = tf.zeros_like(self.model_vars.b_var, dtype=self.grad_dtype) + + if concat: + fim = tf.concat([fim_a, fim_b], axis=1) + return fim + else: + return fim_a, fim_b + + @abc.abstractmethod + def _weight_fim_aa( + self, + x, + loc, + scale + ): + """ + Compute for mean model IWLS update for a GLM. + + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + + :return tuple of tf.tensors + Constants with respect to coefficient index for + Fisher information matrix and score function computation. + """ + pass + + @abc.abstractmethod + def _weight_fim_bb( + self, + x, + loc, + scale + ): + """ + Compute for dispersion model IWLS update for a GLM. + + :param x: tf.tensor observations x features + Observation by observation and feature. + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + + :return tuple of tf.tensors + Constants with respect to coefficient index for + Fisher information matrix and score function computation. + """ + pass + + +class JacobianGLM(Gradient): + + def call(self, inputs, **kwargs): + return self._jac_analytic(*inputs) + + def _jac_analytic(self, x, design_loc, design_scale, loc, scale, concat) -> tf.Tensor: + """ + Compute the closed-form of the base_glm_all model jacobian + by evalutating its terms grouped by observations. + + :param x: tf.tensor observations x features + Observation by observation and feature. + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + """ + + def _a_byobs(): + """ + Compute the mean model block of the jacobian. + + :return Jblock: tf.tensor features x coefficients + Block of jacobian. + """ + w = self._weights_jac_a(x=x, loc=loc, scale=scale) # [observations, features] + xh = self.calc_design_mat(design_loc, self.model_vars.constraints_loc) # [observations, coefficient] + + jblock = tf.matmul(tf.transpose(w), xh) # [features, coefficients] + return jblock + + def _b_byobs(): + """ + Compute the dispersion model block of the jacobian. + + :return Jblock: tf.tensor features x coefficients + Block of jacobian. + """ + w = self._weights_jac_b(x=x, loc=loc, scale=scale) # [observations, features] + xh = self.calc_design_mat(design_scale, self.model_vars.constraints_scale) # [observations, coefficient] + + jblock = tf.matmul(tf.transpose(w), xh) # [features, coefficients] + return jblock + + if self.compute_a and self.compute_b: + j_a = _a_byobs() + j_b = _b_byobs() + elif self.compute_a and not self.compute_b: + j_a = _a_byobs() + j_b = tf.zeros((j_a.get_shape()[0], self.model_vars.b_var.get_shape()[0]), dtype=self.grad_dtype) + elif not self.compute_a and self.compute_b: + j_b = _b_byobs() + j_a = tf.zeros((j_b.get_shape()[0], self.model_vars.b_var.get_shape()[0]), dtype=self.grad_dtype) + else: + j_a = tf.transpose(tf.zeros_like(self.model_vars.a_var, dtype=self.grad_dtype)) + j_b = tf.transpose(tf.zeros_like(self.model_vars.b_var, dtype=self.grad_dtype)) + + if concat: + j = tf.concat([j_a, j_b], axis=1) + return j + else: + return j_a, j_b + + @abc.abstractmethod + def _weights_jac_a( + self, + x, + loc, + scale + ): + """ + Compute the coefficient index invariant part of the + mean model gradient. + + :param x: tf.tensor observations x features + Observation by observation and feature. + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + + :return const: tf.tensor observations x features + Coefficient invariant terms of hessian of + given observations and features. + """ + pass + + @abc.abstractmethod + def _weights_jac_b( + self, + x, + loc, + scale + ): + """ + Compute the coefficient index invariant part of the + dispersion model gradient. + + :param x: tf.tensor observations x features + Observation by observation and feature. + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + + :return const: tf.tensor observations x features + Coefficient invariant terms of hessian of + given observations and features. + """ + pass + + +class HessianGLM(Gradient): + """ + Compute the closed-form of the base_glm_all model hessian + by evaluating its terms grouped by observations. + + Has three sub-functions which built the specific blocks of the hessian + and one sub-function which concatenates the blocks into a full hessian. + """ + + def call(self, inputs, **kwargs): + return self._hessian_analytic(*inputs) + + def _hessian_analytic(self, x, design_loc, design_scale, loc, scale, concat) -> tf.Tensor: + """ + Compute the closed-form of the base_glm_all model hessian + by evaluating its terms grouped by observations. + + Has three sub-functions which built the specific blocks of the hessian + and one sub-function which concatenates the blocks into a full hessian. + """ + + def _aa_byobs_batched(): + """ + Compute the mean model diagonal block of the + closed form hessian of base_glm_all model by observation across features + for a batch of observations. + """ + w = self._weight_hessian_aa(x=x, loc=loc, scale=scale) # [observations x features] + # The computation of the hessian block requires two outer products between + # feature-wise constants and the coefficient wise design matrix entries, for each observation. + # The resulting tensor is observations x features x coefficients x coefficients which + # is too large too store in memory in most cases. However, the full 4D tensor is never + # actually needed but only its marginal across features, the final hessian block shape. + # Here, we use the einsum to efficiently perform the two outer products and the marginalisation. + xh = self.calc_design_mat(design_loc, self.model_vars.constraints_loc) + + hblock = self.create_specific_block(w, xh, xh) + return hblock + + def _bb_byobs_batched(): + """ + Compute the dispersion model diagonal block of the + closed form hessian of base_glm_all model by observation across features. + """ + w = self._weight_hessian_bb(x=x, loc=loc, scale=scale) # [observations x features] + # The computation of the hessian block requires two outer products between + # feature-wise constants and the coefficient wise design matrix entries, for each observation. + # The resulting tensor is observations x features x coefficients x coefficients which + # is too large too store in memory in most cases. However, the full 4D tensor is never + # actually needed but only its marginal across features, the final hessian block shape. + # Here, we use the Einstein summation to efficiently perform the two outer products and the marginalisation. + xh = self.calc_design_mat(design_scale, self.model_vars.constraints_scale) + + hblock = self.create_specific_block(w, xh, xh) + return hblock + + def _ab_byobs_batched(): + """ + Compute the mean-dispersion model off-diagonal block of the + closed form hessian of base_glm_all model by observastion across features. + + Note that there are two blocks of the same size which can + be compute from each other with a transpose operation as + the hessian is symmetric. + """ + w = self._weight_hessian_ab(x=x, loc=loc, scale=scale) # [observations x features] + # The computation of the hessian block requires two outer products between + # feature-wise constants and the coefficient wise design matrix entries, for each observation. + # The resulting tensor is observations x features x coefficients x coefficients which + # is too large too store in memory in most cases. However, the full 4D tensor is never + # actually needed but only its marginal across features, the final hessian block shape. + # Here, we use the Einstein summation to efficiently perform the two outer products and the marginalisation. + xhloc = self.calc_design_mat(design_loc, self.model_vars.constraints_loc) + xhscale = self.calc_design_mat(design_scale, self.model_vars.constraints_scale) + + hblock = self.create_specific_block(w, xhloc, xhscale) + return hblock + + if self.compute_a and self.compute_b: + h_aa = _aa_byobs_batched() + h_bb = _bb_byobs_batched() + h_ab = _ab_byobs_batched() + h_ba = tf.transpose(h_ab, perm=[0, 2, 1]) + elif self.compute_a and not self.compute_b: + h_aa = _aa_byobs_batched() + h_bb = tf.zeros_like(h_aa, dtype=self.grad_dtype) + h_ab = tf.zeros_like(h_aa, dtype=self.grad_dtype) + h_ba = tf.zeros_like(h_aa, dtype=self.grad_dtype) + elif not self.compute_a and self.compute_b: + h_bb = _bb_byobs_batched() + h_aa = tf.zeros_like(h_bb, dtype=self.grad_dtype) + h_ab = tf.zeros_like(h_bb, dtype=self.grad_dtype) + h_ba = tf.zeros_like(h_bb, dtype=self.grad_dtype) + else: + h_aa = tf.zeros((), dtype=self.grad_dtype) + h_bb = tf.zeros((), dtype=self.grad_dtype) + h_ab = tf.zeros((), dtype=self.grad_dtype) + h_ba = tf.zeros((), dtype=self.grad_dtype) + + if concat: + h = tf.concat( + [tf.concat([h_aa, h_ab], axis=2), + tf.concat([h_ba, h_bb], axis=2)], + axis=1 + ) + return h + else: + return h_aa, h_ab, h_ba, h_bb + + @abc.abstractmethod + def _weight_hessian_aa( + self, + x, + loc, + scale + ): + """ + Compute the coefficient index invariant part of the + mean model block of the hessian. + + :param x: tf.tensor observations x features + Observation by observation and feature. + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + + :return const: tf.tensor observations x features + Coefficient invariant terms of hessian of + given observations and features. + """ + pass + + @abc.abstractmethod + def _weight_hessian_bb( + self, + x, + loc, + scale + ): + """ + Compute the coefficient index invariant part of the + dispersion model block of the hessian. + + :param x: tf.tensor observations x features + Observation by observation and feature. + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + + :return const: tf.tensor observations x features + Coefficient invariant terms of hessian of + given observations and features. + """ + pass + + @abc.abstractmethod + def _weight_hessian_ab( + self, + x, + loc, + scale + ): + """ + Compute the coefficient index invariant part of the + mean-dispersion model block of the hessian. + + Note that there are two blocks of the same size which can + be compute from each other with a transpose operation as + the hessian is symmetric. + + :param x: tf.tensor observations x features + Observation by observation and feature. + :param loc: tf.tensor observations x features + Value of mean model by observation and feature. + :param scale: tf.tensor observations x features + Value of dispersion model by observation and feature. + + :return const: tf.tensor observations x features + Coefficient invariant terms of hessian of + given observations and features. + """ + pass diff --git a/batchglm/train/tf2/base_glm/model.py b/batchglm/train/tf2/base_glm/model.py new file mode 100644 index 00000000..cbf2d6d1 --- /dev/null +++ b/batchglm/train/tf2/base_glm/model.py @@ -0,0 +1,226 @@ +import logging +import tensorflow as tf +import numpy as np +from .external import ModelBase, LossBase +from .processModel import ProcessModelGLM +from .training_strategies import TrainingStrategies + +logger = logging.getLogger("batchglm") + + +class GLM(ModelBase, ProcessModelGLM): + + """ + base GLM class containg the model call. + """ + + TS: {} = TrainingStrategies.DEFAULT.value + compute_a: bool = True + compute_b: bool = True + + def __init__( + self, + model_vars, + unpack_params: tf.keras.layers.Layer, + linear_loc: tf.keras.layers.Layer, + linear_scale: tf.keras.layers.Layer, + linker_loc: tf.keras.layers.Layer, + linker_scale: tf.keras.layers.Layer, + likelihood: tf.keras.layers.Layer, + jacobian: tf.keras.layers.Layer, + hessian: tf.keras.layers.Layer, + fim: tf.keras.layers.Layer, + use_gradient_tape: bool = False + ): + super(GLM, self).__init__() + self.model_vars = model_vars + self.params = tf.Variable(tf.concat( + [ + model_vars.init_a_clipped, + model_vars.init_b_clipped, + ], + axis=0 + ), name="params", trainable=True) + + self.unpack_params = unpack_params + self.linear_loc = linear_loc + self.linear_scale = linear_scale + self.linker_loc = linker_loc + self.linker_scale = linker_scale + self.likelihood = likelihood + self.jacobian = jacobian + self.hessian = hessian + self.fim = fim + self.use_gradient_tape = use_gradient_tape + self.params_copy = None + self.batch_features = False + + def _call_parameters(self, inputs, keep_previous_params_copy=False): + if not keep_previous_params_copy: + if self.batch_features: + self.params_copy = tf.Variable(tf.boolean_mask(tensor=self.params, + mask=tf.logical_not(self.model_vars.converged), + axis=1), trainable=True) + else: + self.params_copy = self.params + design_loc, design_scale, size_factors = inputs + a_var, b_var = self.unpack_params([self.params_copy, self.model_vars.a_var.get_shape()[0]]) + eta_loc = self.linear_loc([a_var, design_loc, self.model_vars.constraints_loc, size_factors]) + eta_scale = self.linear_scale([b_var, design_scale, self.model_vars.constraints_scale]) + loc = self.linker_loc(eta_loc) + scale = self.linker_scale(eta_scale) + return eta_loc, eta_scale, loc, scale, a_var, b_var + + def calc_ll(self, inputs, keep_previous_params_copy=False): + parameters = self._call_parameters(inputs[1:], keep_previous_params_copy) + log_probs = self.likelihood([*parameters[:-2], inputs[0], np.sum(self.model_vars.updated)]) + return (log_probs, *parameters[2:]) + + def _calc_jacobians(self, inputs, concat, transpose=True): + """ + calculates jacobian. + + :param inputs: TODO + :param concat: boolean + if true, concatenates the loc and scale block. + :param transpose: bool + transpose the gradient if true. + autograd returns gradients with respect to the shape of self.params. + But analytic differentiation returns it the other way, which is + often needed for downstream operations (e.g. hessian) + Therefore, if self.use_gradient_tape, it will transpose if transpose == False + """ + + with tf.GradientTape(persistent=True) as g: + log_probs, loc, scale, a_var, b_var = self.calc_ll(inputs) + + if self.use_gradient_tape: + + if self.compute_a: + if self.compute_b: + if concat: + jacobians = g.gradient(log_probs, self.params_copy) + if not transpose: + jacobians = tf.transpose(jacobians) + else: + jac_a = g.gradient(log_probs, a_var) + jac_b = g.gradient(log_probs, b_var) + if not transpose: + jac_a = tf.transpose(jac_a) + jac_b = tf.transpose(jac_b) + else: + jac_a = g.gradient(log_probs, a_var) + jac_b = tf.zeros((jac_a.get_shape()[0], b_var.get_shape()[1]), b_var.dtype) + if concat: + jacobians = tf.concat([jac_a, jac_b], axis=0) + if not transpose: + jacobians = tf.transpose(jacobians) + else: + jac_b = g.gradient(log_probs, b_var) + jac_a = tf.zeros((jac_b.get_shape()[0], a_var.get_shape()[0]), a_var.dtype) + if concat: + jacobians = tf.concat([jac_a, jac_b], axis=0) + if not transpose: + jacobians = tf.transpose(jacobians) + + else: + + if concat: + jacobians = self.jacobian([*inputs[0:3], loc, scale, True]) + if transpose: + jacobians = tf.transpose(jacobians) + else: + jac_a, jac_b = self.jacobian([*inputs[0:3], loc, scale, False]) + + del g + if concat: + return loc, scale, log_probs, tf.negative(jacobians) + return loc, scale, log_probs, tf.negative(jac_a), tf.negative(jac_b) + + def call(self, inputs, training=False, mask=None): + # X_data, design_loc, design_scale, size_factors = inputs + + # This is for first order optimizations, which get the full jacobian + + concat = self.TS["concat_grads"] + + if self.TS["jacobian"] is True: + _, _, log_probs, jacobians = self._calc_jacobians(inputs, concat=concat) + return log_probs, jacobians + + # This is for SecondOrder NR/NR_TR + if self.TS["hessian"] is True: + + # with tf.GradientTape(persistent=True) as g2: + if concat: + loc, scale, log_probs, jacobians = self._calc_jacobians(inputs, concat=True, transpose=False) + else: + loc, scale, log_probs, jac_a, jac_b = self._calc_jacobians(inputs, concat=False, transpose=False) + # results_arr = [jacobians[:, i] for i in tf.range(self.params_copy.get_shape()[0])] + + ''' + autograd not yet working. TODO: Search error in the following code: + + if self.use_gradient_tape: + + i = tf.constant(0, tf.int32) + h_tensor_array = tf.TensorArray( # hessian slices [:,:,j] + dtype=self.params_copy.dtype, + size=self.params_copy.get_shape()[0], + clear_after_read=False + ) + while i < self.params_copy.get_shape()[0]: + grad = g2.gradient(results_arr[i], self.params_copy) + h_tensor_array.write(index=i, value=grad) + i += 1 + + # h_tensor_array is a TensorArray, reshape this into a tensor so that it can be used + # in down-stream computation graphs. + + hessians = tf.transpose(tf.reshape( + h_tensor_array.stack(), + tf.stack((self.params_copy.get_shape()[0], + self.params_copy.get_shape()[0], + self.params_copy.get_shape()[1])) + ), perm=[2, 1, 0]) + hessians = tf.negative(hessians) + ''' + # else: + if concat: + hessians = tf.negative(self.hessian([*inputs[0:3], loc, scale, True])) + return log_probs, jacobians, hessians + else: + hes_aa, hes_ab, hes_ba, hes_bb = self.hessian([*inputs[0:3], loc, scale, False]) + return log_probs, jac_a, jac_b, tf.negative(hes_aa), tf.negative(hes_ab), tf.negative(hes_ba), tf.negative(hes_bb) + # del g2 # need to delete this GradientTape because persistent is True. + + + # This is for SecondOrder IRLS/IRLS_GD/IRLS_TR/IRLS_GD_TR + if self.TS["fim"] is True: + + + + if concat: + loc, scale, log_probs, jacobians = self._calc_jacobians(inputs, concat=True, transpose=False) + fims = self.fim([*inputs[0:3], loc, scale, True]) + + return log_probs, tf.negative(jacobians), fims + else: + loc, scale, log_probs, jac_a, jac_b = self._calc_jacobians(inputs, concat=False, transpose=False) + fim_a, fim_b = self.fim([*inputs[0:3], loc, scale, False]) + + return log_probs, jac_a, jac_b, fim_a, fim_b + + raise ValueError("No gradient calculation specified.") + + +class LossGLM(LossBase): + + def norm_log_likelihood(self, log_probs): + return tf.reduce_mean(log_probs, axis=0, name="log_likelihood") + + def norm_neg_log_likelihood(self, log_probs): + return - self.norm_log_likelihood(log_probs) + + def call(self, y_true, log_probs): + return tf.reduce_sum(self.norm_neg_log_likelihood(log_probs)) diff --git a/batchglm/train/tf2/base_glm/optim.py b/batchglm/train/tf2/base_glm/optim.py new file mode 100644 index 00000000..04bd2f16 --- /dev/null +++ b/batchglm/train/tf2/base_glm/optim.py @@ -0,0 +1,535 @@ +from .external import pkg_constants +import tensorflow as tf +from .external import OptimizerBase +import abc +import numpy as np + + +class SecondOrderOptim(OptimizerBase, metaclass=abc.ABCMeta): + + """ + Superclass for NR and IRLS + """ + + def _norm_log_likelihood(self, log_probs): + return tf.reduce_mean(log_probs, axis=0, name="log_likelihood") + + def _norm_neg_log_likelihood(self, log_probs): + return - self._norm_log_likelihood(log_probs) + + def _resource_apply_dense(self, grad, handle, apply_state=None): + + update_op = handle.assign_add(grad, read_value=False) + + return update_op + + def _resource_apply_sparse(self, grad, handle, apply_state=None): + + raise NotImplementedError('Applying SparseTensor currently not possible.') + + def get_config(self): + + config = {"name": "SOO"} + return config + + def _create_slots(self, var_list): + + self.add_slot(var_list[0], 'mu_r') + + def _trust_region_ops( + self, + x_batch, + likelihood, + proposed_vector, + proposed_gain, + compute_a, + compute_b, + batch_features, + ll_prev + ): + # Load hyper-parameters: + assert pkg_constants.TRUST_REGION_ETA0 < pkg_constants.TRUST_REGION_ETA1, \ + "eta0 must be smaller than eta1" + assert pkg_constants.TRUST_REGION_ETA1 <= pkg_constants.TRUST_REGION_ETA2, \ + "eta1 must be smaller than or equal to eta2" + assert pkg_constants.TRUST_REGION_T1 <= 1, "t1 must be smaller than 1" + assert pkg_constants.TRUST_REGION_T2 >= 1, "t1 must be larger than 1" + # Set trust region hyper-parameters + eta0 = tf.constant(pkg_constants.TRUST_REGION_ETA0, dtype=self._dtype) + eta1 = tf.constant(pkg_constants.TRUST_REGION_ETA1, dtype=self._dtype) + eta2 = tf.constant(pkg_constants.TRUST_REGION_ETA2, dtype=self._dtype) + if self.gd and compute_b: + t1 = tf.constant(pkg_constants.TRUST_REGIONT_T1_IRLS_GD_TR_SCALE, dtype=self._dtype) + else: + t1 = tf.constant(pkg_constants.TRUST_REGION_T1, dtype=self._dtype) + t2 = tf.constant(pkg_constants.TRUST_REGION_T2, dtype=self._dtype) + upper_bound = tf.constant(pkg_constants.TRUST_REGION_UPPER_BOUND, dtype=self._dtype) + + # Phase I: Perform a trial update. + # Propose parameter update: + + self.model.params_copy.assign_sub(proposed_vector) + # Phase II: Evaluate success of trial update and complete update cycle. + # Include parameter updates only if update improves cost function: + new_likelihood = self.model.calc_ll([*x_batch], keep_previous_params_copy=True)[0] + delta_f_actual = self._norm_neg_log_likelihood(likelihood) - self._norm_neg_log_likelihood(new_likelihood) + + if batch_features: + + indices = tf.where(tf.logical_not(self.model.model_vars.converged)) + updated_lls = tf.scatter_nd(indices, delta_f_actual, shape=ll_prev.shape) + delta_f_actual = np.where(self.model.model_vars.converged, ll_prev, updated_lls.numpy()) + update_var = tf.transpose(tf.scatter_nd( + indices, + tf.transpose(proposed_vector), + shape=(self.model.model_vars.n_features, proposed_vector.get_shape()[0]) + )) + + gain_var = tf.transpose(tf.scatter_nd( + indices, + proposed_gain, + shape=([self.model.model_vars.n_features]))) + else: + update_var = proposed_vector + gain_var = proposed_gain + delta_f_ratio = tf.divide(delta_f_actual, gain_var) + + # Compute parameter updates.g + update_theta = tf.logical_and(delta_f_actual > eta0, tf.logical_not(self.model.model_vars.converged)) + update_theta_numeric = tf.expand_dims(tf.cast(update_theta, self._dtype), axis=0) + keep_theta_numeric = tf.ones_like(update_theta_numeric) - update_theta_numeric + if batch_features: + params = tf.transpose(tf.scatter_nd( + indices, + tf.transpose(self.model.params_copy), + shape=(self.model.model_vars.n_features, self.model.params.get_shape()[0]) + )) + + theta_new_tr = tf.add( + tf.multiply(self.model.params, keep_theta_numeric), + tf.multiply(params, update_theta_numeric) + ) + + + #self.model.params.assign_(tf.multiply(params, update_theta_numeric)) + + else: + params = self.model.params_copy + theta_new_tr = tf.add( + tf.multiply(params + update_var, keep_theta_numeric), # old values + tf.multiply(params, update_theta_numeric) # new values + ) + self.model.params.assign(theta_new_tr) + self.model.model_vars.updated = update_theta.numpy() + + # Update trusted region accordingly: + decrease_radius = tf.logical_or( + delta_f_actual <= eta0, + tf.logical_and(delta_f_ratio <= eta1, tf.logical_not(self.model.model_vars.converged)) + ) + increase_radius = tf.logical_and( + delta_f_actual > eta0, + tf.logical_and(delta_f_ratio > eta2, tf.logical_not(self.model.model_vars.converged)) + ) + keep_radius = tf.logical_and(tf.logical_not(decrease_radius), + tf.logical_not(increase_radius)) + radius_update = tf.add_n([ + tf.multiply(t1, tf.cast(decrease_radius, self._dtype)), + tf.multiply(t2, tf.cast(increase_radius, self._dtype)), + tf.multiply(tf.ones_like(t1), tf.cast(keep_radius, self._dtype)) + ]) + + if self.gd and compute_b and not compute_a: + tr_radius = self.tr_radius_b + else: + tr_radius = self.tr_radius + + radius_new = tf.minimum(tf.multiply(tr_radius, radius_update), upper_bound) + tr_radius.assign(radius_new) + + def __init__(self, dtype: tf.dtypes.DType, trusted_region_mode: bool, model: tf.keras.Model, name: str): + + self.model = model + self.gd = name in ['IRLS_GD', 'IRLS_GD_TR'] + + super(SecondOrderOptim, self).__init__(name) + + self._dtype = dtype + self.trusted_region_mode = trusted_region_mode + if trusted_region_mode: + + self.tr_radius = tf.Variable( + np.zeros(shape=[self.model.model_vars.n_features]) + pkg_constants.TRUST_REGION_RADIUS_INIT, + dtype=self._dtype, trainable=False + ) + if self.gd: + self.tr_radius_b = tf.Variable( + np.zeros(shape=[self.model.model_vars.n_features]) + pkg_constants.TRUST_REGION_RADIUS_INIT, + dtype=self._dtype, trainable=False + ) + + self.tr_ll_prev = tf.Variable(np.zeros(shape=[self.model.model_vars.n_features]), trainable=False) + self.tr_pred_gain = tf.Variable(np.zeros(shape=[self.model.model_vars.n_features]), trainable=False) + + else: + + self.tr_radius = tf.Variable(np.array([np.inf]), dtype=self._dtype, trainable=False) + + @abc.abstractmethod + def perform_parameter_update(self, inputs): + pass + + def _newton_type_update(self, lhs, rhs, psd): + + new_rhs = tf.expand_dims(rhs, axis=-1) + res = tf.linalg.lstsq(lhs, new_rhs, fast=False) + delta_t = tf.squeeze(res, axis=-1) + update_tensor = tf.transpose(delta_t) + return update_tensor + + def _pad_updates( + self, + update_raw, + compute_a, + compute_b + ): + # Pad update vectors to receive update tensors that match + # the shape of model_vars.params. + if compute_a: + if compute_b: + netwon_type_update = update_raw + else: + netwon_type_update = tf.concat([ + update_raw, + tf.zeros(shape=(self.model.model_vars.b_var.get_shape()[0], update_raw.get_shape()[1]), + dtype=self._dtype) + ], axis=0) + + elif compute_b: + netwon_type_update = tf.concat([ + tf.zeros(shape=(self.model.model_vars.a_var.get_shape()[0], update_raw.get_shape()[1]), + dtype=self._dtype), + update_raw + ], axis=0) + + else: + raise ValueError("No training necessary") + + return netwon_type_update + + def _trust_region_update( + self, + update_raw, + radius_container, + n_obs=None + ): + update_magnitude_sq = tf.reduce_sum(tf.square(update_raw), axis=0) + update_magnitude = tf.where( + condition=update_magnitude_sq > 0, + x=tf.sqrt(update_magnitude_sq), + y=tf.zeros_like(update_magnitude_sq) + ) + update_magnitude_inv = tf.where( + condition=update_magnitude > 0, + x=tf.divide( + tf.ones_like(update_magnitude), + update_magnitude + ), + y=tf.zeros_like(update_magnitude) + ) + update_norm = tf.multiply(update_raw, update_magnitude_inv) + # the following switch is for irls_gd_tr (linear instead of newton) + if n_obs is not None: + update_magnitude /= n_obs + update_scale = tf.minimum( + radius_container, + update_magnitude + ) + proposed_vector = tf.multiply( + update_norm, + update_scale + ) + + return proposed_vector + + def _trust_region_newton_cost_gain( + self, + proposed_vector, + neg_jac, + hessian_fim, + n_obs + ): + pred_cost_gain = tf.add( + tf.einsum( + 'ni,in->n', + neg_jac, + proposed_vector + ) / n_obs, + 0.5 * tf.einsum( + 'nix,xin->n', + tf.einsum('inx,nij->njx', + tf.expand_dims(proposed_vector, axis=-1), + hessian_fim), + tf.expand_dims(proposed_vector, axis=0) + ) / tf.square(n_obs) + ) + return pred_cost_gain + + +class NR(SecondOrderOptim): + + def _get_updates(self, lhs, rhs, psd, compute_a, compute_b): + + update_raw = self._newton_type_update(lhs=lhs, rhs=rhs, psd=psd) + update = self._pad_updates(update_raw, compute_a, compute_b) + + return update_raw, update + + def perform_parameter_update(self, inputs, compute_a=True, compute_b=True, batch_features=False, prev_ll=None): + + x_batch, log_probs, jacobians, hessians, psd, n_obs = inputs + if not (compute_a or compute_b): + raise ValueError( + "Nothing can be trained. Please make sure at least one of train_mu and train_r is set to True.") + + update_raw, update = self._get_updates(hessians, jacobians, psd, compute_a, compute_b) + + if self.trusted_region_mode: + + n_obs = tf.cast(n_obs, dtype=self._dtype) + if batch_features: + radius_container = tf.boolean_mask( + tensor=self.tr_radius, + mask=tf.logical_not(self.model.model_vars.converged)) + else: + radius_container = self.tr_radius + tr_proposed_vector = self._trust_region_update( + update_raw=update_raw, + radius_container=radius_container + ) + tr_pred_cost_gain = self._trust_region_newton_cost_gain( + proposed_vector=tr_proposed_vector, + neg_jac=jacobians, + hessian_fim=hessians, + n_obs=n_obs + ) + + tr_proposed_vector_pad = self._pad_updates( + update_raw=tr_proposed_vector, + compute_a=compute_a, + compute_b=compute_b + ) + + self._trust_region_ops( + x_batch=x_batch, + likelihood=log_probs, + proposed_vector=tr_proposed_vector_pad, + proposed_gain=tr_pred_cost_gain, + compute_a=compute_a, + compute_b=compute_b, + batch_features=batch_features, + ll_prev=prev_ll + ) + + else: + if batch_features: + indices = tf.where(tf.logical_not(self.model.model_vars.converged)) + update_var = tf.transpose( + tf.scatter_nd( + indices, + tf.transpose(update), + shape=(self.model.model_vars.n_features, update.get_shape()[0]) + ) + ) + else: + update_var = update + self.model.params.assign_sub(update_var) + + +class IRLS(SecondOrderOptim): + + def _calc_proposed_vector_and_pred_cost_gain( + self, + update_x, + radius_container, + n_obs, + gd, + neg_jac_x, + fim_x=None + ): + """ + Calculates the proposed vector and predicted cost gain for either mean or scale part. + :param update_x: tf.tensor coefficients x features ? TODO + + :param radius_container: tf.tensor ? x ? TODO + + :param n_obs: ? TODO + Number of observations in current batch. + :param gd: boolean + If True, the proposed vector and predicted cost gain are + calculated by linear functions related to IRLS_GD(_TR) optimizer. + If False, use newton functions for IRLS_TR optimizer instead. + :param neg_jac_x: tf.Tensor coefficients x features ? TODO + Upper (mu part) or lower (r part) of negative jacobian matrix + :param fim_x + Upper (mu part) or lower (r part) of Fisher Inverse Matrix. + Defaults to None, is only needed if gd is False + :return proposed_vector_x, pred_cost_gain_x + Returns proposed vector and predicted cost gain after + trusted region update for either mu or r part, depending on x + """ + + proposed_vector_x = self._trust_region_update( + update_raw=update_x, + radius_container=radius_container, + n_obs=n_obs if gd else None + ) + # here, functions have different number of arguments, thus + # must be written out + if gd: + pred_cost_gain_x = self._trust_region_linear_cost_gain( + proposed_vector=proposed_vector_x, + neg_jac=neg_jac_x + ) + else: + pred_cost_gain_x = self._trust_region_newton_cost_gain( + proposed_vector=proposed_vector_x, + neg_jac=neg_jac_x, + hessian_fim=fim_x, + n_obs=n_obs + ) + + return proposed_vector_x, pred_cost_gain_x + + def _trust_region_linear_cost_gain( + self, + proposed_vector, + neg_jac + ): + pred_cost_gain = tf.reduce_sum(tf.multiply( + proposed_vector, + tf.transpose(neg_jac) + ), axis=0) + return pred_cost_gain + + def perform_parameter_update(self, inputs, compute_a=True, compute_b=True, batch_features=False, prev_ll=None): + + x_batch, log_probs, jac_a, jac_b, fim_a, fim_b, psd, n_obs = inputs + if not (compute_a or compute_b): + raise ValueError( + "Nothing can be trained. Please make sure at least one of train_mu and train_r is set to True.") + # Compute a and b model updates separately. + if compute_a: + # The FIM of the mean model is guaranteed to be + # positive semi-definite and can therefore be inverted + # with the Cholesky decomposition. This information is + # passed here with psd=True. + update_a = self._newton_type_update( + lhs=fim_a, + rhs=jac_a, + psd=True + ) + if compute_b: + + if self.gd: + update_b = tf.transpose(jac_b) + + else: + update_b = self._newton_type_update( + lhs=fim_b, + rhs=jac_b, + psd=False + ) + + if not self.trusted_region_mode: + if compute_a: + if compute_b: + update_raw = tf.concat([update_a, update_b], axis=0) + else: + update_raw = update_a + else: + update_raw = update_b + + update = self._pad_updates( + update_raw=update_raw, + compute_a=compute_a, + compute_b=compute_b + ) + + if batch_features: + indices = tf.where(tf.logical_not(self.model.model_vars.converged)) + update_var = tf.transpose( + tf.scatter_nd( + indices, + tf.transpose(update), + shape=(self.model.model_vars.n_features, update.get_shape()[0]) + ) + ) + else: + update_var = update + self.model.params.assign_sub(update_var) + + else: + + n_obs = tf.cast(n_obs, dtype=self._dtype) + # put together update_raw based on proposed vector and cost gain depending on train_r and train_mu + if compute_b: + if compute_a: + if batch_features: + radius_container = tf.boolean_mask( + tensor=self.tr_radius, + mask=tf.logical_not(self.model.model_vars.converged)) + else: + radius_container = self.tr_radius + tr_proposed_vector_b, tr_pred_cost_gain_b = self._calc_proposed_vector_and_pred_cost_gain( + update_b, radius_container, n_obs, self.gd, jac_b, fim_b) + + tr_proposed_vector_a, tr_pred_cost_gain_a = self._calc_proposed_vector_and_pred_cost_gain( + update_a, radius_container, n_obs, False, jac_a, fim_a) + + tr_update_raw = tf.concat([tr_proposed_vector_a, tr_proposed_vector_b], axis=0) + tr_pred_cost_gain = tf.add(tr_pred_cost_gain_a, tr_pred_cost_gain_b) + + else: + radius_container = self.tr_radius_b if self.gd else self.tr_radius + if batch_features: + radius_container = tf.boolean_mask( + tensor=radius_container, + mask=tf.logical_not(self.model.model_vars.converged)) + + tr_proposed_vector_b, tr_pred_cost_gain_b = self._calc_proposed_vector_and_pred_cost_gain( + update_b, radius_container, n_obs, self.gd, jac_b, fim_b) + + # directly apply output of calc_proposed_vector_and_pred_cost_gain to tr_update_raw + # and tr_pred_cost_gain + tr_update_raw = tr_proposed_vector_b + tr_pred_cost_gain = tr_pred_cost_gain_b + else: + if batch_features: + radius_container = tf.boolean_mask( + tensor=self.tr_radius, + mask=tf.logical_not(self.model.model_vars.converged)) + else: + radius_container = self.tr_radius + # here train_r is False AND train_mu is true, so the output of the function can directly be applied to + # tr_update_raw and tr_pred_cost_gain, similar to train_r = True and train_mu = False + tr_update_raw, tr_pred_cost_gain = self._calc_proposed_vector_and_pred_cost_gain( + update_a, radius_container, n_obs, False, jac_a, fim_a) + + # perform update + tr_update = self._pad_updates( + update_raw=tr_update_raw, + compute_a=compute_a, + compute_b=compute_b + ) + + self._trust_region_ops( + x_batch, + log_probs, + tr_update, + tr_pred_cost_gain, + compute_a, + compute_b, + batch_features, + prev_ll + ) diff --git a/batchglm/train/tf2/base_glm/processModel.py b/batchglm/train/tf2/base_glm/processModel.py new file mode 100644 index 00000000..4b6aedf7 --- /dev/null +++ b/batchglm/train/tf2/base_glm/processModel.py @@ -0,0 +1,9 @@ +from .external import ProcessModelBase +import abc + + +class ProcessModelGLM(ProcessModelBase): + + @abc.abstractmethod + def param_bounds(self, dtype): + pass diff --git a/batchglm/train/tf2/base_glm/training_strategies.py b/batchglm/train/tf2/base_glm/training_strategies.py new file mode 100644 index 00000000..63f295d3 --- /dev/null +++ b/batchglm/train/tf2/base_glm/training_strategies.py @@ -0,0 +1,111 @@ +from enum import Enum + + +class TrainingStrategies(Enum): + + AUTO = None + + DEFAULT = \ + { + "optim_algo": "default_adam", + "jacobian": True, + "hessian": False, + "fim": False, + "concat_grads": True + } + + GD = \ + { + "optim_algo": "gd", + "jacobian": True, + "hessian": False, + "fim": False, + "concat_grads": True + } + + ADAM = \ + { + "optim_algo": "adam", + "jacobian": True, + "hessian": False, + "fim": False, + "concat_grads": True + } + + ADAGRAD = \ + { + "optim_algo": "adagrad", + "jacobian": True, + "hessian": False, + "fim": False, + "concat_grads": True + } + + RMSPROP = \ + { + "optim_algo": "rmsprop", + "jacobian": True, + "hessian": False, + "fim": False, + "concat_grads": True + } + + IRLS = \ + { + "optim_algo": "irls", + "jacobian": False, + "hessian": False, + "fim": True, + "concat_grads": False, + "calc_separated": True + } + + IRLS_TR = \ + { + "optim_algo": "irls_tr", + "jacobian": False, + "hessian": False, + "fim": True, + "concat_grads": False, + "calc_separated": True + } + + IRLS_GD = \ + { + "optim_algo": "irls_gd", + "jacobian": False, + "hessian": False, + "fim": True, + "concat_grads": False, + "calc_separated": True + } + + IRLS_GD_TR = \ + { + "optim_algo": "irls_gd_tr", + "jacobian": False, + "hessian": False, + "fim": True, + "concat_grads": False, + "calc_separated": True + } + + NR = \ + { + "optim_algo": "nr", + "jacobian": False, + "hessian": True, + "fim": False, + "concat_grads": True, + "calc_separated": False + } + + NR_TR = \ + { + "optim_algo": "nr_tr", + "jacobian": False, + "hessian": True, + "fim": False, + "concat_grads": True, + "calc_separated": False + } diff --git a/batchglm/train/tf2/base_glm/vars.py b/batchglm/train/tf2/base_glm/vars.py new file mode 100644 index 00000000..4b0debca --- /dev/null +++ b/batchglm/train/tf2/base_glm/vars.py @@ -0,0 +1,86 @@ +import numpy as np +import tensorflow as tf +import abc + +from .model import ProcessModelGLM + + +class ModelVarsGLM(ProcessModelGLM): + """ Build tf.Variables to be optimzed and their constraints. + + a_var and b_var slices of the tf.Variable params which contains + all parameters to be optimized during model estimation. + Params is defined across both location and scale model so that + the hessian can be computed for the entire model. + a and b are the clipped parameter values which also contain + constraints and constrained dependent coefficients which are not + directly optimized. + """ + + constraints_loc: tf.Tensor + constraints_scale: tf.Tensor + params: tf.Variable + a_var: tf.Tensor + b_var: tf.Tensor + updated: np.ndarray + converged: np.ndarray + dtype: str + n_features: int + + def __init__( + self, + init_a: np.ndarray, + init_b: np.ndarray, + constraints_loc: np.ndarray, + constraints_scale: np.ndarray, + dtype: str + ): + """ + + :param init_a: nd.array (mean model size x features) + Initialisation for all parameters of mean model. + :param init_b: nd.array (dispersion model size x features) + Initialisation for all parameters of dispersion model. + :param dtype: Precision used in tensorflow. + """ + self.constraints_loc = tf.convert_to_tensor(constraints_loc, dtype) + self.constraints_scale = tf.convert_to_tensor(constraints_scale, dtype) + + self.init_a = tf.convert_to_tensor(init_a, dtype=dtype) + self.init_b = tf.convert_to_tensor(init_b, dtype=dtype) + + self.init_a_clipped = self.tf_clip_param(self.init_a, "a_var") + self.init_b_clipped = self.tf_clip_param(self.init_b, "b_var") + + # Param is the only tf.Variable in the graph. + # a_var and b_var have to be slices of params. + self.params = tf.Variable(tf.concat( + [ + self.init_a_clipped, + self.init_b_clipped, + ], + axis=0 + ), name="params") + + a_var = self.params[0:init_a.shape[0]] + b_var = self.params[init_a.shape[0]:] + + self.a_var = self.tf_clip_param(a_var, "a_var") + self.b_var = self.tf_clip_param(b_var, "b_var") + + # Properties to follow gene-wise convergence. + self.updated = np.repeat(a=True, repeats=self.params.shape[1]) # Initialise to is updated. + self.converged = np.repeat(a=False, repeats=self.params.shape[1]) # Initialise to non-converged. + + self.dtype = dtype + self.n_features = self.params.shape[1] + self.idx_train_loc = np.arange(0, init_a.shape[0]) + self.idx_train_scale = np.arange(init_a.shape[0], init_a.shape[0]+init_b.shape[0]) + + @abc.abstractmethod + def param_bounds(self, dtype): + pass + + def convergence_update(self, status: np.ndarray, features_updated: np.ndarray): + self.converged = status.copy() + self.updated = features_updated diff --git a/batchglm/train/tf2/glm_beta/__init__.py b/batchglm/train/tf2/glm_beta/__init__.py new file mode 100644 index 00000000..a616f181 --- /dev/null +++ b/batchglm/train/tf2/glm_beta/__init__.py @@ -0,0 +1,5 @@ +from .processModel import ProcessModel +from .vars import ModelVars +from .estimator import Estimator + +from .model import BetaGLM diff --git a/batchglm/train/tf2/glm_beta/estimator.py b/batchglm/train/tf2/glm_beta/estimator.py new file mode 100644 index 00000000..d35cdea2 --- /dev/null +++ b/batchglm/train/tf2/glm_beta/estimator.py @@ -0,0 +1,239 @@ +import logging +from typing import Union + +import numpy as np + +from .external import closedform_beta_glm_logitmean, closedform_beta_glm_logsamplesize +from .external import InputDataGLM, Model +from .external import Estimator as GLMEstimator +from .model import BetaGLM, LossGLMBeta +from .processModel import ProcessModel +from .vars import ModelVars + + +class Estimator(GLMEstimator, ProcessModel): + """ + Estimator for Generalized Linear Models (GLMs) with beta distributed noise. + Uses a logit linker function for loc and log linker function for scale. + """ + + model: BetaGLM + + def __init__( + self, + input_data: InputDataGLM, + init_a: Union[np.ndarray, str] = "AUTO", + init_b: Union[np.ndarray, str] = "AUTO", + quick_scale: bool = False, + dtype="float64", + ): + """ + Performs initialisation and creates a new estimator. + + :param input_data: InputDataGLM + The input data + :param init_a: (Optional) + Low-level initial values for a. Can be: + + - str: + * "auto": automatically choose best initialization + * "random": initialize with random values + * "standard": initialize intercept with observed mean + * "init_model": initialize with another model (see `ìnit_model` parameter) + * "closed_form": try to initialize with closed form + - np.ndarray: direct initialization of 'a' + :param init_b: (Optional) + Low-level initial values for b. Can be: + + - str: + * "auto": automatically choose best initialization + * "random": initialize with random values + * "standard": initialize with zeros + * "init_model": initialize with another model (see `ìnit_model` parameter) + * "closed_form": try to initialize with closed form + - np.ndarray: direct initialization of 'b' + :param quick_scale: bool + Whether `scale` will be fitted faster and maybe less accurate. + Useful in scenarios where fitting the exact `scale` is not absolutely necessary. + :param dtype: Precision used in tensorflow. + """ + + self._train_loc = True + self._train_scale = True + + (init_a, init_b) = self.init_par( + input_data=input_data, + init_a=init_a, + init_b=init_b, + init_model=None + ) + init_a = init_a.astype(dtype) + init_b = init_b.astype(dtype) + if quick_scale: + self._train_scale = False + + self.model_vars = ModelVars( + init_a=init_a, + init_b=init_b, + constraints_loc=input_data.constraints_loc, + constraints_scale=input_data.constraints_scale, + dtype=dtype + ) + + super(Estimator, self).__init__( + input_data=input_data, + dtype=dtype + ) + + def train( + self, + batched_model=True, + batch_size: int = 500, + optimizer: str = "adam", + learning_rate: float = 1e-2, + convergence_criteria="step", + stopping_criteria=1000, + autograd=False, + featurewise = True, + benchmark: bool = False + ): + self.model = BetaGLM(model_vars=self.model_vars, dtype=self.model_vars.dtype, + compute_a=self._train_loc, compute_b=self._train_scale, use_gradient_tape=autograd) + self._loss = LossGLMBeta() + + optimizer_object, optimizer_enum = self.get_optimizer_object(optimizer, learning_rate) + self.model.TS = optimizer_enum.value + + super(Estimator, self)._train( + noise_model="beta", + batched_model=batched_model, + batch_size=batch_size, + optimizer_object=optimizer_object, + optimizer_enum=optimizer_enum, + convergence_criteria=convergence_criteria, + stopping_criteria=stopping_criteria, + autograd=autograd, + benchmark=benchmark + ) + + def get_model_container( + self, + input_data + ): + return Model(input_data=input_data) + + def init_par( + self, + input_data, + init_a, + init_b, + init_model + ): + r""" + standard: + Only initialise intercept and keep other coefficients as zero. + + closed-form: + Initialize with Maximum Likelihood / Maximum of Momentum estimators + """ + + size_factors_init = input_data.size_factors + + if init_model is None: + groupwise_means = None + init_a_str = None + if isinstance(init_a, str): + init_a_str = init_a.lower() + # Chose option if auto was chosen + if init_a.lower() == "auto": + init_a = "closed_form" + + if init_a.lower() == "closed_form": + groupwise_means, init_a, rmsd_a = closedform_beta_glm_logitmean( + x=input_data.x, + design_loc=input_data.design_loc, + constraints_loc=input_data.constraints_loc, + size_factors=size_factors_init, + link_fn=lambda mean: np.log( + 1/(1/self.np_clip_param(mean, "mean")-1) + ) + ) + + # train mu, if the closed-form solution is inaccurate + self._train_loc = not (np.all(rmsd_a == 0) or rmsd_a.size == 0) + + logging.getLogger("batchglm").debug("Using closed-form MME initialization for mean") + elif init_a.lower() == "standard": + overall_means = np.mean(input_data.x, axis=0) + overall_means = self.np_clip_param(overall_means, "mean") + + init_a = np.zeros([input_data.num_loc_params, input_data.num_features]) + init_a[0, :] = np.log(overall_means/(1-overall_means)) + self._train_loc = True + + logging.getLogger("batchglm").debug("Using standard initialization for mean") + elif init_a.lower() == "all_zero": + init_a = np.zeros([input_data.num_loc_params, input_data.num_features]) + self._train_loc = True + + logging.getLogger("batchglm").debug("Using all_zero initialization for mean") + else: + raise ValueError("init_a string %s not recognized" % init_a) + logging.getLogger("batchglm").debug("Should train mean: %s", self._train_loc) + if isinstance(init_b, str): + if init_b.lower() == "auto": + init_b = "standard" + + if init_b.lower() == "standard": + groupwise_scales, init_b_intercept, rmsd_b = closedform_beta_glm_logsamplesize( + x=input_data.x, + design_scale=input_data.design_scale[:, [0]], + constraints=input_data.constraints_scale[[0], :][:, [0]], + size_factors=size_factors_init, + groupwise_means=None, + link_fn=lambda samplesize: np.log(self.np_clip_param(samplesize, "samplesize")) + ) + init_b = np.zeros([input_data.num_scale_params, input_data.num_features]) + init_b[0, :] = init_b_intercept + + logging.getLogger("batchglm").debug("Using standard-form MME initialization for dispersion") + elif init_b.lower() == "closed_form": + dmats_unequal = False + if input_data.num_design_loc_params == input_data.num_design_scale_params: + if np.any(input_data.design_loc != input_data.design_scale): + dmats_unequal = True + + inits_unequal = False + if init_a_str is not None: + if init_a_str != init_b: + inits_unequal = True + + if inits_unequal or dmats_unequal: + raise ValueError( + "cannot use closed_form init for scale model if scale model differs from loc model" + ) + + groupwise_scales, init_b, rmsd_b = closedform_beta_glm_logsamplesize( + x=input_data.x, + design_scale=input_data.design_scale, + constraints=input_data.constraints_scale, + size_factors=size_factors_init, + groupwise_means=groupwise_means, + link_fn=lambda samplesize: np.log(self.np_clip_param(samplesize, "samplesize")) + ) + + logging.getLogger("batchglm").debug("Using closed-form MME initialization for dispersion") + elif init_b.lower() == "all_zero": + init_b = np.zeros([input_data.num_scale_params, input_data.num_features]) + + logging.getLogger("batchglm").debug("Using standard initialization for dispersion") + else: + raise ValueError("init_b string %s not recognized" % init_b) + logging.getLogger("batchglm").debug("Should train r: %s", self._train_scale) + else: + init_a, init_b = self.get_init_from_model(init_a=init_a, + init_b=init_b, + input_data=input_data, + init_model=init_model) + + return init_a, init_b diff --git a/batchglm/train/tf2/glm_beta/external.py b/batchglm/train/tf2/glm_beta/external.py new file mode 100644 index 00000000..f7b5d508 --- /dev/null +++ b/batchglm/train/tf2/glm_beta/external.py @@ -0,0 +1,12 @@ +from batchglm import pkg_constants +import batchglm.data as data_utils + +from batchglm.models.base_glm.utils import closedform_glm_mean, closedform_glm_scale +from batchglm.models.glm_beta import _EstimatorGLM, InputDataGLM, Model +from batchglm.models.glm_beta.utils import closedform_beta_glm_logitmean, closedform_beta_glm_logsamplesize +from batchglm.utils.linalg import groupwise_solve_lm + +from batchglm.train.tf2.base_glm import ProcessModelGLM, GLM, LossGLM, Estimator, ModelVarsGLM +from batchglm.train.tf2.base_glm import LinearLocGLM, LinearScaleGLM, LinkerLocGLM, LinkerScaleGLM, LikelihoodGLM, UnpackParamsGLM +from batchglm.train.tf2.base_glm import FIMGLM, JacobianGLM, HessianGLM + diff --git a/batchglm/train/tf2/glm_beta/layers.py b/batchglm/train/tf2/glm_beta/layers.py new file mode 100644 index 00000000..2eae4735 --- /dev/null +++ b/batchglm/train/tf2/glm_beta/layers.py @@ -0,0 +1,53 @@ +import tensorflow as tf +from .external import LinearLocGLM, LinearScaleGLM, LinkerLocGLM, LinkerScaleGLM, LikelihoodGLM, UnpackParamsGLM +from .processModel import ProcessModel + + +class UnpackParams(UnpackParamsGLM, ProcessModel): + """ + Full class. + """ + + +class LinearLoc(LinearLocGLM, ProcessModel): + + def with_size_factors(self, eta_loc, size_factors): + raise NotImplementedError("There are no size_factors for GLMs with Beta noise.") + + +class LinearScale(LinearScaleGLM, ProcessModel): + """ + Full Class + """ + + +class LinkerLoc(LinkerLocGLM): + + def _inv_linker(self, loc: tf.Tensor): + return 1 / (1 + tf.exp(-loc)) + + +class LinkerScale(LinkerScaleGLM): + + def _inv_linker(self, scale: tf.Tensor): + return tf.exp(scale) + + +class Likelihood(LikelihoodGLM, ProcessModel): + + def _ll(self, eta_loc, eta_scale, loc, scale, x, n_features): + + if isinstance(x, tf.SparseTensor): + one_minus_x = -tf.sparse.add(x, -tf.ones_like(loc)) + else: + one_minus_x = 1 - x + + one_minus_loc = 1 - loc + log_probs = tf.math.lgamma(scale) - tf.math.lgamma(loc * scale) \ + - tf.math.lgamma(one_minus_loc * scale) \ + + (scale * loc - 1) * tf.math.log(x) \ + + (one_minus_loc * scale - 1) * tf.math.log(one_minus_x) + + log_probs = self.tf_clip_param(log_probs, "log_probs") + + return log_probs diff --git a/batchglm/train/tf2/glm_beta/layers_gradients.py b/batchglm/train/tf2/glm_beta/layers_gradients.py new file mode 100644 index 00000000..566e9b44 --- /dev/null +++ b/batchglm/train/tf2/glm_beta/layers_gradients.py @@ -0,0 +1,144 @@ +import tensorflow as tf +from .external import FIMGLM, JacobianGLM, HessianGLM + + +class FIM(FIMGLM): + # No Fisher Information Matrices due to unsolvable E[log(X)] + + def _weight_fim_aa( + self, + x, + loc, + scale + ): + assert False, "not implemented" + + def _weight_fim_bb( + self, + x, + loc, + scale + ): + assert False, "not implemented" + + +class Jacobian(JacobianGLM): + + def _weights_jac_a( + self, + x, + loc, + scale, + ): + one_minus_loc = 1 - loc + if isinstance(x, tf.SparseTensor): + const1 = tf.math.log(tf.sparse.add(tf.zeros_like(loc), x).__div__(-tf.sparse.add(x, -tf.ones_like(loc)))) + else: + const1 = tf.math.log(x / (1 - x)) + const2 = - tf.math.digamma(loc * scale) + tf.math.digamma(one_minus_loc * scale) + const1 + const = const2 * scale * loc * one_minus_loc + return const + + def _weights_jac_b( + self, + x, + loc, + scale, + ): + if isinstance(x, tf.SparseTensor): + one_minus_x = - tf.sparse.add(x, -tf.ones_like(loc)) + else: + one_minus_x = 1 - x + one_minus_loc = 1 - loc + const = scale * (tf.math.digamma(scale) - tf.math.digamma(loc * scale) * loc - tf.math.digamma( + one_minus_loc * scale) * one_minus_loc + loc * tf.math.log(x) + one_minus_loc * tf.math.log( + one_minus_x)) + return const + + +class Hessian(HessianGLM): + + def _weight_hessian_aa( + self, + x, + loc, + scale, + ): + one_minus_loc = 1 - loc + loc_times_scale = loc * scale + one_minus_loc_times_scale = one_minus_loc * scale + + if isinstance(x, tf.SparseTensor): + # Using the dense matrix of the location model to serve the correct shapes for the sparse X. + const1 = tf.sparse.add(tf.zeros_like(loc), x).__div__(-tf.sparse.add(x, -tf.ones_like(loc))) + # Adding tf.zeros_like(loc) is a hack to avoid bug thrown by log on sparse matrix below, + # to_dense does not work. + else: + const1 = tf.math.log(x / (tf.ones_like(x) - x)) + + const2 = (1 - 2 * loc) * ( + - tf.math.digamma(loc_times_scale) + tf.math.digamma(one_minus_loc_times_scale) + const1) + const3 = loc * one_minus_loc_times_scale * ( + - tf.math.polygamma(tf.ones_like(loc), loc_times_scale) - tf.math.polygamma(tf.ones_like(loc), + one_minus_loc_times_scale)) + const = loc * one_minus_loc_times_scale * (const2 + const3) + return const + + def _weight_hessian_ab( + self, + x, + loc, + scale, + ): + one_minus_loc = 1 - loc + loc_times_scale = loc * scale + one_minus_loc_times_scale = one_minus_loc * scale + scalar_one = tf.constant(1, shape=(), dtype=self.dtype) + + if isinstance(x, tf.SparseTensor): + # Using the dense matrix of the location model to serve the correct shapes for the sparse X. + const1 = tf.sparse.add(tf.zeros_like(loc), x).__div__(-tf.sparse.add(x, -tf.ones_like(loc))) + # Adding tf.zeros_like(loc) is a hack to avoid bug thrown by log on sparse matrix below, + # to_dense does not work. + else: + const1 = tf.math.log(x / (1 - x)) + + const2 = - tf.math.digamma(loc_times_scale) + tf.math.digamma(one_minus_loc_times_scale) + const1 + const3 = scale * (- tf.math.polygamma(scalar_one, loc_times_scale) * loc + one_minus_loc * tf.math.polygamma( + scalar_one, + one_minus_loc_times_scale)) + + const = loc * one_minus_loc_times_scale * (const2 + const3) + + return const + + def _weight_hessian_bb( + self, + x, + loc, + scale, + ): + one_minus_loc = 1 - loc + loc_times_scale = loc * scale + one_minus_loc_times_scale = one_minus_loc * scale + scalar_one = tf.constant(1, shape=(), dtype=self.dtype) + + if isinstance(x, tf.SparseTensor): + # Using the dense matrix of the location model to serve the correct shapes for the sparse X. + const1 = tf.sparse.add(tf.zeros_like(loc), x).__div__(-tf.sparse.add(x, -tf.ones_like(loc))) + # Adding tf.zeros_like(loc) is a hack to avoid bug thrown by log on sparse matrix below, + # to_dense does not work. + const2 = loc * (tf.math.log(tf.sparse.add(tf.zeros_like(loc), x)) - tf.math.digamma(loc_times_scale)) \ + - one_minus_loc * (tf.math.digamma(one_minus_loc_times_scale) + tf.math.log(const1)) \ + + tf.math.digamma(scale) + else: + const1 = tf.math.log(x / (1 - x)) + const2 = loc * (tf.math.log(x) - tf.math.digamma(loc_times_scale)) \ + - one_minus_loc * (tf.math.digamma(one_minus_loc_times_scale) + tf.math.log(const1)) \ + + tf.math.digamma(scale) + const3 = scale * (- tf.square(loc) * tf.math.polygamma(scalar_one, loc_times_scale) + + tf.math.polygamma(scalar_one, scale) + - tf.math.polygamma(scalar_one, one_minus_loc_times_scale) * tf.square(one_minus_loc)) + const = scale * (const2 + const3) + + return const diff --git a/batchglm/train/tf2/glm_beta/model.py b/batchglm/train/tf2/glm_beta/model.py new file mode 100644 index 00000000..435c9c53 --- /dev/null +++ b/batchglm/train/tf2/glm_beta/model.py @@ -0,0 +1,44 @@ +import logging + +from .layers import UnpackParams, LinearLoc, LinearScale, LinkerLoc, LinkerScale, Likelihood +from .external import GLM, LossGLM +from .layers_gradients import Jacobian, Hessian, FIM +from .processModel import ProcessModel + +logger = logging.getLogger(__name__) + + +class BetaGLM(GLM, ProcessModel): + + def __init__( + self, + model_vars, + dtype, + compute_a, + compute_b, + use_gradient_tape + ): + self.compute_a = compute_a + self.compute_b = compute_b + + super(BetaGLM, self).__init__( + model_vars=model_vars, + unpack_params=UnpackParams(), + linear_loc=LinearLoc(), + linear_scale=LinearScale(), + linker_loc=LinkerLoc(), + linker_scale=LinkerScale(), + likelihood=Likelihood(dtype), + jacobian=Jacobian(model_vars=model_vars, compute_a=compute_a, compute_b=compute_b, dtype=dtype), + hessian=Hessian(model_vars=model_vars, compute_a=compute_a, compute_b=compute_b, dtype=dtype), + fim=FIM(model_vars=model_vars, compute_a=compute_a, compute_b=compute_b, dtype=dtype), + use_gradient_tape=use_gradient_tape + + ) + + +class LossGLMBeta(LossGLM): + + """ + Full class + """ diff --git a/batchglm/train/tf2/glm_beta/processModel.py b/batchglm/train/tf2/glm_beta/processModel.py new file mode 100644 index 00000000..c21811a4 --- /dev/null +++ b/batchglm/train/tf2/glm_beta/processModel.py @@ -0,0 +1,45 @@ +from .external import ProcessModelGLM +import tensorflow as tf +import numpy as np +from .external import pkg_constants + + +class ProcessModel(ProcessModelGLM): + + def param_bounds( + self, + dtype + ): + if isinstance(dtype, tf.DType): + dmax = dtype.max + dtype = dtype.as_numpy_dtype + else: + dtype = np.dtype(dtype) + dmax = np.finfo(dtype).max + dtype = dtype.type + + zero = np.nextafter(0, np.inf, dtype=dtype) + one = np.nextafter(1, -np.inf, dtype=dtype) + + sf = dtype(pkg_constants.ACCURACY_MARGIN_RELATIVE_TO_LIMIT) + bounds_min = { + "a_var": np.log(zero / (1 - zero)) / sf, + "b_var": np.log(zero) / sf, + "eta_loc": np.log(zero / (1 - zero)) / sf, + "eta_scale": np.log(zero) / sf, + "mean": np.nextafter(0, np.inf, dtype=dtype), + "samplesize": np.nextafter(0, np.inf, dtype=dtype), + "probs": dtype(0), + "log_probs": np.log(zero), + } + bounds_max = { + "a_var": np.log(one / (1 - one)) / sf, + "b_var": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "eta_loc": np.log(one / (1 - one)) / sf, + "eta_scale": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "mean": one, + "samplesize": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, + "probs": dtype(1), + "log_probs": dtype(0), + } + return bounds_min, bounds_max diff --git a/batchglm/train/tf2/glm_beta/vars.py b/batchglm/train/tf2/glm_beta/vars.py new file mode 100644 index 00000000..b1200abc --- /dev/null +++ b/batchglm/train/tf2/glm_beta/vars.py @@ -0,0 +1,8 @@ +from .model import ProcessModel +from .external import ModelVarsGLM + + +class ModelVars(ProcessModel, ModelVarsGLM): + """ + Full class. + """ diff --git a/batchglm/train/tf2/glm_nb/__init__.py b/batchglm/train/tf2/glm_nb/__init__.py new file mode 100644 index 00000000..f8cd6ee7 --- /dev/null +++ b/batchglm/train/tf2/glm_nb/__init__.py @@ -0,0 +1,5 @@ +from .processModel import ProcessModel +from .vars import ModelVars +from .estimator import Estimator + +from .model import NBGLM diff --git a/batchglm/train/tf2/glm_nb/estimator.py b/batchglm/train/tf2/glm_nb/estimator.py new file mode 100644 index 00000000..3cad4c19 --- /dev/null +++ b/batchglm/train/tf2/glm_nb/estimator.py @@ -0,0 +1,266 @@ +import logging +from typing import Union +import numpy as np + +from .external import InputDataGLM, Model +from .external import closedform_nb_glm_logmu, closedform_nb_glm_logphi + +from .model import NBGLM, LossGLMNB +from .vars import ModelVars +from .processModel import ProcessModel +from .external import Estimator as GLMEstimator + + +class Estimator(GLMEstimator, ProcessModel): + """ + Estimator for Generalized Linear Models (GLMs) with negative binomial noise. + Uses the natural logarithm as linker function. + """ + model: NBGLM + + def __init__( + self, + input_data: InputDataGLM, + init_a: Union[np.ndarray, str] = "AUTO", + init_b: Union[np.ndarray, str] = "AUTO", + quick_scale: bool = False, + dtype="float64", + ): + """ + Performs initialisation and creates a new estimator. + + :param input_data: InputDataGLM + The input data + :param init_a: (Optional) + Low-level initial values for a. Can be: + + - str: + * "auto": automatically choose best initialization + * "random": initialize with random values + * "standard": initialize intercept with observed mean + * "init_model": initialize with another model (see `ìnit_model` parameter) + * "closed_form": try to initialize with closed form + - np.ndarray: direct initialization of 'a' + :param init_b: (Optional) + Low-level initial values for b. Can be: + + - str: + * "auto": automatically choose best initialization + * "random": initialize with random values + * "standard": initialize with zeros + * "init_model": initialize with another model (see `ìnit_model` parameter) + * "closed_form": try to initialize with closed form + - np.ndarray: direct initialization of 'b' + :param quick_scale: bool + Whether `scale` will be fitted faster and maybe less accurate. + Useful in scenarios where fitting the exact `scale` is not absolutely necessary. + :param dtype: Precision used in tensorflow. + """ + self._train_loc = True + self._train_scale = True + + (init_a, init_b) = self.init_par( + input_data=input_data, + init_a=init_a, + init_b=init_b, + init_model=None + ) + init_a = init_a.astype(dtype) + init_b = init_b.astype(dtype) + if quick_scale: + self._train_scale = False + + self.model_vars = ModelVars( + init_a=init_a, + init_b=init_b, + constraints_loc=input_data.constraints_loc, + constraints_scale=input_data.constraints_scale, + dtype=dtype + ) + + super(Estimator, self).__init__( + input_data=input_data, + dtype=dtype + ) + + def train( + self, + batched_model: bool = True, + batch_size: int = 500, + optimizer: str = "adam", + learning_rate: float = 1e-2, + convergence_criteria: str = "step", + stopping_criteria: int = 1000, + autograd: bool = False, + featurewise: bool = True, + benchmark: bool = False + ): + self.model = NBGLM( + model_vars=self.model_vars, + dtype=self.model_vars.dtype, + compute_a=self._train_loc, + compute_b=self._train_scale, + use_gradient_tape=autograd + ) + + self._loss = LossGLMNB() + + optimizer_object, optimizer_enum = self.get_optimizer_object(optimizer, learning_rate) + self.model.TS = optimizer_enum.value + + super(Estimator, self)._train( + noise_model="nb", + batched_model=batched_model, + batch_size=batch_size, + optimizer_object=optimizer_object, + optimizer_enum=optimizer_enum, + convergence_criteria=convergence_criteria, + stopping_criteria=stopping_criteria, + autograd=autograd, + featurewise=featurewise, + benchmark=benchmark + ) + + def get_model_container( + self, + input_data + ): + return Model(input_data=input_data) + + def init_par( + self, + input_data, + init_a, + init_b, + init_model + ): + r""" + standard: + Only initialise intercept and keep other coefficients as zero. + + closed-form: + Initialize with Maximum Likelihood / Maximum of Momentum estimators + + Idea: + $$ + \theta &= f(x) \\ + \Rightarrow f^{-1}(\theta) &= x \\ + &= (D \cdot D^{+}) \cdot x \\ + &= D \cdot (D^{+} \cdot x) \\ + &= D \cdot x' = f^{-1}(\theta) + $$ + """ + + size_factors_init = input_data.size_factors + if size_factors_init is not None: + size_factors_init = np.expand_dims(size_factors_init, axis=1) + size_factors_init = np.broadcast_to( + array=size_factors_init, + shape=[input_data.num_observations, input_data.num_features] + ) + + if init_model is None: + groupwise_means = None + init_a_str = None + if isinstance(init_a, str): + init_a_str = init_a.lower() + # Chose option if auto was chosen + if init_a.lower() == "auto": + init_a = "standard" + + if init_a.lower() == "closed_form": + groupwise_means, init_a, rmsd_a = closedform_nb_glm_logmu( + x=input_data.x, + design_loc=input_data.design_loc, + constraints_loc=input_data.constraints_loc, + size_factors=size_factors_init, + link_fn=lambda loc: np.log(self.np_clip_param(loc, "loc")) + ) + + # train mu, if the closed-form solution is inaccurate + self._train_loc = not (np.all(rmsd_a == 0) or rmsd_a.size == 0) + + if input_data.size_factors is not None: + if np.any(input_data.size_factors != 1): + self._train_loc = True + + logging.getLogger("batchglm").debug("Using closed-form MLE initialization for mean") + logging.getLogger("batchglm").debug("Should train loc: %s", self._train_loc) + elif init_a.lower() == "standard": + overall_means = np.mean(input_data.x, axis=0) # directly calculate the mean + overall_means = self.np_clip_param(overall_means, "loc") + + init_a = np.zeros([input_data.num_loc_params, input_data.num_features]) + init_a[0, :] = np.log(overall_means) + self._train_loc = True + + logging.getLogger("batchglm").debug("Using standard initialization for mean") + logging.getLogger("batchglm").debug("Should train loc: %s", self._train_loc) + elif init_a.lower() == "all_zero": + init_a = np.zeros([input_data.num_loc_params, input_data.num_features]) + self._train_loc = True + + logging.getLogger("batchglm").debug("Using all_zero initialization for mean") + logging.getLogger("batchglm").debug("Should train loc: %s", self._train_loc) + else: + raise ValueError("init_a string %s not recognized" % init_a) + + if isinstance(init_b, str): + if init_b.lower() == "auto": + init_b = "standard" + + if init_b.lower() == "standard": + groupwise_scales, init_b_intercept, rmsd_b = closedform_nb_glm_logphi( + x=input_data.x, + design_scale=input_data.design_scale[:, [0]], + constraints=input_data.constraints_scale[[0], :][:, [0]], + size_factors=size_factors_init, + groupwise_means=None, + link_fn=lambda scale: np.log(self.np_clip_param(scale, "scale")) + ) + init_b = np.zeros([input_data.num_scale_params, input_data.num_features]) + init_b[0, :] = init_b_intercept + + logging.getLogger("batchglm").debug("Using standard-form MME initialization for dispersion") + logging.getLogger("batchglm").debug("Should train scale: %s", self._train_scale) + elif init_b.lower() == "closed_form": + dmats_unequal = False + if input_data.design_loc.shape[1] == input_data.design_scale.shape[1]: + if np.any(input_data.design_loc != input_data.design_scale): + dmats_unequal = True + + inits_unequal = False + if init_a_str is not None: + if init_a_str != init_b: + inits_unequal = True + + if inits_unequal or dmats_unequal: + raise ValueError( + "cannot use closed_form init for scale model if scale model differs from loc model" + ) + + groupwise_scales, init_b, rmsd_b = closedform_nb_glm_logphi( + x=input_data.x, + design_scale=input_data.design_scale, + constraints=input_data.constraints_scale, + size_factors=size_factors_init, + groupwise_means=groupwise_means, + link_fn=lambda scale: np.log(self.np_clip_param(scale, "scale")) + ) + + logging.getLogger("batchglm").debug("Using closed-form MME initialization for dispersion") + logging.getLogger("batchglm").debug("Should train scale: %s", self._train_scale) + elif init_b.lower() == "all_zero": + init_b = np.zeros([input_data.num_scale_params, input_data.x.shape[1]]) + + logging.getLogger("batchglm").debug("Using standard initialization for dispersion") + logging.getLogger("batchglm").debug("Should train scale: %s", self._train_scale) + else: + raise ValueError("init_b string %s not recognized" % init_b) + else: + init_a, init_b = self.get_init_from_model(init_a=init_a, + init_b=init_b, + input_data=input_data, + init_model=init_model) + + return init_a, init_b diff --git a/batchglm/train/tf2/glm_nb/external.py b/batchglm/train/tf2/glm_nb/external.py new file mode 100644 index 00000000..d5c3a2e7 --- /dev/null +++ b/batchglm/train/tf2/glm_nb/external.py @@ -0,0 +1,18 @@ +import batchglm.data as data_utils + +from batchglm.models.glm_nb import _EstimatorGLM, InputDataGLM, Model +from batchglm.models.base_glm.utils import closedform_glm_mean, closedform_glm_scale +from batchglm.models.glm_nb.utils import closedform_nb_glm_logmu, closedform_nb_glm_logphi + +from batchglm.utils.linalg import groupwise_solve_lm +from batchglm import pkg_constants + +from batchglm.train.tf2.base_glm import GLM +from batchglm.train.tf2.base_glm import ProcessModelGLM, ModelVarsGLM + +# import necessary base_glm layers +from batchglm.train.tf2.base_glm import LinearLocGLM, LinearScaleGLM, LinkerLocGLM +from batchglm.train.tf2.base_glm import LinkerScaleGLM, LikelihoodGLM, UnpackParamsGLM +from batchglm.train.tf2.base_glm import FIMGLM, JacobianGLM, HessianGLM +from batchglm.train.tf2.base_glm import LossGLM +from batchglm.train.tf2.base_glm import Estimator diff --git a/batchglm/train/tf2/glm_nb/layers.py b/batchglm/train/tf2/glm_nb/layers.py new file mode 100644 index 00000000..b180c9eb --- /dev/null +++ b/batchglm/train/tf2/glm_nb/layers.py @@ -0,0 +1,59 @@ +import tensorflow as tf +from .processModel import ProcessModel +from .external import LinearLocGLM, LinearScaleGLM, LinkerLocGLM +from .external import LinkerScaleGLM, LikelihoodGLM, UnpackParamsGLM + + +class UnpackParams(UnpackParamsGLM, ProcessModel): + """ + Full class. + """ + + +class LinearLoc(LinearLocGLM, ProcessModel): + + def with_size_factors(self, eta_loc, size_factors): + return tf.add(eta_loc, tf.math.log(size_factors)) + + +class LinearScale(LinearScaleGLM, ProcessModel): + """ + Full class. + """ + + +class LinkerLoc(LinkerLocGLM): + + def _inv_linker(self, loc: tf.Tensor): + return tf.exp(loc) + + +class LinkerScale(LinkerScaleGLM): + + def _inv_linker(self, scale: tf.Tensor): + return tf.exp(scale) + + +class Likelihood(LikelihoodGLM, ProcessModel): + + def _ll(self, eta_loc, eta_scale, loc, scale, x, n_features): + + # Log-likelihood: + log_r_plus_mu = tf.math.log(scale + loc) + if isinstance(x, tf.SparseTensor): + log_probs_sparse = x.__mul__(eta_loc - log_r_plus_mu) + log_probs_dense = tf.math.lgamma(tf.sparse.add(x, scale)) - \ + tf.math.lgamma(tf.sparse.add(x, tf.ones(shape=x.dense_shape, dtype=self.ll_dtype))) - \ + tf.math.lgamma(scale) + \ + tf.multiply(scale, eta_scale - log_r_plus_mu) + log_probs = tf.sparse.add(log_probs_sparse, log_probs_dense) + # log_probs.set_shape([None, n_features]) # need as shape completely lost. + else: + log_probs = tf.math.lgamma(scale + x) - \ + tf.math.lgamma(x + tf.ones_like(x)) - \ + tf.math.lgamma(scale) + \ + tf.multiply(x, eta_loc - log_r_plus_mu) + \ + tf.multiply(scale, eta_scale - log_r_plus_mu) + + log_probs = self.tf_clip_param(log_probs, "log_probs") + return log_probs diff --git a/batchglm/train/tf2/glm_nb/layers_gradients.py b/batchglm/train/tf2/glm_nb/layers_gradients.py new file mode 100644 index 00000000..8ff079c6 --- /dev/null +++ b/batchglm/train/tf2/glm_nb/layers_gradients.py @@ -0,0 +1,144 @@ +import tensorflow as tf +from .external import FIMGLM, JacobianGLM, HessianGLM + + +class FIM(FIMGLM): + + def _weight_fim_aa( + self, + x, + loc, + scale + ): + const = tf.divide(scale, scale + loc) + w = tf.multiply(loc, const) + + return w + + def _weight_fim_bb( + self, + x, + loc, + scale + ): + return tf.zeros_like(scale) + + +class Jacobian(JacobianGLM): + + def _weights_jac_a( + self, + x, + loc, + scale, + ): + if isinstance(x, tf.SparseTensor): # or isinstance(x, tf.SparseTensorValue): + const = tf.sparse.add(x, tf.negative(loc)) + else: + const = tf.subtract(x, loc) + return tf.divide(tf.multiply(scale, const), tf.add(loc, scale)) + + def _weights_jac_b(self, x, loc, scale): + # Pre-define sub-graphs that are used multiple times: + scalar_one = tf.constant(1, shape=(), dtype=self.dtype) + if isinstance(x, tf.SparseTensor): # or isinstance(x, tf.SparseTensorValue): + scale_plus_x = tf.sparse.add(x, scale) + else: + scale_plus_x = scale + x + + r_plus_mu = scale + loc + + # Define graphs for individual terms of constant term of hessian: + const1 = tf.subtract( + tf.math.digamma(x=scale_plus_x), + tf.math.digamma(x=scale) + ) + const2 = tf.negative(scale_plus_x / r_plus_mu) + const3 = tf.add( + tf.math.log(scale), + scalar_one - tf.math.log(r_plus_mu) + ) + const = tf.add_n([const1, const2, const3]) # [observations, features] + const = scale * const + + return const + + +class Hessian(HessianGLM): + + def _weight_hessian_ab(self, x, loc, scale): + + if isinstance(x, tf.SparseTensor): + x_minus_mu = tf.sparse.add(x, -loc) + else: + x_minus_mu = x - loc + + const = tf.multiply( + loc * scale, + tf.divide( + x_minus_mu, + tf.square(loc + scale) + ) + ) + + return const + + def _weight_hessian_aa( + self, + x, + loc, + scale, + ): + if isinstance(x, tf.SparseTensor):# or isinstance(x, tf.SparseTensorValue): + x_by_scale_plus_one = tf.sparse.add(x.__div__(scale), tf.ones_like(scale)) + else: + x_by_scale_plus_one = x / scale + tf.ones_like(scale) + + const = tf.negative(tf.multiply( + loc, + tf.divide( + x_by_scale_plus_one, + tf.square((loc / scale) + tf.ones_like(loc)) + ) + )) + + return const + + def _weight_hessian_bb( + self, + x, + loc, + scale, + ): + if isinstance(x, tf.SparseTensor):# or isinstance(x, tf.SparseTensorValue): + scale_plus_x = tf.sparse.add(x, scale) + else: + scale_plus_x = x + scale + + scalar_one = tf.constant(1, shape=(), dtype=self.dtype) + scalar_two = tf.constant(2, shape=(), dtype=self.dtype) + # Pre-define sub-graphs that are used multiple times: + scale_plus_loc = scale + loc + # Define graphs for individual terms of constant term of hessian: + const1 = tf.add( + tf.math.digamma(x=scale_plus_x), + scale * tf.math.polygamma(a=scalar_one, x=scale_plus_x) + ) + const2 = tf.negative(tf.add( + tf.math.digamma(x=scale), + scale * tf.math.polygamma(a=scalar_one, x=scale) + )) + const3 = tf.negative(tf.divide( + tf.add( + loc * scale_plus_x, + scalar_two * scale * scale_plus_loc + ), + tf.square(scale_plus_loc) + )) + const4 = tf.add( + tf.math.log(scale), + scalar_two - tf.math.log(scale_plus_loc) + ) + const = tf.add_n([const1, const2, const3, const4]) + const = tf.multiply(scale, const) + return const diff --git a/batchglm/train/tf2/glm_nb/model.py b/batchglm/train/tf2/glm_nb/model.py new file mode 100644 index 00000000..665696ab --- /dev/null +++ b/batchglm/train/tf2/glm_nb/model.py @@ -0,0 +1,43 @@ +import logging + +from .external import LossGLM, GLM +from .layers import UnpackParams, LinearLoc, LinearScale, LinkerLoc, LinkerScale, Likelihood +from .layers_gradients import Jacobian, Hessian, FIM + +from .processModel import ProcessModel + +logger = logging.getLogger(__name__) + + +class NBGLM(GLM, ProcessModel): + + def __init__( + self, + model_vars, + dtype, + compute_a, + compute_b, + use_gradient_tape + ): + self.compute_a = compute_a + self.compute_b = compute_b + + super(NBGLM, self).__init__( + model_vars=model_vars, + unpack_params=UnpackParams(), + linear_loc=LinearLoc(), + linear_scale=LinearScale(), + linker_loc=LinkerLoc(), + linker_scale=LinkerScale(), + likelihood=Likelihood(dtype), + jacobian=Jacobian(model_vars=model_vars, compute_a=compute_a, compute_b=compute_b, dtype=dtype), + hessian=Hessian(model_vars=model_vars, compute_a=compute_a, compute_b=compute_b, dtype=dtype), + fim=FIM(model_vars=model_vars, compute_a=compute_a, compute_b=compute_b, dtype=dtype), + use_gradient_tape=use_gradient_tape + ) + + +class LossGLMNB(LossGLM): + """ + Full class + """ diff --git a/batchglm/train/tf2/glm_nb/processModel.py b/batchglm/train/tf2/glm_nb/processModel.py new file mode 100644 index 00000000..6a177f7f --- /dev/null +++ b/batchglm/train/tf2/glm_nb/processModel.py @@ -0,0 +1,42 @@ +from .external import ProcessModelGLM +import tensorflow as tf +import numpy as np +from .external import pkg_constants + + +class ProcessModel(ProcessModelGLM): + + def param_bounds( + self, + dtype + ): + if isinstance(dtype, tf.DType): + dmax = dtype.max + dtype = dtype.as_numpy_dtype + else: + dtype = np.dtype(dtype) + dmax = np.finfo(dtype).max + dtype = dtype.type + + sf = dtype(pkg_constants.ACCURACY_MARGIN_RELATIVE_TO_LIMIT) + bounds_min = { + "a_var": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, + "b_var": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, + "eta_loc": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, + "eta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, + "loc": np.nextafter(0, np.inf, dtype=dtype), + "scale": np.nextafter(0, np.inf, dtype=dtype), + "probs": dtype(0), + "log_probs": np.log(np.nextafter(0, np.inf, dtype=dtype)), + } + bounds_max = { + "a_var": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "b_var": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "eta_loc": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "eta_scale": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "loc": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, + "scale": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, + "probs": dtype(1), + "log_probs": dtype(0), + } + return bounds_min, bounds_max diff --git a/batchglm/train/tf2/glm_nb/vars.py b/batchglm/train/tf2/glm_nb/vars.py new file mode 100644 index 00000000..b1200abc --- /dev/null +++ b/batchglm/train/tf2/glm_nb/vars.py @@ -0,0 +1,8 @@ +from .model import ProcessModel +from .external import ModelVarsGLM + + +class ModelVars(ProcessModel, ModelVarsGLM): + """ + Full class. + """ diff --git a/batchglm/train/tf2/glm_norm/__init__.py b/batchglm/train/tf2/glm_norm/__init__.py new file mode 100644 index 00000000..b6bf02af --- /dev/null +++ b/batchglm/train/tf2/glm_norm/__init__.py @@ -0,0 +1,5 @@ +from .processModel import ProcessModel +from .vars import ModelVars +from .estimator import Estimator + +from .model import NormGLM diff --git a/batchglm/train/tf2/glm_norm/estimator.py b/batchglm/train/tf2/glm_norm/estimator.py new file mode 100644 index 00000000..cdd32b0f --- /dev/null +++ b/batchglm/train/tf2/glm_norm/estimator.py @@ -0,0 +1,284 @@ +import logging +import numpy as np +import scipy.sparse +from typing import Union + +from .external import closedform_norm_glm_logsd +from .external import InputDataGLM, Model +from .external import Estimator as GLMEstimator +from .model import NormGLM, LossGLMNorm +from .processModel import ProcessModel +from .vars import ModelVars + + +logger = logging.getLogger("batchglm") + + +class Estimator(GLMEstimator, ProcessModel): + """ + Estimator for Generalized Linear Models (GLMs) with normal distributed noise. + Uses the identity function as linker function for loc and a log-linker function for scale. + """ + + model: NormGLM + loss: LossGLMNorm + + def __init__( + self, + input_data: InputDataGLM, + init_a: Union[np.ndarray, str] = "AUTO", + init_b: Union[np.ndarray, str] = "AUTO", + quick_scale: bool = False, + dtype="float64", + ): + """ + Performs initialisation and creates a new estimator. + + :param input_data: InputDataGLM + The input data + :param init_a: (Optional) + Low-level initial values for a. Can be: + + - str: + * "auto": automatically choose best initialization + * "all zero": initialize with zeros + * "random": initialize with random values + * "standard": initialize intercept with observed mean + * "init_model": initialize with another model (see `ìnit_model` parameter) + * "closed_form": try to initialize with closed form + - np.ndarray: direct initialization of 'a' + :param init_b: (Optional) + Low-level initial values for b. Can be: + + - str: + * "auto": automatically choose best initialization + * "random": initialize with random values + * "standard": initialize with zeros + * "init_model": initialize with another model (see `ìnit_model` parameter) + * "closed_form": try to initialize with closed form + - np.ndarray: direct initialization of 'b' + :param quick_scale: bool + Whether `scale` will be fitted faster and maybe less accurate. + Useful in scenarios where fitting the exact `scale` is not absolutely necessary. + :param dtype: Precision used in tensorflow. + """ + + self._train_loc = True + self._train_scale = True + + (init_a, init_b) = self.init_par( + input_data=input_data, + init_a=init_a, + init_b=init_b, + init_model=None + ) + init_a = init_a.astype(dtype) + init_b = init_b.astype(dtype) + if quick_scale: + self._train_scale = False + + self.model_vars = ModelVars( + init_a=init_a, + init_b=init_b, + constraints_loc=input_data.constraints_loc, + constraints_scale=input_data.constraints_scale, + dtype=dtype + ) + + super(Estimator, self).__init__( + input_data=input_data, + dtype=dtype + ) + + def train( + self, + batched_model=True, + batch_size: int = 500, + optimizer: str = "adam", + learning_rate: float = 1e-2, + convergence_criteria="step", + stopping_criteria=1000, + autograd=False, + featurewise = True, + benchmark: bool = False + ): + + self.model = NormGLM( + model_vars=self.model_vars, + dtype=self.model_vars.dtype, + compute_a=self._train_loc, + compute_b=self._train_scale, + use_gradient_tape=autograd + ) + + self._loss = LossGLMNorm() + + optimizer_object, optimizer_enum = self.get_optimizer_object(optimizer, learning_rate) + self.model.TS = optimizer_enum.value + + super(Estimator, self)._train( + noise_model="norm", + batched_model=batched_model, + batch_size=batch_size, + optimizer_object=optimizer_object, + optimizer_enum=optimizer_enum, + convergence_criteria=convergence_criteria, + stopping_criteria=stopping_criteria, + autograd=autograd, + featurewise=featurewise, + benchmark=benchmark + + ) + + def get_model_container( + self, + input_data + ): + return Model(input_data=input_data) + + def init_par(self, input_data, init_a, init_b, init_model): + r""" + standard: + Only initialise intercept and keep other coefficients as zero. + + closed-form: + Initialize with Maximum Likelihood / Maximum of Momentum estimators + """ + + size_factors_init = input_data.size_factors + if size_factors_init is not None: + size_factors_init = np.expand_dims(size_factors_init, axis=1) + size_factors_init = np.broadcast_to( + array=size_factors_init, + shape=[input_data.num_observations, input_data.num_features] + ) + + sf_given = False + if input_data.size_factors is not None: + if np.any(np.abs(input_data.size_factors - 1.) > 1e-8): + sf_given = True + + is_ols_model = input_data.design_scale.shape[1] == 1 and \ + np.all(np.abs(input_data.design_scale - 1.) < 1e-8) and not sf_given + + if init_model is None: + groupwise_means = None + init_a_str = None + if isinstance(init_a, str): + init_a_str = init_a.lower() + # Chose option if auto was chosen + if init_a.lower() == "auto": + init_a = "closed_form" + + if init_a.lower() == "closed_form" or init_a.lower() == "standard": + design_constr = np.matmul(input_data.design_loc, input_data.constraints_loc) + # Iterate over genes if X is sparse to avoid large sparse tensor. + # If X is dense, the least square problem can be vectorised easily. + if isinstance(input_data.x, scipy.sparse.csr_matrix): + init_a, rmsd_a, _, _ = np.linalg.lstsq( + np.matmul(design_constr.T, design_constr), + input_data.x.T.dot(design_constr).T, # need double .T because of dot product on sparse. + rcond=None + ) + else: + init_a, rmsd_a, _, _ = np.linalg.lstsq( + np.matmul(design_constr.T, design_constr), + np.matmul(design_constr.T, input_data.x), + rcond=None + ) + groupwise_means = None + if is_ols_model: + self._train_loc = False + + logger.debug("Using OLS initialization for location model") + elif init_a.lower() == "all_zero": + init_a = np.zeros([input_data.num_loc_params, input_data.num_features]) + self._train_loc = True + + logger.debug("Using all_zero initialization for mean") + else: + raise ValueError("init_a string %s not recognized" % init_a) + logger.debug("Should train location model: %s", self._train_loc) + + if isinstance(init_b, str): + if init_b.lower() == "auto": + init_b = "standard" + + if is_ols_model: + # Calculated variance via E(x)^2 or directly depending on whether `mu` was specified. + if isinstance(input_data.x, scipy.sparse.csr_matrix): + expect_xsq = np.asarray(np.mean(input_data.x.power(2), axis=0)) + else: + expect_xsq = np.expand_dims(np.mean(np.square(input_data.x), axis=0), axis=0) + mean_model = np.matmul( + np.matmul(input_data.design_loc, input_data.constraints_loc), + init_a + ) + expect_x_sq = np.mean(np.square(mean_model), axis=0) + variance = (expect_xsq - expect_x_sq) + init_b = np.log(np.sqrt(variance)) + self._train_scale = False + + logger.debug("Using residuals from OLS estimate for variance estimate") + elif init_b.lower() == "closed_form": + dmats_unequal = False + if input_data.design_loc.shape[1] == input_data.design_scale.shape[1]: + if np.any(input_data.design_loc != input_data.design_scale): + dmats_unequal = True + + inits_unequal = False + if init_a_str is not None: + if init_a_str != init_b: + inits_unequal = True + + # Watch out: init_mean is full obs x features matrix and is very large in many cases. + if inits_unequal or dmats_unequal: + raise ValueError( + "cannot use closed_form init for scale model \ + if scale model differs from loc model" + ) + + groupwise_scales, init_b, rmsd_b = closedform_norm_glm_logsd( + x=input_data.x, + design_scale=input_data.design_scale, + constraints=input_data.constraints_scale, + size_factors=size_factors_init, + groupwise_means=groupwise_means, + link_fn=lambda sd: np.log(self.np_clip_param(sd, "sd")) + ) + + # train scale, if the closed-form solution is inaccurate + self._train_scale = not (np.all(rmsd_b == 0) or rmsd_b.size == 0) + + logger.debug("Using closed-form MME initialization for standard deviation") + elif init_b.lower() == "standard": + groupwise_scales, init_b_intercept, rmsd_b = closedform_norm_glm_logsd( + x=input_data.x, + design_scale=input_data.design_scale[:, [0]], + constraints=input_data.constraints_scale[[0], :][:, [0]], + size_factors=size_factors_init, + groupwise_means=None, + link_fn=lambda sd: np.log(self.np_clip_param(sd, "sd")) + ) + init_b = np.zeros([input_data.num_scale_params, input_data.num_features]) + init_b[0, :] = init_b_intercept + + # train scale, if the closed-form solution is inaccurate + self._train_scale = not (np.all(rmsd_b == 0) or rmsd_b.size == 0) + + logger.debug("Using closed-form MME initialization for standard deviation") + logger.debug("Should train sd: %s", self._train_scale) + elif init_b.lower() == "all_zero": + init_b = np.zeros([input_data.num_scale_params, input_data.num_features]) + + logger.debug("Using standard initialization for standard deviation") + else: + raise ValueError("init_b string %s not recognized" % init_b) + logger.debug("Should train sd: %s", self._train_scale) + else: + init_a, init_b = self.get_init_from_model(init_a=init_a, + init_b=init_b, + input_data=input_data, + init_model=init_model) + + return init_a, init_b diff --git a/batchglm/train/tf2/glm_norm/external.py b/batchglm/train/tf2/glm_norm/external.py new file mode 100644 index 00000000..4b290d2e --- /dev/null +++ b/batchglm/train/tf2/glm_norm/external.py @@ -0,0 +1,12 @@ +import batchglm.data as data_utils + +from batchglm.models.glm_norm import _EstimatorGLM, InputDataGLM, Model +from batchglm.models.base_glm.utils import closedform_glm_mean, closedform_glm_scale +from batchglm.models.glm_norm.utils import closedform_norm_glm_mean, closedform_norm_glm_logsd + +from batchglm.utils.linalg import groupwise_solve_lm +from batchglm import pkg_constants + +from batchglm.train.tf2.base_glm import ProcessModelGLM, GLM, LossGLM, Estimator, ModelVarsGLM +from batchglm.train.tf2.base_glm import LinearLocGLM, LinearScaleGLM, LinkerLocGLM, LinkerScaleGLM, LikelihoodGLM, UnpackParamsGLM +from batchglm.train.tf2.base_glm import FIMGLM, JacobianGLM, HessianGLM diff --git a/batchglm/train/tf2/glm_norm/layers.py b/batchglm/train/tf2/glm_norm/layers.py new file mode 100644 index 00000000..ba067352 --- /dev/null +++ b/batchglm/train/tf2/glm_norm/layers.py @@ -0,0 +1,49 @@ +import tensorflow as tf +import numpy as np +from .external import LinearLocGLM, LinearScaleGLM, LinkerLocGLM, LinkerScaleGLM, LikelihoodGLM, UnpackParamsGLM +from .processModel import ProcessModel + + +class UnpackParams(UnpackParamsGLM, ProcessModel): + """ + Full class. + """ + + +class LinearLoc(LinearLocGLM, ProcessModel): + + def with_size_factors(self, eta_loc, size_factors): + return tf.multiply(eta_loc, size_factors) + + +class LinearScale(LinearScaleGLM, ProcessModel): + """ + Full Class + """ + + +class LinkerLoc(LinkerLocGLM): + + def _inv_linker(self, loc: tf.Tensor): + return loc + + +class LinkerScale(LinkerScaleGLM): + + def _inv_linker(self, scale: tf.Tensor): + return tf.math.exp(scale) + + +class Likelihood(LikelihoodGLM, ProcessModel): + + def _ll(self, eta_loc, eta_scale, loc, scale, x, n_features): + + const = tf.constant(-0.5 * np.log(2 * np.pi), shape=(), dtype=self.ll_dtype) + if isinstance(x, tf.SparseTensor): + log_probs = const - eta_scale - 0.5 * tf.math.square(tf.divide(tf.sparse.add(x, - loc), scale)) + # log_probs.set_shape([None, a_var.shape[1]]) # Need this so as shape is completely lost. + else: + log_probs = const - eta_scale - 0.5 * tf.math.square(tf.divide(x - loc, scale)) + log_probs = self.tf_clip_param(log_probs, "log_probs") + + return log_probs diff --git a/batchglm/train/tf2/glm_norm/layers_gradients.py b/batchglm/train/tf2/glm_norm/layers_gradients.py new file mode 100644 index 00000000..e2b35119 --- /dev/null +++ b/batchglm/train/tf2/glm_norm/layers_gradients.py @@ -0,0 +1,116 @@ +import tensorflow as tf +from .external import FIMGLM, JacobianGLM, HessianGLM + + +class FIM(FIMGLM): + + def _weight_fim_aa( + self, + x, + loc, + scale + ): + w = tf.square(tf.divide(tf.ones_like(scale), scale)) + + return w + + def _weight_fim_bb( + self, + x, + loc, + scale + ): + w = tf.constant(2, shape=loc.shape, dtype=self.dtype) + + return w + + +class Jacobian(JacobianGLM): + + def _weights_jac_a( + self, + x, + loc, + scale, + ): + if isinstance(x, tf.SparseTensor): + const1 = tf.sparse.add(x, -loc) + const = tf.divide(const1, tf.square(scale)) + else: + const1 = tf.subtract(x, loc) + const = tf.divide(const1, tf.square(scale)) + return const + + def _weights_jac_b( + self, + x, + loc, + scale, + ): + scalar_one = tf.constant(1, shape=(), dtype=self.dtype) + if isinstance(x, tf.SparseTensor): + const = tf.negative(scalar_one) + tf.math.square( + tf.divide(tf.sparse.add(x, -loc), scale) + ) + else: + const = tf.negative(scalar_one) + tf.math.square( + tf.divide(tf.subtract(x, loc), scale) + ) + return const + + +class Hessian(HessianGLM): + + def _weight_hessian_ab( + self, + x, + loc, + scale, + ): + scalar_two = tf.constant(2, shape=(), dtype=self.dtype) + if isinstance(x, tf.SparseTensor): + x_minus_loc = tf.sparse.add(x, -loc) + else: + x_minus_loc = x - loc + + const = - tf.multiply(scalar_two, + tf.divide( + x_minus_loc, + tf.square(scale) + ) + ) + return const + + def _weight_hessian_aa( + self, + x, + loc, + scale, + ): + scalar_one = tf.constant(1, shape=(), dtype=self.dtype) + const = - tf.divide(scalar_one, tf.square(scale)) + + return const + + def _weight_hessian_bb( + self, + x, + loc, + scale, + ): + scalar_two = tf.constant(2, shape=(), dtype=self.dtype) + if isinstance(x, tf.SparseTensor): + x_minus_loc = tf.sparse.add(x, -loc) + else: + x_minus_loc = x - loc + + const = - tf.multiply( + scalar_two, + tf.math.square( + tf.divide( + x_minus_loc, + scale + ) + ) + ) + return const diff --git a/batchglm/train/tf2/glm_norm/model.py b/batchglm/train/tf2/glm_norm/model.py new file mode 100644 index 00000000..e5b74734 --- /dev/null +++ b/batchglm/train/tf2/glm_norm/model.py @@ -0,0 +1,55 @@ +import logging + +from .layers import UnpackParams, LinearLoc, LinearScale, LinkerLoc, LinkerScale, Likelihood +from .layers_gradients import Jacobian, Hessian, FIM +from .external import GLM, LossGLM +from .processModel import ProcessModel + +logger = logging.getLogger(__name__) + + +class NormGLM(GLM, ProcessModel): + + def __init__( + self, + model_vars, + dtype, + compute_a, + compute_b, + use_gradient_tape + ): + self.compute_a = compute_a + self.compute_b = compute_b + + super(NormGLM, self).__init__( + model_vars=model_vars, + unpack_params=UnpackParams(), + linear_loc=LinearLoc(), + linear_scale=LinearScale(), + linker_loc=LinkerLoc(), + linker_scale=LinkerScale(), + likelihood=Likelihood(dtype), + jacobian=Jacobian( + model_vars=model_vars, + compute_a=self.compute_a, + compute_b=self.compute_b, + dtype=dtype), + hessian=Hessian( + model_vars=model_vars, + compute_a=self.compute_a, + compute_b=self.compute_b, + dtype=dtype), + fim=FIM( + model_vars=model_vars, + compute_a=self.compute_a, + compute_b=self.compute_b, + dtype=dtype), + use_gradient_tape=use_gradient_tape + ) + + +class LossGLMNorm(LossGLM): + + """ + Full class + """ diff --git a/batchglm/train/tf2/glm_norm/processModel.py b/batchglm/train/tf2/glm_norm/processModel.py new file mode 100644 index 00000000..629099ff --- /dev/null +++ b/batchglm/train/tf2/glm_norm/processModel.py @@ -0,0 +1,42 @@ +from .external import ProcessModelGLM +import tensorflow as tf +import numpy as np +from .external import pkg_constants + + +class ProcessModel(ProcessModelGLM): + + def param_bounds( + self, + dtype + ): + if isinstance(dtype, tf.DType): + dmax = dtype.max + dtype = dtype.as_numpy_dtype + else: + dtype = np.dtype(dtype) + dmax = np.finfo(dtype).max + dtype = dtype.type + + sf = dtype(pkg_constants.ACCURACY_MARGIN_RELATIVE_TO_LIMIT) + bounds_min = { + "a_var": np.nextafter(-dmax, np.inf, dtype=dtype) / sf, + "b_var": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, + "eta_loc": np.nextafter(-dmax, np.inf, dtype=dtype) / sf, + "eta_scale": np.log(np.nextafter(0, np.inf, dtype=dtype)) / sf, + "mean": np.nextafter(-dmax, np.inf, dtype=dtype) / sf, + "sd": np.nextafter(0, np.inf, dtype=dtype), + "probs": dtype(0), + "log_probs": np.log(np.nextafter(0, np.inf, dtype=dtype)), + } + bounds_max = { + "a_var": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, + "b_var": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "eta_loc": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, + "eta_scale": np.nextafter(np.log(dmax), -np.inf, dtype=dtype) / sf, + "mean": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, + "sd": np.nextafter(dmax, -np.inf, dtype=dtype) / sf, + "probs": dtype(1), + "log_probs": dtype(0), + } + return bounds_min, bounds_max diff --git a/batchglm/train/tf2/glm_norm/vars.py b/batchglm/train/tf2/glm_norm/vars.py new file mode 100644 index 00000000..b1200abc --- /dev/null +++ b/batchglm/train/tf2/glm_norm/vars.py @@ -0,0 +1,8 @@ +from .model import ProcessModel +from .external import ModelVarsGLM + + +class ModelVars(ProcessModel, ModelVarsGLM): + """ + Full class. + """ diff --git a/batchglm/train/tf2/ops.py b/batchglm/train/tf2/ops.py new file mode 100644 index 00000000..56fbf48b --- /dev/null +++ b/batchglm/train/tf2/ops.py @@ -0,0 +1,59 @@ +import tensorflow as tf +from typing import Union + + +def swap_dims(tensor, axis0, axis1, exec_transpose=True, return_perm=False, name="swap_dims"): + """ + Swaps two dimensions in a given tensor. + + :param tensor: The tensor whose axes should be swapped + :param axis0: The first axis which should be swapped with `axis1` + :param axis1: The second axis which should be swapped with `axis0` + :param exec_transpose: Should the transpose operation be applied? + :param return_perm: Should the permutation argument for `tf.transpose` be returned? + Autmoatically true, if `exec_transpose` is False + :param name: The name scope of this op + :return: either retval, (retval, permutation) or permutation + """ + with tf.name_scope(name): + rank = tf.range(tf.rank(tensor)) + idx0 = rank[axis0] + idx1 = rank[axis1] + perm0 = tf.where(tf.equal(rank, idx0), tf.tile(tf.expand_dims(idx1, 0), [tf.size(rank)]), rank) + perm1 = tf.where(tf.equal(rank, idx1), tf.tile(tf.expand_dims(idx0, 0), [tf.size(rank)]), perm0) + + if exec_transpose: + retval = tf.transpose(tensor, perm1) + + if return_perm: + return retval, perm1 + else: + return retval + else: + return perm1 + + +def stacked_lstsq(L, b, rcond=1e-10, name="stacked_lstsq"): + r""" + Solve `Lx = b`, via SVD least squares cutting of small singular values + + :param L: tensor of shape (..., M, K) + :param b: tensor of shape (..., M, N). + :param rcond: threshold for inverse + :param name: name scope of this op + :return: x of shape (..., K, N) + """ + with tf.name_scope(name): + u, s, v = tf.linalg.svd(L, full_matrices=False) + s_max = s.max(axis=-1, keepdims=True) + s_min = rcond * s_max + + inv_s = tf.where(s >= s_min, tf.reciprocal(s), 0) + + x = tf.einsum( + '...MK,...MN->...KN', + v, + tf.einsum('...K,...MK,...MN->...KN', inv_s, u, b) + ) + + return tf.conj(x)