Skip to content

Commit

Permalink
fix encoding in the generate() function
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentherrmann committed Mar 23, 2018
1 parent 30ad855 commit 2b7bfb2
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions wavenet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,31 +210,29 @@ def generate(self,
print("pad zero")

for i in range(num_samples):
input = generated[-self.receptive_field:].view(1, 1, -1)
input = Variable(torch.FloatTensor(1, self.classes, self.receptive_field).zero_())
input = input.scatter_(1, generated[-self.receptive_field:].view(1, -1, self.receptive_field), 1.)

x = self.wavenet(input,
dilation_func=self.wavenet_dilate)[-1, :].squeeze()
dilation_func=self.wavenet_dilate)[:, :, -1].squeeze()

if temperature > 0:
x /= temperature
prob = F.softmax(x, dim=0)
prob = prob.cpu()
np_prob = prob.data.numpy()
x = np.random.choice(self.classes, p=np_prob)
x = np.array([x])

soft_o = F.softmax(x)
soft_o = soft_o.cpu()
np_o = soft_o.data.numpy()
s = np.random.choice(self.num_classes, p=np_o)
s = Variable(torch.FloatTensor([s]))
s = (s / self.num_classes) * 2. - 1
x = Variable(torch.LongTensor([x]))#np.array([x])
else:
max = torch.max(x, 0)[1].float()
s = (max / self.num_classes) * 2. - 1 # new sample
x = torch.max(x, 0)[1].float()

generated = torch.cat((generated, x), 0)

generated = (generated / self.classes) * 2. - 1
mu_gen = mu_law_expansion(generated, self.classes)

generated = torch.cat((generated, s), 0)
self.train()
return generated.data.tolist()
return mu_gen

def generate_fast(self,
num_samples,
Expand Down

0 comments on commit 2b7bfb2

Please sign in to comment.