Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Variable-Q Transform #113

Merged
merged 1 commit into from
Dec 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Installation/nnAudio/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .griffin_lim import *
from .mel import *
from .stft import *
from .vqt import *
202 changes: 202 additions & 0 deletions Installation/nnAudio/features/vqt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import torch
import torch.nn as nn
import numpy as np
from time import time
from ..librosa_functions import *
from ..utils import *


class VQT(torch.nn.Module):
def __init__(
self,
sr=22050,
hop_length=512,
fmin=32.70,
fmax=None,
n_bins=84,
filter_scale=1,
bins_per_octave=12,
norm=True,
basis_norm=1,
gamma=0,
window='hann',
pad_mode='reflect',
earlydownsample=True,
trainable=False,
output_format='Magnitude',
verbose=True
):

super().__init__()

self.norm = norm
self.hop_length = hop_length
self.pad_mode = pad_mode
self.n_bins = n_bins
self.earlydownsample = earlydownsample
self.trainable = trainable
self.output_format = output_format
self.filter_scale = filter_scale
self.bins_per_octave = bins_per_octave
self.sr = sr
self.gamma = gamma
self.basis_norm = basis_norm

# It will be used to calculate filter_cutoff and creating CQT kernels
Q = float(filter_scale)/(2**(1/bins_per_octave)-1)

# Creating lowpass filter and make it a torch tensor
if verbose==True:
print("Creating low pass filter ...", end='\r')
start = time()
lowpass_filter = torch.tensor(create_lowpass_filter(
band_center = 0.50,
kernelLength=256,
transitionBandwidth=0.001)
)

self.register_buffer('lowpass_filter', lowpass_filter[None,None,:])
if verbose == True:
print("Low pass filter created, time used = {:.4f} seconds".format(time()-start))

n_filters = min(bins_per_octave, n_bins)
self.n_filters = n_filters
self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
if verbose == True:
print("num_octave = ", self.n_octaves)

self.fmin_t = fmin * 2 ** (self.n_octaves - 1)
remainder = n_bins % bins_per_octave

if remainder==0:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave)
else:
# Calculate the top bin frequency
fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave)

# Adjusting the top minimum bins
self.fmin_t = fmax_t / 2 ** (1 - 1 / bins_per_octave)
if fmax_t > sr/2:
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \
please reduce the n_bins'.format(fmax_t))

if self.earlydownsample == True: # Do early downsampling if this argument is True
if verbose == True:
print("Creating early downsampling filter ...", end='\r')
start = time()
sr, self.hop_length, self.downsample_factor, early_downsample_filter, \
self.earlydownsample = get_early_downsample_params(sr,
hop_length,
fmax_t,
Q,
self.n_octaves,
verbose)
self.register_buffer('early_downsample_filter', early_downsample_filter)

if verbose==True:
print("Early downsampling filter created, \
time used = {:.4f} seconds".format(time()-start))
else:
self.downsample_factor = 1.

# For normalization in the end
# The freqs returned by create_cqt_kernels cannot be used
# Since that returns only the top octave bins
# We need the information for all freq bin
alpha = 2.0 ** (1.0 / bins_per_octave) - 1.0
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave))
self.frequencies = freqs
lenghts = np.ceil(Q * sr / (freqs + gamma / alpha))

# get max window length depending on gamma value
max_len = int(max(lenghts))
self.n_fft = int(2 ** (np.ceil(np.log2(max_len))))

lenghts = torch.tensor(lenghts).float()
self.register_buffer('lenghts', lenghts)


def forward(self, x, output_format=None, normalization_type='librosa'):
"""
Convert a batch of waveforms to VQT spectrograms.

Parameters
----------
x : torch tensor
Input signal should be in either of the following shapes.\n
1. ``(len_audio)``\n
2. ``(num_audio, len_audio)``\n
3. ``(num_audio, 1, len_audio)``
It will be automatically broadcast to the right shape
"""
output_format = output_format or self.output_format

x = broadcast_dim(x)
if self.earlydownsample==True:
x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor)
hop = self.hop_length
vqt = []

x_down = x # Preparing a new variable for downsampling
my_sr = self.sr

for i in range(self.n_octaves):
if i > 0:
x_down = downsampling_by_2(x_down, self.lowpass_filter)
my_sr /= 2
hop //= 2

else:
x_down = x

Q = float(self.filter_scale)/(2**(1/self.bins_per_octave)-1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this same as line 46? Both self.bins_per_octave and self.filter_scale seem to be unchanged inside the for loop. So maybe there is no need to recompute the Q here? Or am I misunderstanding something?

Copy link
Contributor Author

@gudgud96 gudgud96 Dec 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for replying late. You can understand this Q as an "initial value of Q".
Following librosa's implementation, the "variable" part actually lies in lengths of create_cqt_kernels-

Ours: https://github.com/KinWaiCheuk/nnAudio/pull/113/files#diff-8fea6e5f3e058527d6bbfe52b2dd2d9e756425b37d39bac07541f7f5cb1ccce5R424

Librosa's: https://github.com/librosa/librosa/blob/381efbd684c01ae372220526352d34fd732d3b1d/librosa/filters.py#L802

The gamma will change the filter lengths as compared to CQT. From what I understand, since filter length and Q are interrelated, by changing the lengths we can also view it as "variable-Q".

image


basis, self.n_fft, lengths, _ = create_cqt_kernels(Q,
my_sr,
self.fmin_t * 2 ** -i,
self.n_filters,
self.bins_per_octave,
norm=self.basis_norm,
topbin_check=False,
gamma=self.gamma)

cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1)
cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1)

