From b2f9d72fae414aef3e8be844444636bbcca407e8 Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Wed, 27 Mar 2024 19:49:43 -0400 Subject: [PATCH] Add torch support to geometric_mean function --- bnpm/stats.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/bnpm/stats.py b/bnpm/stats.py index aca5053..8de7892 100644 --- a/bnpm/stats.py +++ b/bnpm/stats.py @@ -1,5 +1,6 @@ import numpy as np import scipy.stats +import torch def ttest_paired_ratio(a, b): """ @@ -36,7 +37,7 @@ def ttest_paired_ratio(a, b): return p_val -def geometric_mean(a): +def geometric_mean(a, axis=0, nan_policy="omit"): """ Computes the geometric mean of an array of data. This is useful for computing the geometric mean of ratios. @@ -47,7 +48,23 @@ def geometric_mean(a): a (np.ndarray or torch.Tensor): Array of data. """ - return np.exp(np.mean(np.log(a))) + if isinstance(a, (np.ndarray, list)): + mean, nanmean, isnan, exp, log = np.mean, np.nanmean, np.isnan, np.exp, np.log + elif isinstance(a, torch.Tensor): + mean, nanmean, isnan, exp, log = torch.mean, torch.nanmean, torch.isnan, torch.exp, torch.log + else: + raise ValueError("Data must be a numpy array or a torch tensor.") + + if nan_policy == "omit": + mean = nanmean + elif nan_policy == "propagate": + mean = mean + elif nan_policy == "raise": + mean = mean + if isnan(a).any(): + raise ValueError("Data contains nan values.") + + return exp(mean(log(a), axis=axis)) def sparsity(a, axis=0):