-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathbtcvae.py
executable file
·74 lines (57 loc) · 2.71 KB
/
btcvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import math
def matrix_log_density_gaussian(x, mu, logvar):
batch_size, dim = x.shape
x = x.view(batch_size, 1, dim)
mu = mu.view(1, batch_size, dim)
logvar = logvar.view(1, batch_size, dim)
return log_density_gaussian(x, mu, logvar)
def log_density_gaussian(x, mu, logvar):
normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
inv_var = torch.exp(-logvar)
log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
return log_density
def log_importance_weight_matrix(batch_size, dataset_size):
N = dataset_size
M = batch_size - 1
strat_weight = (N - M) / (N * M)
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
W.view(-1)[::M + 1] = 1 / N
W.view(-1)[1::M + 1] = strat_weight
W[M - 1, 0] = strat_weight
return W.log()
def KL(latent_dist, latent_sample=None, alpha=1., beta=1.1, gamma=1., is_train=False):
batch_size, latent_dim = latent_sample.shape
log_pz, log_qz, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(latent_sample,
latent_dist,
is_mss=False)
# I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
mi_loss = (log_q_zCx - log_qz).mean()
# TC[z] = KL[q(z)||\prod_i z_i]
tc_loss = (log_qz - log_prod_qzi).mean()
# dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
dw_kl_loss = (log_prod_qzi - log_pz).mean()
#anneal_reg = (linear_annealing(0, 1, self.n_train_steps, self.steps_anneal)
# if is_train else 1)
# total loss
loss = alpha * mi_loss + beta * tc_loss + gamma * dw_kl_loss
return loss
def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data=100, is_mss=True):
batch_size, hidden_dim = latent_sample.shape
# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
if is_mss:
# use stratification
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)
return log_pz, log_qz, log_prod_qzi, log_q_zCx