Skip to content

Commit

Permalink
Add torch_coherence function to spectral.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 16, 2024
1 parent 5139f0a commit 229b460
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 1 deletion.
167 changes: 166 additions & 1 deletion bnpm/spectral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Tuple, List, Dict, Any, Optional
import functools
import math

import scipy.signal
import scipy.stats
Expand Down Expand Up @@ -513,4 +514,168 @@ def filtfilt_simple_fft(
idx=slice(0, x.shape[-1]),
)
out = out.real if use_real else out
return out
return out


def torch_coherence(
x: torch.Tensor,
y: torch.Tensor,
fs: float = 1.0,
window: str = 'hann',
nperseg: Optional[int] = None,
noverlap: Optional[int] = None,
nfft: Optional[int] = None,
detrend: str = 'constant',
axis: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the magnitude-squared coherence between two signals using a PyTorch
implementation. This function gives identical results to the
scipy.signal.coherence. \n
The primary difference in implementation between this and scipy's coherence
is that this uses an accumulation method for Welch's method, while scipy
just makes a large array with all the overlapping windows. Therefore, this
method uses less memory and is faster for large windows but is slower for
small windows and there is a very small amount of numerical error due to the
accumulation. \n
RH 2024
Args:
x (torch.Tensor):
First input signal.
y (torch.Tensor):
Second input signal.
fs (float):
Sampling frequency of the input signal. (Default is 1.0)
window (str):
Type of window to apply. Supported window types are the same as
`scipy.signal.get_window`. (Default is 'hann')
nperseg (Optional[int]):
Length of each segment. (Default is ``None``, which uses ``len(x) //
8``)
noverlap (Optional[int]):
Number of points to overlap between segments. (Default is ``None``,
which uses ``nperseg // 2``)
nfft (Optional[int]):
Number of points in the FFT used for each segment. (Default is
``None``, which sets it equal to `nperseg`)
detrend (str):
Specifies how to detrend each segment. Supported values are
'constant' or 'linear'. (Default is 'constant')
axis (int):
Axis along which the coherence is calculated. (Default is -1)
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- freqs (torch.Tensor): Frequencies for which the coherence is computed.
- coherence (torch.Tensor): Magnitude-squared coherence values.
Example:
.. highlight:: python
.. code-block:: python
x = torch.randn(1024)
y = torch.randn(1024)
freqs, coherence = torch_coherence(x, y, fs=256)
"""
## Convert axis to positive
axis = axis % len(x.shape)

## Check dimensions
### They should either be the same or one of them should be 1
if not (x.shape == y.shape):
assert all([x.shape[ii] in [1, y.shape[ii]] for ii in range(len(x.shape))]), f"x and y should have the same shape or one of them should have shape 1 at each dimension. Found x.shape={x.shape} and y.shape={y.shape}"

if nperseg is None:
nperseg = len(x) // 8

if noverlap is None:
noverlap = nperseg // 2

if nfft is None:
nfft = nperseg

if window is not None:
window = scipy.signal.get_window(window, nperseg)
window = torch.tensor(window, dtype=x.dtype, device=x.device)

## Detrend the signals
def detrend_constant(y, axis):
y = y - torch.mean(y, axis=axis, keepdim=True)
return y
def detrend_linear(y, axis):
"""
Uses least squares approach to remove linear trend.
"""
## Move axis to end
y_dims_to = [ii for ii in range(len(y.shape)) if ii != axis] + [axis]
y = y.permute(*y_dims_to)
n = y.shape[-1]
## Prepare the design matrix
X = torch.ones(n, 2, dtype=y.dtype, device=y.device)
X[:, 1] = torch.arange(n, dtype=y.dtype, device=y.device)
## Compute the coefficients
beta = torch.linalg.lstsq(X, y)[0]
## Remove the trend
y = y - X @ beta
## Move axis back to original position (argsort y_dims_to)
y_dims_from = [y_dims_to.index(ii) for ii in range(len(y.shape))]
y = y.permute(*y_dims_from)
return y

if detrend == 'constant':
fn_detrend = detrend_constant
elif detrend == 'linear':
fn_detrend = detrend_linear
else:
raise ValueError(f"detrend must be 'constant' or 'linear'. Found {detrend}")

## Initialize the coherence arrays
### Get broadcasted dimensions: max(x, y) at each dimension, and nfft at axis
x_shape = list(x.shape)
y_shape = list(y.shape)
out_shape = [max(x_shape[i], y_shape[i]) for i in range(len(x_shape))]
out_shape[axis] = nfft // 2 + 1 ## rfft returns only non-negative frequencies (0 to fs/2 inclusive )

## Initialize sums for Welch's method
### Prepare complex dtype
dtype_complex = x.dtype.to_complex()
f_cross_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device)
psd1_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device)
psd2_sum = torch.zeros(out_shape, dtype=dtype_complex, device=x.device)

## Perform Welch's averaging of FFT segments
num_segments = (x.shape[axis] - nperseg) // (nperseg - noverlap) + 1
### Pad window with [None] dims to match x and y
window = window[(None,) * axis + (slice(None),) + (None,) * (len(x.shape) - axis - 1)]
for ii in range(num_segments):
start = ii * (nperseg - noverlap)
end = start + nperseg
fn_get_segment = lambda x, axis, start, end: torch.fft.rfft(fn_detrend(torch_helpers.slice_along_dim(x, axis, slice(start, end)), axis=axis) * window, n=nfft, dim=axis)
segment1 = fn_get_segment(x, axis, start, end)
segment2 = fn_get_segment(y, axis, start, end)
f_cross_sum += torch.conj(segment1) * segment2
psd1_sum += torch.conj(segment1) * segment1
psd2_sum += torch.conj(segment2) * segment2

## Averaging the sums
f_cross = f_cross_sum / num_segments
psd1 = psd1_sum.real / num_segments
psd2 = psd2_sum.real / num_segments

## Compute coherence
coherence = torch.abs(f_cross) ** 2 / (psd1 * psd2)

## Generate frequency axis
freqs = np.fft.rfftfreq(nfft, d=1 / fs)

## Take the positive part of the frequency spectrum
### NOTE: This is not necessary as the coherence is symmetric (always odd and real)
# pos_mask = freqs >= 0
# ### slice along axis
# freqs = freqs[pos_mask]
# coherence = torch_helpers.slice_along_dim(coherence, axis=axis, idx=pos_mask)

return freqs, coherence
170 changes: 170 additions & 0 deletions bnpm/tests/test_coherence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import pytest
import torch
import numpy as np
import scipy.signal
from hypothesis import given, strategies as st

from ..spectral import torch_coherence

# Test with basic sinusoidal inputs
def test_basic_functionality():
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

# Use default parameters
fs = 1.0
nperseg = 256

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), "Coherence values do not match closely enough."

# Test varying sampling frequencies
@pytest.mark.parametrize("fs", [0.5, 1.0, 2.0, 10.0])
def test_varying_sampling_frequency(fs):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

nperseg = 256

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough for fs={fs}."

# Test different window types
@pytest.mark.parametrize("window", ['hann', 'hamming', 'blackman'])
def test_different_window_types(window):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
nperseg = 256

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, window=window, nperseg=nperseg)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, window=window, nperseg=nperseg)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough with window type={window}."

# Test varying segment lengths
@pytest.mark.parametrize("nperseg", [128, 256, 512])
def test_varying_segment_lengths(nperseg):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
window = 'hann'

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, window=window, nperseg=nperseg)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, window=window, nperseg=nperseg)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough for segment length={nperseg}."

# Test varying overlap sizes
@pytest.mark.parametrize("noverlap", [0, 1, 2, 64, 128, 192])
def test_overlap_sizes(noverlap):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
nperseg = 256 # Fixed segment length for consistency in comparison

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg, noverlap=noverlap)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, noverlap=noverlap)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough for overlap size={noverlap}."

# Test varying FFT lengths
@pytest.mark.parametrize("nfft", [256, 512, 1024])
def test_fft_lengths(nfft):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
nperseg = 256 # Maintain constant segment size to isolate the effect of nfft

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg, nfft=nfft)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, nfft=nfft)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough for FFT length={nfft}."

# Test detrending methods
@pytest.mark.parametrize("detrend", ['constant', 'linear'])
def test_detrending_methods(detrend):
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) + np.linspace(0, 1, 1000) # Sinusoid with linear trend
y = np.sin(2 * np.pi * 5 * t + np.pi/4) + np.linspace(1, 0, 1000) # Sinusoid with inverse linear trend

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
nperseg = 256

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg, detrend=detrend)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, detrend=detrend)

# Check if the results are close enough
assert np.allclose(coherence_pytorch.numpy(), coherence_scipy, atol=1e-2), f"Coherence values do not match closely enough for detrend method={detrend}."

# Test multi-dimensional input
def test_multi_dimensional_input():
np.random.seed(0) # For reproducibility
t = np.linspace(0, 1, 1000, endpoint=False)
x = np.sin(2 * np.pi * 5 * t) # 5 Hz sinusoid
y = np.sin(2 * np.pi * 5 * t + np.pi/4) # 5 Hz sinusoid, phase shifted

# Extend to 2D by repeating the array
x = np.tile(x, (10, 1))
y = np.tile(y, (10, 1))

x_torch = torch.tensor(x)
y_torch = torch.tensor(y)

fs = 1.0
nperseg = 256

freqs_pytorch, coherence_pytorch = torch_coherence(x_torch, y_torch, fs=fs, nperseg=nperseg)
freqs_scipy, coherence_scipy = scipy.signal.coherence(x, y, fs=fs, nperseg=nperseg, axis=1)

# Check if the results are close enough, comparing each ensemble member's coherence
for i in range(10):
assert np.allclose(coherence_pytorch[i].numpy(), coherence_scipy[i], atol=1e-2), "Coherence values do not match for multi-dimensional input."

0 comments on commit 229b460

Please sign in to comment.