Skip to content

Commit

Permalink
forget about triton, optimize for clarity and education
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 4, 2022
1 parent f5332eb commit 0832087
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 9 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ tokens = torch.randint(0, 20000, (1, 2048))
logits = palm(tokens) # (1, 2048, 20000)
```

## Todo

- [ ] use Triton to add bias-less Layernorm + stable softmax

## Citations

```bibtex
Expand Down
7 changes: 3 additions & 4 deletions palm_pytorch/palm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer('beta', torch.zeros(dim))

def forward(self, x):
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# parallel with residual
# discovered by Wang et al + EleutherAI from GPT-J fame
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'PaLM: Scaling Language Modeling with Pathways - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0832087

Please sign in to comment.