Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated PyTorch version from 0.3 to 1.6 #38

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pytorch-wavenet
This is an implementation of the WaveNet architecture, as described in the [original paper](https://arxiv.org/abs/1609.03499).
Updated to work on PyTorch 1.6.

## Features
- Automatic creation of a dataset (training and validation/test set) from all sound files (.wav, .aiff, .mp3) in a directory
Expand All @@ -9,7 +10,7 @@ This is an implementation of the WaveNet architecture, as described in the [orig

## Requirements
- python 3
- pytorch 0.3
- pytorch 1.6
- numpy
- librosa
- jupyter
Expand Down
18 changes: 10 additions & 8 deletions wavenet_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import os.path
import time
from wavenet_modules import *
import torch.nn.functional as F
from audio_data import *
from torch import nn
from wavenet_modules import *


class WaveNetModel(nn.Module):
Expand All @@ -25,6 +27,7 @@ class WaveNetModel(nn.Module):
- Output: :math:`()`
L should be the length of the receptive field
"""

def __init__(self,
layers=10,
blocks=4,
Expand Down Expand Up @@ -109,9 +112,9 @@ def __init__(self,
new_dilation *= 2

self.end_conv_1 = nn.Conv1d(in_channels=skip_channels,
out_channels=end_channels,
kernel_size=1,
bias=True)
out_channels=end_channels,
kernel_size=1,
bias=True)

self.end_conv_2 = nn.Conv1d(in_channels=end_channels,
out_channels=classes,
Expand Down Expand Up @@ -153,7 +156,7 @@ def wavenet(self, input, dilation_func):
# parametrized skip connection
s = x
if x.size(2) != 1:
s = dilate(x, 1, init_dilation=dilation)
s = dilate(x, 1, init_dilation=dilation)
s = self.skip_convs[i](s)
try:
skip = skip[:, :, -s.size(2):]
Expand Down Expand Up @@ -222,7 +225,7 @@ def generate(self,
prob = prob.cpu()
np_prob = prob.data.numpy()
x = np.random.choice(self.classes, p=np_prob)
x = Variable(torch.LongTensor([x]))#np.array([x])
x = Variable(torch.LongTensor([x])) # np.array([x])
else:
x = torch.max(x, 0)[1].float()

Expand Down Expand Up @@ -301,7 +304,7 @@ def generate_fast(self,
input.zero_()
input = input.scatter_(1, x.view(1, -1, 1), 1.).view(1, self.classes, 1)

if (i+1) == 100:
if (i + 1) == 100:
toc = time.time()
print("one generating step does take approximately " + str((toc - tic) * 0.01) + " seconds)")

Expand All @@ -314,7 +317,6 @@ def generate_fast(self,
mu_gen = mu_law_expansion(generated, self.classes)
return mu_gen


def parameter_count(self):
par = list(self.parameters())
s = sum([np.prod(list(d.size())) for d in par])
Expand Down
64 changes: 17 additions & 47 deletions wavenet_modules.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.autograd import Variable, Function
import numpy as np
from torch.nn import Parameter


def dilate(x, dilation, init_dilation=1, pad_start=True):
Expand All @@ -24,7 +24,7 @@ def dilate(x, dilation, init_dilation=1, pad_start=True):
new_l = int(np.ceil(l / dilation_factor) * dilation_factor)
if new_l != l:
l = new_l
x = constant_pad_1d(x, new_l, dimension=2, pad_start=pad_start)
x = constant_pad_1d(x, new_l, pad_start=pad_start)

l_old = int(round(l / dilation_factor))
n_old = int(round(n * dilation_factor))
Expand Down Expand Up @@ -77,51 +77,21 @@ def reset(self):
self.out_pos = 0


class ConstantPad1d(Function):
def __init__(self, target_size, dimension=0, value=0, pad_start=False):
super(ConstantPad1d, self).__init__()
self.target_size = target_size
self.dimension = dimension
self.value = value
self.pad_start = pad_start

def forward(self, input):
self.num_pad = self.target_size - input.size(self.dimension)
assert self.num_pad >= 0, 'target size has to be greater than input size'

self.input_size = input.size()

size = list(input.size())
size[self.dimension] = self.target_size
output = input.new(*tuple(size)).fill_(self.value)
c_output = output

# crop output
if self.pad_start:
c_output = c_output.narrow(self.dimension, self.num_pad, c_output.size(self.dimension) - self.num_pad)
else:
c_output = c_output.narrow(self.dimension, 0, c_output.size(self.dimension) - self.num_pad)

c_output.copy_(input)
return output

def backward(self, grad_output):
grad_input = grad_output.new(*self.input_size).zero_()
cg_output = grad_output

# crop grad_output
if self.pad_start:
cg_output = cg_output.narrow(self.dimension, self.num_pad, cg_output.size(self.dimension) - self.num_pad)
else:
cg_output = cg_output.narrow(self.dimension, 0, cg_output.size(self.dimension) - self.num_pad)

grad_input.copy_(cg_output)
return grad_input


def constant_pad_1d(input,
target_size,
dimension=0,
value=0,
pad_start=False):
return ConstantPad1d(target_size, dimension, value, pad_start)(input)
"""
Assumes that padded dim is the 2, based on pytorch specification.
Input: (N,C,Win)(N, C, W_{in})(N,C,Win​)
Output: (N,C,Wout)(N, C, W_{out})(N,C,Wout​) where
:param input:
:param target_size:
:param value:
:param pad_start:
:return:
"""
num_pad = target_size - input.size(2)
assert num_pad >= 0, 'target size has to be greater than input size'
padding = (num_pad, 0) if pad_start else (0, num_pad)
return torch.nn.ConstantPad1d(padding, value)(input)