diff --git a/wavenet_model.py b/wavenet_model.py index 5e3c32cc7..3ecdd2761 100644 --- a/wavenet_model.py +++ b/wavenet_model.py @@ -210,9 +210,11 @@ def generate(self, print("pad zero") for i in range(num_samples): - input = generated[-self.receptive_field:].view(1, 1, -1) + input = Variable(torch.FloatTensor(1, self.classes, self.receptive_field).zero_()) + input = input.scatter_(1, generated[-self.receptive_field:].view(1, -1, self.receptive_field), 1.) + x = self.wavenet(input, - dilation_func=self.wavenet_dilate)[-1, :].squeeze() + dilation_func=self.wavenet_dilate)[:, :, -1].squeeze() if temperature > 0: x /= temperature @@ -220,21 +222,17 @@ def generate(self, prob = prob.cpu() np_prob = prob.data.numpy() x = np.random.choice(self.classes, p=np_prob) - x = np.array([x]) - - soft_o = F.softmax(x) - soft_o = soft_o.cpu() - np_o = soft_o.data.numpy() - s = np.random.choice(self.num_classes, p=np_o) - s = Variable(torch.FloatTensor([s])) - s = (s / self.num_classes) * 2. - 1 + x = Variable(torch.LongTensor([x]))#np.array([x]) else: - max = torch.max(x, 0)[1].float() - s = (max / self.num_classes) * 2. - 1 # new sample + x = torch.max(x, 0)[1].float() + + generated = torch.cat((generated, x), 0) + + generated = (generated / self.classes) * 2. - 1 + mu_gen = mu_law_expansion(generated, self.classes) - generated = torch.cat((generated, s), 0) self.train() - return generated.data.tolist() + return mu_gen def generate_fast(self, num_samples,