diff --git a/bnpm/circular.py b/bnpm/circular.py index 6595f62..d7e12b2 100644 --- a/bnpm/circular.py +++ b/bnpm/circular.py @@ -1,8 +1,10 @@ +from typing import Union, Optional +import functools + import numpy as np import torch import matplotlib.pyplot as plt -import functools from . import misc @@ -12,7 +14,7 @@ """ -def _circ_operator(a, period=2*np.pi): +def _circ_operator(a, period: float = 2*np.pi): """ Helper function for circular arithmetic. \n This function is used to ensure that the output of a circular operation is @@ -32,7 +34,7 @@ def _circ_operator(a, period=2*np.pi): return ((a + period/2) % period) - period/2 -def circ_subtract(a, b, period=2*np.pi): +def circ_subtract(a, b, period: float = 2*np.pi): """ Modular subtraction of ``b`` from ``a``. \n This is equivalent to the distance from ``a`` to ``b`` on a circle. \n @@ -53,7 +55,7 @@ def circ_subtract(a, b, period=2*np.pi): return _circ_operator(a - b, period=period) -def circ_add(a, b, period=2*np.pi): +def circ_add(a, b, period: float = 2*np.pi): """ Modular addition of ``a`` and ``b``. \n This is equivalent to the sum of ``a`` and ``b`` on a circle. \n @@ -74,7 +76,14 @@ def circ_add(a, b, period=2*np.pi): return _circ_operator(a + b, period=period) @misc.wrapper_flexible_args(['dim', 'axis']) -def circ_diff(arr, period=2*np.pi, axis=-1, prepend=None, append=None, n=1): +def circ_diff( + arr: Union[np.ndarray, torch.Tensor], + period: Union[float, np.ndarray, torch.Tensor] = 2*np.pi, + axis: int = -1, + prepend: int = None, + append: int = None, + n: int = 1 +): """ Modular derivative (like np.diff) of an array. \n Calculates the circular difference between adjacent elements of ``arr``. \n @@ -168,186 +177,3 @@ def moduloCounter_to_linearCounter(trace, modulus, modulus_value=None, diff_thre plt.plot(trace_times) return trace_times - - -def _circfuncs_common(samples, high, low): - """ - Helper function for circular statistics. \n - This function is used to ensure that the output of a circular operation is - always within the range [low, high). \n - RH 2024 - - Args: - samples (np.ndarray or torch.Tensor): - Input values - high (float or np.ndarray or torch.Tensor): - High value - low (float or np.ndarray or torch.Tensor): - Low value - - Returns: - output (np.ndarray or torch.Tensor): - Output values - """ - if isinstance(samples, torch.Tensor): - nan, pi = (torch.tensor(v, dtype=samples.dtype, device=samples.device) for v in [torch.nan, np.pi]) - sin, cos = torch.sin, torch.cos - else: - nan, pi = (np.array(v, dtype=samples.dtype) for v in [np.nan, np.pi]) - sin, cos = np.sin, np.cos, - - if samples.size == 0: - return nan, nan - - ## sin and cos - samples = (samples - low) * 2.0 * pi / (high - low) - sin_samp = sin(samples) - cos_samp = cos(samples) - - return sin_samp, cos_samp - - -def circmean(samples, high=2*np.pi, low=0, axis=None, nan_policy='propagate'): - """ - Circular mean of samples. Equivalent results to scipy.stats.circmean. \n - RH 2024 - - Args: - samples (np.ndarray or torch.Tensor): - Input values - high (float or np.ndarray or torch.Tensor): - High value - low (float or np.ndarray or torch.Tensor): - Low value - axis (int): - Axis along which to take the mean - nan_policy (str): - Policy for handling NaN values: \n - * 'propagate' - Propagate NaN values. - * 'omit' - Ignore NaN values. - * 'raise' - Raise an error if NaN values are present. - - Returns: - mean (np.ndarray or torch.Tensor): - Mean values - """ - - if nan_policy != 'propagate': - raise NotImplementedError("Only 'propagate' nan_policy is supported") - - if isinstance(samples, torch.Tensor): - pi = torch.tensor(np.pi, dtype=samples.dtype, device=samples.device) - arctan2, sum, nansum = torch.atan2, torch.sum, torch.nansum - else: - pi = np.array(np.pi, dtype=samples.dtype) - arctan2, sum, nansum = np.arctan2, np.sum, np.nansum - - if nan_policy == 'omit': - fn_sum = nansum - elif nan_policy == 'propagate': - fn_sum = sum - elif nan_policy == 'raise': - if torch.any(torch.isnan(samples)): - raise ValueError("NaN values are present in the input") - fn_sum = sum - else: - raise ValueError("Invalid nan_policy") - - sin_samp, cos_samp = _circfuncs_common(samples, high, low) - sin_sum = fn_sum(sin_samp, axis) - cos_sum = fn_sum(cos_samp, axis) - res = arctan2(sin_sum, cos_sum) - - res[res < 0] += 2 * pi - res = res[()] - - return res*(high - low)/2.0/pi + low - - -def cirvar(samples, high=2*np.pi, low=0, axis=None, nan_policy='propagate'): - """ - Circular variance of samples. Equivalent results to scipy.stats.circvar. \n - RH 2024 - - Args: - samples (np.ndarray or torch.Tensor): - Input values - high (float or np.ndarray or torch.Tensor): - High value - low (float or np.ndarray or torch.Tensor): - Low value - axis (int): - Axis along which to take the variance - nan_policy (str): - Policy for handling NaN values. Can only be 'propagate' for now. - - Returns: - variance (np.ndarray or torch.Tensor): - Variance values - """ - - if nan_policy != 'propagate': - raise NotImplementedError("Only 'propagate' nan_policy is supported") - - if isinstance(samples, torch.Tensor): - sqrt = torch.sqrt - else: - sqrt = np.sqrt - - sin_samp, cos_samp = _circfuncs_common(samples, high, low) - sin_mean = sin_samp.mean(axis) - cos_mean = cos_samp.mean(axis) - - R = sqrt(sin_mean**2 + cos_mean**2) - - return 1 - R - - -def circstd(samples, high=2*np.pi, low=0, axis=None, nan_policy='propagate', normalize=False): - """ - Circular standard deviation of samples. Equivalent results to - scipy.stats.circstd. \n - RH 2024 - - Args: - samples (np.ndarray or torch.Tensor): - Input values - high (float or np.ndarray or torch.Tensor): - High value - low (float or np.ndarray or torch.Tensor): - Low value - axis (int): - Axis along which to take the standard deviation - nan_policy (str): - Policy for handling NaN values. Can only be 'propagate' for now. - normalize (bool): - Whether to normalize the standard deviation. If True, the result is - equal to ``sqrt(-2*log(R))`` and does not depend on the variable - units. If False (default), the returned value is scaled by - ``((high-low)/(2*pi))``. - - - Returns: - std (np.ndarray or torch.Tensor): - Standard deviation values - """ - - if nan_policy != 'propagate': - raise NotImplementedError("Only 'propagate' nan_policy is supported") - - if isinstance(samples, torch.Tensor): - pi = torch.tensor(np.pi, dtype=samples.dtype, device=samples.device) - sqrt, log = torch.sqrt, torch.log - else: - pi = np.array(np.pi, dtype=samples.dtype) - sqrt, log = np.sqrt, np.log - - sin_samp, cos_samp = _circfuncs_common(samples, high, low) - sin_mean = sin_samp.mean(axis) - cos_mean = cos_samp.mean(axis) - R = sqrt(sin_mean**2 + cos_mean**2) - - res = sqrt(-2*log(R)) - if not normalize: - res *= (high-low)/(2.*pi) - return res \ No newline at end of file