Skip to content

Commit

Permalink
add reinmax
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 10, 2023
1 parent daf30d0 commit 66b573b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -751,4 +751,14 @@ $ python generate.py --chinese --text '追老鼠的猫'
}
```

```bibtex
@article{Liu2023BridgingDA,
title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
journal = {ArXiv},
year = {2023},
volume = {abs/2304.08612}
}
```

*Those who do not want to imitate anything, produce nothing.* - Dali
19 changes: 17 additions & 2 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
smooth_l1_loss = False,
temperature = 0.9,
straight_through = False,
reinmax = False,
kl_div_loss_weight = 0.,
normalization = ((*((0.5,) * 3), 0), (*((0.5,) * 3), 1))
):
Expand All @@ -125,6 +126,8 @@ def __init__(
self.num_layers = num_layers
self.temperature = temperature
self.straight_through = straight_through
self.reinmax = reinmax

self.codebook = nn.Embedding(num_tokens, codebook_dim)

hdim = hidden_dim
Expand Down Expand Up @@ -227,8 +230,20 @@ def forward(
return logits # return logits for getting hard image indices for DALL-E training

temp = default(temp, self.temperature)
soft_one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)
sampled = einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)

one_hot = F.gumbel_softmax(logits, tau = temp, dim = 1, hard = self.straight_through)

if self.straight_through and self.reinmax:
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
# algorithm 2
one_hot = one_hot.detach()
π0 = logits.softmax(dim = 1)
π1 = (one_hot + (logits / temp).softmax(dim = 1)) / 2
π1 = ((π1.log() - logits).detach() + logits).softmax(dim = 1)
π2 = 2 * π1 - 0.5 * π0
one_hot = π2 - π2.detach() + one_hot

sampled = einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight)
out = self.decoder(sampled)

if not return_loss:
Expand Down
2 changes: 1 addition & 1 deletion dalle_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.6.4'
__version__ = '1.6.5'

0 comments on commit 66b573b

Please sign in to comment.