diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 45becc33..5128d1a9 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -51,7 +51,7 @@ def inner(model, *args, **kwargs): # sampling helpers def log(t, eps = 1e-20): - return torch.log(t + eps) + return torch.log(t.clamp(min = eps)) def gumbel_noise(t): noise = torch.zeros_like(t).uniform_(0, 1) @@ -239,7 +239,7 @@ def forward( one_hot = one_hot.detach() π0 = logits.softmax(dim = 1) π1 = (one_hot + (logits / temp).softmax(dim = 1)) / 2 - π1 = ((π1.log() - logits).detach() + logits).softmax(dim = 1) + π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1) π2 = 2 * π1 - 0.5 * π0 one_hot = π2 - π2.detach() + one_hot diff --git a/dalle_pytorch/version.py b/dalle_pytorch/version.py index f3df7f04..008e8016 100644 --- a/dalle_pytorch/version.py +++ b/dalle_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.5' +__version__ = '1.6.6'