if self.pad_mode == 'constant':
my_padding = nn.ConstantPad1d(cqt_kernels_real.shape[-1] // 2, 0)
elif self.pad_mode == 'reflect':
my_padding= nn.ReflectionPad1d(cqt_kernels_real.shape[-1] // 2)

cur_vqt = get_cqt_complex(x_down, cqt_kernels_real, cqt_kernels_imag, hop, my_padding)
vqt.insert(0, cur_vqt)

vqt = torch.cat(vqt, dim=1)
vqt = vqt[:,-self.n_bins:,:] # Removing unwanted bottom bins
vqt = vqt * self.downsample_factor

# Normalize again to get same result as librosa
if normalization_type == 'librosa':
vqt = vqt * torch.sqrt(self.lenghts.view(-1,1,1))
elif normalization_type == 'convolutional':
pass
elif normalization_type == 'wrap':
vqt *= 2
else:
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type)

if output_format=='Magnitude':
if self.trainable==False:
# Getting CQT Amplitude
return torch.sqrt(vqt.pow(2).sum(-1))
else:
return torch.sqrt(vqt.pow(2).sum(-1) + 1e-8)

elif output_format=='Complex':
return vqt

elif output_format=='Phase':
phase_real = torch.cos(torch.atan2(vqt[:,:,:,1], vqt[:,:,:,0]))
phase_imag = torch.sin(torch.atan2(vqt[:,:,:,1], vqt[:,:,:,0]))
return torch.stack((phase_real,phase_imag), -1)
20 changes: 12 additions & 8 deletions Installation/nnAudio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def rfft_fn(x, n=None, onesided=False):
else:
return torch.rfft(x, n, onesided=onesided)


## --------------------------- Filter Design ---------------------------##
def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2):
w_stacks = w.unsqueeze(-1).repeat((1, n_frames)).unsqueeze(0)
Expand Down Expand Up @@ -387,6 +386,8 @@ def create_cqt_kernels(
window="hann",
fmax=None,
topbin_check=True,
gamma=0,
pad_fft=True
):
"""
Automatically create CQT kernels in time domain
Expand Down Expand Up @@ -419,25 +420,28 @@ def create_cqt_kernels(
)
)

alpha = 2.0 ** (1.0 / bins_per_octave) - 1.0
lengths = np.ceil(Q * fs / (freqs + gamma / alpha))

# get max window length depending on gamma value
max_len = int(max(lengths))
fftLen = int(2 ** (np.ceil(np.log2(max_len))))

tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)

lengths = np.ceil(Q * fs / freqs)
for k in range(0, int(n_bins)):
freq = freqs[k]
l = np.ceil(Q * fs / freq)
l = lengths[k]

# Centering the kernels
if l % 2 == 1: # pad more zeros on RHS
start = int(np.ceil(fftLen / 2.0 - l / 2.0)) - 1
else:
start = int(np.ceil(fftLen / 2.0 - l / 2.0))

sig = (
get_window_dispatch(window, int(l), fftbins=True)
* np.exp(np.r_[-l // 2 : l // 2] * 1j * 2 * np.pi * freq / fs)
/ l
)
window_dispatch = get_window_dispatch(window, int(l), fftbins=True)
sig = window_dispatch * np.exp(np.r_[-l // 2 : l // 2] * 1j * 2 * np.pi * freq / fs) / l

if norm: # Normalizing the filter # Trying to normalize like librosa
tempKernel[k, start : start + int(l)] = sig / np.linalg.norm(sig, norm)
Expand Down
75 changes: 75 additions & 0 deletions Installation/tests/test_vqt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import librosa
import torch
import sys

sys.path.insert(0, "./")

import os

dir_path = os.path.dirname(os.path.realpath(__file__))

from nnAudio.features import CQT2010v2, VQT
import numpy as np
from parameters import *
import warnings

gpu_idx = 0 # Choose which GPU to use

# If GPU is avaliable, also test on GPU
if torch.cuda.is_available():
device_args = ["cpu", f"cuda:{gpu_idx}"]
else:
warnings.warn("GPU is not avaliable, testing only on CPU")
device_args = ["cpu"]

# librosa example audio for testing
y, sr = librosa.load(librosa.ex('choice'), duration=5)

@pytest.mark.parametrize("device", [*device_args])
def test_vqt_gamma_zero(device):
# nnAudio cqt
spec = CQT2010v2(sr=sr, verbose=False)
C2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa')
C2 = C2.numpy().squeeze()

# nnAudio vqt
spec = VQT(sr=sr, gamma=0, verbose=False)
V2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa')
V2 = V2.numpy().squeeze()

assert (C2 == V2).all() == True


@pytest.mark.parametrize("device", [*device_args])
def test_vqt(device):
for gamma in [0, 1, 2, 5, 10]:

# librosa vqt
V1 = np.abs(librosa.vqt(y, sr=sr, gamma=gamma))

# nnAudio vqt
spec = VQT(sr=sr, gamma=gamma, verbose=False)
V2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa')
V2 = V2.numpy().squeeze()

# NOTE: there will still be some diff between librosa and nnAudio vqt values (same as cqt)
# mainly due to the lengths of both - librosa uses float but nnAudio uses int
# this test aims to keep the diff range within a baseline threshold
vqt_diff = np.abs(V1 - V2)

if gamma == 0:
assert np.amin(vqt_diff) < 1e-8
assert np.amax(vqt_diff) < 0.6785
elif gamma == 1:
assert np.amin(vqt_diff) < 1e-8
assert np.amax(vqt_diff) < 0.6510
elif gamma == 2:
assert np.amin(vqt_diff) < 1e-8
assert np.amax(vqt_diff) < 0.5962
elif gamma == 5:
assert np.amin(vqt_diff) < 1e-8
assert np.amax(vqt_diff) < 0.3695
else:
assert np.amin(vqt_diff) < 1e-8
assert np.amax(vqt_diff) < 0.1