Skip to content

Commit

Permalink
make VQT a torch nn module and play nice with jit script
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Feb 8, 2024
1 parent 42c5f6d commit a253147
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 123 deletions.
8 changes: 6 additions & 2 deletions vqt/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
import numpy as np
import scipy.signal
Expand Down Expand Up @@ -48,7 +50,7 @@ def make_batches(
l = length

if batch_size is None:
batch_size = np.int64(np.ceil(l / num_batches))
batch_size = int(math.ceil(l / num_batches))

for start in range(idx_start, l, batch_size):
end = min(start + batch_size, l)
Expand Down Expand Up @@ -332,7 +334,7 @@ def make_VQT_filters(
mother_wave = scipy.signal.windows.get_window(window=window_type, Nx=resolution, fftbins=False)

wins, xs = make_scaled_wave_basis(mother_wave, lens_waves=scales, lens_windows=win_size)
wins = torch.tensor(np.stack(wins, axis=0), dtype=torch.float32)
wins = torch.as_tensor(np.stack(wins, axis=0), dtype=torch.float32)

elif isinstance(window_type, (np.ndarray, list, tuple)):
mother_wave = np.array(window_type, dtype=np.float32)
Expand Down Expand Up @@ -361,6 +363,8 @@ def make_VQT_filters(
## Normalize filters to have unit magnitude
filts_complex = filts_complex / torch.linalg.norm(filts_complex, ord=2, dim=1, keepdim=True)

freqs = torch.as_tensor(freqs, dtype=torch.float32)

## Plot
if plot_pref:
plt.figure()
Expand Down
Loading

0 comments on commit a253147

Please sign in to comment.