Skip to content

Commit

Permalink
Merge pull request #318 from neurodsp-tools/zeropad
Browse files Browse the repository at this point in the history
[ENH] Zero padding + Welch's PSD
  • Loading branch information
TomDonoghue authored May 14, 2024
2 parents aef60f3 + b8c861d commit b52f552
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 27 deletions.
18 changes: 10 additions & 8 deletions neurodsp/spectral/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,29 @@ def check_mt_settings(n_samples, fs, bandwidth, n_tapers):
fs : float
Sampling rate, in Hz.
bandwidth : float or None
Bandwidth of the multitaper window, in Hz. If None, will use
8 * fs / n_samples.
Bandwidth of the multitaper window, in Hz.
If None, will use 8 * fs / n_samples.
n_tapers : int or None
Number of tapers to use. If None, will use bandwidth * n_samples / fs
Number of tapers to use.
If None, will use bandwidth * n_samples / fs.
Returns
-------
nw : float
Standardized half bandwidth (used to compute DPSS)
Standardized half bandwidth (used to compute DPSS).
n_tapers : int
Number of tapers.
"""
"""

# set bandwidth
if bandwidth is None:
bandwidth = 8 * fs / n_samples # MNE default
bandwidth = 8 * fs / n_samples # MNE default

# check bandwidth - break if alpha < 1
alpha = n_samples * bandwidth / (fs * 2)
if alpha < 1:
raise ValueError("Bandwidth too narrow for signal length and sampling rate. Try increasing bandwidth. n_samples * bandwidth / (fs * 2) must be >1")
raise ValueError("Bandwidth too narrow for signal length and sampling rate. "
"Try increasing bandwidth. n_samples * bandwidth / (fs * 2) must be >1.")

# compute nw
nw = bandwidth * n_samples / (fs * 2)
Expand All @@ -83,4 +85,4 @@ def check_mt_settings(n_samples, fs, bandwidth, n_tapers):
if n_tapers is None:
n_tapers = int(2 * nw)

return nw, n_tapers
return nw, n_tapers
54 changes: 36 additions & 18 deletions neurodsp/spectral/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

import numpy as np
from scipy.signal import spectrogram, medfilt
from scipy.fft import next_fast_len

from neurodsp.utils.core import get_avg_func
from neurodsp.utils.data import create_freqs
from neurodsp.utils.decorators import multidim
from neurodsp.utils.checks import check_param_options
from neurodsp.utils.outliers import discard_outliers
from neurodsp.timefrequency.wavelets import compute_wavelet_transform
from neurodsp.spectral.utils import trim_spectrum
from neurodsp.spectral.utils import trim_spectrum, window_pad
from neurodsp.spectral.checks import check_spg_settings, check_mt_settings

###################################################################################################
Expand Down Expand Up @@ -70,7 +71,8 @@ def compute_spectrum(sig, fs, method='welch', **kwargs):


SPECTRUM_INPUTS = {
'welch' : ['avg_type', 'window', 'nperseg', 'noverlap', 'f_range', 'outlier_percent'],
'welch' : ['avg_type', 'window', 'nperseg', 'noverlap', 'nfft', \
'fast_len', 'f_range', 'outlier_percent'],
'wavelet' : ['freqs', 'avg_type', 'n_cycles', 'scaling', 'norm'],
'medfilt' : ['filt_len', 'f_range'],
}
Expand Down Expand Up @@ -136,8 +138,8 @@ def compute_spectrum_wavelet(sig, fs, freqs, avg_type='mean', **kwargs):


