From e671a6f9ef6aa72328a5ab23f5720e54225fd02a Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Tue, 16 Apr 2024 21:01:01 -0400 Subject: [PATCH] debug coherence linear detrend method --- bnpm/spectral.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/bnpm/spectral.py b/bnpm/spectral.py index b497a25..9642f31 100644 --- a/bnpm/spectral.py +++ b/bnpm/spectral.py @@ -7,6 +7,7 @@ import numpy as np import matplotlib.pyplot as plt import torch +import opt_einsum from . import circular from . import misc @@ -539,6 +540,10 @@ def torch_coherence( method uses less memory and is faster for large windows but is slower for small windows and there is a very small amount of numerical error due to the accumulation. \n + + Speed: The 'linear' detrending method is not fast on GPU, despite the + implementation being similar. 'constant' is roughly 3x as fast as 'linear' + on CPU. \n RH 2024 @@ -605,21 +610,23 @@ def torch_coherence( def detrend_constant(y, axis): y = y - torch.mean(y, axis=axis, keepdim=True) return y + + X_linearDetrendPrep = torch.ones(nfft, 2, dtype=y.dtype, device=y.device) + X_linearDetrendPrep[:, 1] = torch.arange(nfft, dtype=y.dtype, device=y.device) def detrend_linear(y, axis): """ Uses least squares approach to remove linear trend. """ ## Move axis to end y_dims_to = [ii for ii in range(len(y.shape)) if ii != axis] + [axis] - y = y.permute(*y_dims_to) - n = y.shape[-1] + y = y.permute(*y_dims_to)[..., None] ## Prepare the design matrix - X = torch.ones(n, 2, dtype=y.dtype, device=y.device) - X[:, 1] = torch.arange(n, dtype=y.dtype, device=y.device) + X = X_linearDetrendPrep[*([None] * (len(y.shape) - 2))] ## Compute the coefficients - beta = torch.linalg.lstsq(X, y)[0] + beta = torch.linalg.lstsq(X, y)[0] ## Remove the trend - y = y - X @ beta + y = y - opt_einsum.contract('...ij, ...jk -> ...ik', X, beta) + y = y[..., 0] ## Move axis back to original position (argsort y_dims_to) y_dims_from = [y_dims_to.index(ii) for ii in range(len(y.shape))] y = y.permute(*y_dims_from)