Skip to content

Commit

Permalink
debug coherence linear detrend method
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Apr 17, 2024
1 parent 229b460 commit e671a6f
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions bnpm/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import matplotlib.pyplot as plt
import torch
import opt_einsum

from . import circular
from . import misc
Expand Down Expand Up @@ -539,6 +540,10 @@ def torch_coherence(
method uses less memory and is faster for large windows but is slower for
small windows and there is a very small amount of numerical error due to the
accumulation. \n
Speed: The 'linear' detrending method is not fast on GPU, despite the
implementation being similar. 'constant' is roughly 3x as fast as 'linear'
on CPU. \n
RH 2024
Expand Down Expand Up @@ -605,21 +610,23 @@ def torch_coherence(
def detrend_constant(y, axis):
y = y - torch.mean(y, axis=axis, keepdim=True)
return y

X_linearDetrendPrep = torch.ones(nfft, 2, dtype=y.dtype, device=y.device)
X_linearDetrendPrep[:, 1] = torch.arange(nfft, dtype=y.dtype, device=y.device)
def detrend_linear(y, axis):
"""
Uses least squares approach to remove linear trend.
"""
## Move axis to end
y_dims_to = [ii for ii in range(len(y.shape)) if ii != axis] + [axis]
y = y.permute(*y_dims_to)
n = y.shape[-1]
y = y.permute(*y_dims_to)[..., None]
## Prepare the design matrix
X = torch.ones(n, 2, dtype=y.dtype, device=y.device)
X[:, 1] = torch.arange(n, dtype=y.dtype, device=y.device)
X = X_linearDetrendPrep[*([None] * (len(y.shape) - 2))]
## Compute the coefficients
beta = torch.linalg.lstsq(X, y)[0]
beta = torch.linalg.lstsq(X, y)[0]
## Remove the trend
y = y - X @ beta
y = y - opt_einsum.contract('...ij, ...jk -> ...ik', X, beta)
y = y[..., 0]
## Move axis back to original position (argsort y_dims_to)
y_dims_from = [y_dims_to.index(ii) for ii in range(len(y.shape))]
y = y.permute(*y_dims_from)
Expand Down

0 comments on commit e671a6f

Please sign in to comment.