def compute_spectrum_welch(sig, fs, avg_type='mean', window='hann',
nperseg=None, noverlap=None,
f_range=None, outlier_percent=None):
nperseg=None, noverlap=None, nfft=None,
fast_len=False, f_range=None, outlier_percent=None):
"""Compute the power spectral density using Welch's method.
Parameters
Expand All @@ -161,6 +163,12 @@ def compute_spectrum_welch(sig, fs, avg_type='mean', window='hann',
noverlap : int, optional
Number of points to overlap between segments.
If None, noverlap = nperseg // 8.
nfft : int, optional
Number of samples per window. Requires nfft > nperseg.
Windows are zero-padded by the difference, nfft - nperseg.
fast_len : bool, optional, default: False
Moves nperseg to the fastest length to reduce computation.
See scipy.fft.next_fast_len for details.
f_range : list of [float, float], optional
Frequency range to sub-select from the power spectrum.
outlier_percent : float, optional
Expand Down Expand Up @@ -196,6 +204,18 @@ def compute_spectrum_welch(sig, fs, avg_type='mean', window='hann',

# Calculate the short time Fourier transform with signal.spectrogram
nperseg, noverlap = check_spg_settings(fs, window, nperseg, noverlap)

# Pad signal if requested
if nfft is not None and nfft < nperseg:
raise ValueError('nfft must be greater than nperseg.')
elif nfft is not None:
npad = nfft - nperseg
noverlap = nperseg // 8 if noverlap is None else noverlap
sig, nperseg, noverlap = window_pad(sig, nperseg, noverlap, npad, fast_len)
elif fast_len:
nperseg = next_fast_len(nperseg)

# Compute spectrogram
freqs, _, spg = spectrogram(sig, fs, window, nperseg, noverlap)

# Throw out outliers if indicated
Expand Down Expand Up @@ -272,14 +292,13 @@ def compute_spectrum_multitaper(sig, fs, bandwidth=None, n_tapers=None,
fs : float
Sampling rate, in Hz.
bandwidth : float, optional
Frequency bandwidth of multi-taper window function. Default is
8 * fs / n_samples.
Frequency bandwidth of multi-taper window function.
If not provided, defaults to 8 * fs / n_samples.
n_tapers : int, optional
Number of slepian windows used to compute the spectrum. Default is
bandwidth * n_samples / fs.
low_bias : bool, optional
If True, only use tapers with concentration ratio > 0.9. Default is
True.
Number of slepian windows used to compute the spectrum.
If not provided, defaults to bandwidth * n_samples / fs.
low_bias : bool, optional, default: True
If True, only use tapers with concentration ratio > 0.9.
eigenvalue_weighting : bool, optional
If True, weight spectral estimates by the concentration ratio of
their respective tapers before combining. Default is True.
Expand All @@ -293,8 +312,7 @@ def compute_spectrum_multitaper(sig, fs, bandwidth=None, n_tapers=None,
Examples
--------
Compute the power spectrum of a simulated time series using the
multitaper method:
Compute the power spectrum of a simulated time series using the multitaper method:
>>> from neurodsp.sim import sim_combined
>>> sig = sim_combined(n_seconds=10, fs=500,
Expand All @@ -311,19 +329,19 @@ def compute_spectrum_multitaper(sig, fs, bandwidth=None, n_tapers=None,
nw, n_tapers = check_mt_settings(sig_len, fs, bandwidth, n_tapers)

# Create slepian sequences
slepian_sequences, ratios = dpss(sig_len, nw, n_tapers,
return_ratios=True)
slepian_sequences, ratios = dpss(sig_len, nw, n_tapers, return_ratios=True)

# Drop tapers with low concentration
if low_bias:
slepian_sequences = slepian_sequences[ratios > 0.9]
ratios = ratios[ratios > 0.9]
if len(slepian_sequences) == 0:
raise ValueError('No tapers with concentration ratio > 0.9. Could not compute spectrum with low_bias=True.')
raise ValueError("No tapers with concentration ratio > 0.9. "
"Could not compute spectrum with low_bias=True.")

# Compute fourier on signal weighted by each slepian sequence
# Compute Fourier transform on signal weighted by each slepian sequence
freqs = np.fft.rfftfreq(sig_len, 1. /fs)
spectra = np.abs(np.fft.rfft(slepian_sequences[:, np.newaxis]*sig))**2
spectra = np.abs(np.fft.rfft(slepian_sequences[:, np.newaxis] * sig)) ** 2

# combine estimates to compute final spectrum
if eigenvalue_weighting:
Expand Down
117 changes: 117 additions & 0 deletions neurodsp/spectral/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility function for neurodsp.spectral."""

import numpy as np
from scipy.fft import next_fast_len

# Alias a function that has moved, for backwards compatibility
from neurodsp.sim.utils import rotate_spectrum as rotate_powerlaw
Expand Down Expand Up @@ -127,3 +128,119 @@ def trim_spectrogram(freqs, times, spg, f_range=None, t_range=None):
times_ext = times

return freqs_ext, times_ext, spg_ext


def window_pad(sig, nperseg, noverlap, npad, fast_len,
nwindows=None, nsamples=None, pad_left=None, pad_right=None):
"""Pads windows (for Welch's PSD) with zeros.
Parameters
----------
sig : 1d or 2d array
Time series.
nperseg : int
Length of each segment, in number of samples, at the beginning and end of each window.
noverlap : int
Number of points to overlap between segments, applied prior to zero padding.
npad : int
Number of samples to zero pad windows per side.
fast_len : bool, optional
Moves nperseg to the fastest length to reduce computation.
Adjusts zero-padding to account for the new nperseg.
See scipy.fft.next_fast_len for details.
nwindows, nsamples, pad_left, pad_right : int, optional, default: None
Prevents redundant computation when sig is 2d.
Returns
-------
sig_windowed : 1d or 2d array
Windowed signal, with zeros padded at the around each window.
"""

