Skip to content

Commit

Permalink
move out circ stats to torch_helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed May 6, 2024
1 parent 26382be commit c0156c4
Showing 1 changed file with 14 additions and 188 deletions.
202 changes: 14 additions & 188 deletions bnpm/circular.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Union, Optional
import functools

import numpy as np
import torch
import matplotlib.pyplot as plt

import functools

from . import misc

Expand All @@ -12,7 +14,7 @@
"""


def _circ_operator(a, period=2*np.pi):
def _circ_operator(a, period: float = 2*np.pi):
"""
Helper function for circular arithmetic. \n
This function is used to ensure that the output of a circular operation is
Expand All @@ -32,7 +34,7 @@ def _circ_operator(a, period=2*np.pi):
return ((a + period/2) % period) - period/2


def circ_subtract(a, b, period=2*np.pi):
def circ_subtract(a, b, period: float = 2*np.pi):
"""
Modular subtraction of ``b`` from ``a``. \n
This is equivalent to the distance from ``a`` to ``b`` on a circle. \n
Expand All @@ -53,7 +55,7 @@ def circ_subtract(a, b, period=2*np.pi):
return _circ_operator(a - b, period=period)


def circ_add(a, b, period=2*np.pi):
def circ_add(a, b, period: float = 2*np.pi):
"""
Modular addition of ``a`` and ``b``. \n
This is equivalent to the sum of ``a`` and ``b`` on a circle. \n
Expand All @@ -74,7 +76,14 @@ def circ_add(a, b, period=2*np.pi):
return _circ_operator(a + b, period=period)

