From 0832087f78006c10d52c0600c7377c5929568e0b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 4 Apr 2022 13:47:42 -0700 Subject: [PATCH] forget about triton, optimize for clarity and education --- README.md | 4 ---- palm_pytorch/palm_pytorch.py | 7 +++---- setup.py | 2 +- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 9c6035d..6284978 100644 --- a/README.md +++ b/README.md @@ -29,10 +29,6 @@ tokens = torch.randint(0, 20000, (1, 2048)) logits = palm(tokens) # (1, 2048, 20000) ``` -## Todo - -- [ ] use Triton to add bias-less Layernorm + stable softmax - ## Citations ```bibtex diff --git a/palm_pytorch/palm_pytorch.py b/palm_pytorch/palm_pytorch.py index c550b84..6570a42 100644 --- a/palm_pytorch/palm_pytorch.py +++ b/palm_pytorch/palm_pytorch.py @@ -12,12 +12,11 @@ class LayerNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() self.eps = eps - self.g = nn.Parameter(torch.ones(dim)) + self.gamma = nn.Parameter(torch.ones(dim)) + self.register_buffer('beta', torch.zeros(dim)) def forward(self, x): - var = torch.var(x, dim = -1, unbiased = False, keepdim = True) - mean = torch.mean(x, dim = -1, keepdim = True) - return (x - mean) / (var + self.eps).sqrt() * self.g + return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) # parallel with residual # discovered by Wang et al + EleutherAI from GPT-J fame diff --git a/setup.py b/setup.py index ce4a441..7d8b37b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'PaLM-pytorch', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'PaLM: Scaling Language Modeling with Pathways - Pytorch', author = 'Phil Wang',