if sig.ndim == 2:
# Determine the number of samples and padding once,
# to prevent redundant computation in the loop
nwindows = int(np.ceil(len(sig[0])/nperseg))
if nsamples is None or pad_left is None or pad_right is None:
nsamples, pad_left, pad_right = _find_pad_size(
nperseg, npad, fast_len
)

# Recursively call window_pad on each signal
for sind, csig in enumerate(sig):

_sig_win, _nperseg, _noverlap = window_pad(
# Required arguments
csig, nperseg, noverlap, npad, fast_len,
# Optional arguments to prevent redundant computation
nwindows, nsamples, pad_left, pad_right,
)

if sind == 0:
# Initialize windowed array
sig_windowed = np.zeros((len(sig), len(_sig_win)))

sig_windowed[sind] = _sig_win

# Update nperseg and noverlap
nperseg, noverlap = _nperseg, _noverlap

else:

# Compute the number of windows, samples, and padding.
# Do not recompute if called from the 2d case
if nwindows is None:
nwindows = int(np.ceil(len(sig) / nperseg))

if nsamples is None or pad_left is None or pad_right is None:
# Skipped if called from the 2d case
nsamples, pad_left, pad_right = _find_pad_size(
nperseg, npad, fast_len
)

# Window signal
sig_windowed = np.zeros((nwindows, nsamples))

for wind in range(nwindows):

# Signal indices
start = max(0, (wind * nperseg) - noverlap)
end = min(len(sig), start + nperseg)

if end - start != nperseg:
# Stop if a full window can't be created at end of signal
break

# Pad
sig_windowed[wind] = np.pad(sig[start:end], (pad_left, pad_right))

# Removed incomplete windows and flatten
sig_windowed = sig_windowed[:wind].flatten()

# Update nperseg
nperseg += (pad_left + pad_right)

# Overlap is zero since overlapping segments was applied prior to padding each window
noverlap = 0

return sig_windowed, nperseg, noverlap


def _find_pad_size(nperseg, npad, fast_len):
"""Determine pad size and number of samples required."""

nsamples = nperseg + npad

pad_left = npad // 2
pad_right = npad - pad_left

if fast_len:
# Increase nsamples to the next fastest length and update for zero-padding size
nsamples = next_fast_len(nsamples)

# New padding
npad = nsamples - nperseg
pad_left = npad // 2
pad_right = npad - pad_left

return nsamples, pad_left, pad_right
6 changes: 6 additions & 0 deletions neurodsp/tests/spectral/test_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ def test_compute_spectrum_welch(tsig, tsig_sine):
expected_answer = np.zeros_like(psd_welch[0:FREQ_SINE])
assert np.allclose(psd_welch[0:FREQ_SINE], expected_answer, atol=EPS)

# Test zero padding
freqs, spectrum = compute_spectrum(
np.tile(tsig, (2, 1)), FS, nperseg=100, noverlap=0, nfft=1000, f_range=(1, 200)
)
assert np.all(spectrum[0] == spectrum[1])

def test_compute_spectrum_wavelet(tsig):

freqs, spectrum = compute_spectrum_wavelet(tsig, FS, freqs=FREQS_ARR, avg_type='mean')
Expand Down
26 changes: 25 additions & 1 deletion neurodsp/tests/spectral/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Tests for neurodsp.spectral.utils."""

import pytest

import numpy as np
from numpy.testing import assert_equal

from neurodsp.tests.settings import FS

from neurodsp.spectral.utils import *

###################################################################################################
Expand Down Expand Up @@ -40,3 +41,26 @@ def test_trim_spectrogram():
f_ext, t_ext, p_ext = trim_spectrogram(freqs, times, pows, f_range=[6, 8], t_range=None)
assert_equal(f_ext, np.array([6, 7, 8]))
assert_equal(t_ext, times)


@pytest.mark.parametrize("fast_len", [True, False])
def test_window_pad(fast_len):

nperseg = 100
noverlap = 10
npad = 1000

sig = np.random.rand(1000)

sig_windowed, _nperseg, _noverlap = window_pad(sig, nperseg, noverlap, npad, fast_len)

# Overlap was handled correctly b/w the first two windows
assert np.all(sig_windowed[npad:npad+nperseg][-noverlap:] ==
sig_windowed[(3*npad)+nperseg:(3*npad)+nperseg+noverlap])

# Updated nperseg has no remainder
nwin = (len(sig_windowed) / nperseg)
assert nwin == int(nwin)

# Ensure updated nperseg is correct
assert _nperseg == nperseg + npad

0 comments on commit b52f552

Please sign in to comment.