@misc.wrapper_flexible_args(['dim', 'axis'])
def circ_diff(arr, period=2*np.pi, axis=-1, prepend=None, append=None, n=1):
def circ_diff(
arr: Union[np.ndarray, torch.Tensor],
period: Union[float, np.ndarray, torch.Tensor] = 2*np.pi,
axis: int = -1,
prepend: int = None,
append: int = None,
n: int = 1
):
"""
Modular derivative (like np.diff) of an array. \n
Calculates the circular difference between adjacent elements of ``arr``. \n
Expand Down Expand Up @@ -168,186 +177,3 @@ def moduloCounter_to_linearCounter(trace, modulus, modulus_value=None, diff_thre
plt.plot(trace_times)

return trace_times


def _circfuncs_common(samples, high, low):
"""
Helper function for circular statistics. \n
This function is used to ensure that the output of a circular operation is
always within the range [low, high). \n
RH 2024
Args:
samples (np.ndarray or torch.Tensor):
Input values
high (float or np.ndarray or torch.Tensor):
High value
low (float or np.ndarray or torch.Tensor):
Low value
Returns:
output (np.ndarray or torch.Tensor):
Output values
"""
if isinstance(samples, torch.Tensor):
nan, pi = (torch.tensor(v, dtype=samples.dtype, device=samples.device) for v in [torch.nan, np.pi])
sin, cos = torch.sin, torch.cos
else:
nan, pi = (np.array(v, dtype=samples.dtype) for v in [np.nan, np.pi])
sin, cos = np.sin, np.cos,

if samples.size == 0:
return nan, nan

## sin and cos
samples = (samples - low) * 2.0 * pi / (high - low)
sin_samp = sin(samples)
cos_samp = cos(samples)

return sin_samp, cos_samp


def circmean(samples, high=2*np.pi, low=0, axis=None, nan_policy='propagate'):
"""
Circular mean of samples. Equivalent results to scipy.stats.circmean. \n
RH 2024
Args:
samples (np.ndarray or torch.Tensor):
Input values
high (float or np.ndarray or torch.Tensor):
High value
low (float or np.ndarray or torch.Tensor):
Low value
axis (int):
Axis along which to take the mean
nan_policy (str):
Policy for handling NaN values: \n
* 'propagate' - Propagate NaN values.
* 'omit' - Ignore NaN values.
* 'raise' - Raise an error if NaN values are present.
Returns:
mean (np.ndarray or torch.Tensor):
Mean values
"""

if nan_policy != 'propagate':
raise NotImplementedError("Only 'propagate' nan_policy is supported")

if isinstance(samples, torch.Tensor):
pi = torch.tensor(np.pi, dtype=samples.dtype, device=samples.device)
arctan2, sum, nansum = torch.atan2, torch.sum, torch.nansum
else:
pi = np.array(np.pi, dtype=samples.dtype)
arctan2, sum, nansum = np.arctan2, np.sum, np.nansum

if nan_policy == 'omit':
fn_sum = nansum
elif nan_policy == 'propagate':
fn_sum = sum
elif nan_policy == 'raise':
if torch.any(torch.isnan(samples)):
raise ValueError("NaN values are present in the input")
fn_sum = sum
else:
raise ValueError("Invalid nan_policy")

sin_samp, cos_samp = _circfuncs_common(samples, high, low)
sin_sum = fn_sum(sin_samp, axis)
cos_sum = fn_sum(cos_samp, axis)
res = arctan2(sin_sum, cos_sum)

res[res < 0] += 2 * pi
res = res[()]

return res*(high - low)/2.0/pi + low


def cirvar(samples, high=2*np.pi, low=0, axis=None, nan_policy='propagate'):
"""
Circular variance of samples. Equivalent results to scipy.stats.circvar. \n
RH 2024
Args:
samples (np.ndarray or torch.Tensor):
Input values
high (float or np.ndarray or torch.Tensor):
High value
low (float or np.ndarray or torch.Tensor):
Low value
axis (int):
Axis along which to take the variance
nan_policy (str):
Policy for handling NaN values. Can only be 'propagate' for now.
Returns:
variance (np.ndarray or torch.Tensor):
Variance values
"""

if nan_policy != 'propagate':
raise NotImplementedError("Only 'propagate' nan_policy is supported")

if isinstance(samples, torch.Tensor):
sqrt = torch.sqrt
else:
sqrt = np.sqrt

sin_samp, cos_samp = _circfuncs_common(samples, high, low)
sin_mean = sin_samp.mean(axis)
cos_mean = cos_samp.mean(axis)

R = sqrt(sin_mean**2 + cos_mean**2)

return 1 - R


def circstd(samples, high=2*np.pi, low=0, axis=None, nan_policy='propagate', normalize=False):
"""
Circular standard deviation of samples. Equivalent results to
scipy.stats.circstd. \n
RH 2024
Args:
samples (np.ndarray or torch.Tensor):
Input values
high (float or np.ndarray or torch.Tensor):
High value
low (float or np.ndarray or torch.Tensor):
Low value
axis (int):
Axis along which to take the standard deviation
nan_policy (str):
Policy for handling NaN values. Can only be 'propagate' for now.
normalize (bool):
Whether to normalize the standard deviation. If True, the result is
equal to ``sqrt(-2*log(R))`` and does not depend on the variable
units. If False (default), the returned value is scaled by
``((high-low)/(2*pi))``.
Returns:
std (np.ndarray or torch.Tensor):
Standard deviation values
"""

if nan_policy != 'propagate':
raise NotImplementedError("Only 'propagate' nan_policy is supported")

if isinstance(samples, torch.Tensor):
pi = torch.tensor(np.pi, dtype=samples.dtype, device=samples.device)
sqrt, log = torch.sqrt, torch.log
else:
pi = np.array(np.pi, dtype=samples.dtype)
sqrt, log = np.sqrt, np.log

sin_samp, cos_samp = _circfuncs_common(samples, high, low)
sin_mean = sin_samp.mean(axis)
cos_mean = cos_samp.mean(axis)
R = sqrt(sin_mean**2 + cos_mean**2)

res = sqrt(-2*log(R))
if not normalize:
res *= (high-low)/(2.*pi)
return res

0 comments on commit c0156c4

Please sign in to comment.