diff --git a/forest/oak/base.py b/forest/oak/base.py index f2751833..fcb02f81 100644 --- a/forest/oak/base.py +++ b/forest/oak/base.py @@ -184,8 +184,12 @@ def compute_interpolate_cwt(tapered_bout: np.ndarray, fs: int = 10, # interpolate coefficients freqs = out[2] freqs_interp = np.arange(0.5, 4.5, 0.05) - ip = interpolate.interp2d(range(coefs.shape[1]), freqs, coefs) - coefs_interp = ip(range(coefs.shape[1]), freqs_interp) + interpolator = interpolate.RegularGridInterpolator( + (freqs, range(coefs.shape[1])), coefs + ) + grid_x, grid_y = np.meshgrid(freqs_interp, range(coefs.shape[1]), + indexing='ij') + coefs_interp = interpolator((grid_x, grid_y)) # trim spectrogram from the coi coefs_interp = coefs_interp[:, 5*fs:-5*fs]