Skip to content

Commit

Permalink
fix prelayernorm in attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 4, 2022
1 parent f8eb18f commit 49232e1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions palm_pytorch/palm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,17 @@ def forward(self, x):
"""

n, device, h = x.shape[1], x.device, self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

# pre layernorm

x = self.norm(x)

# queries, keys, values

q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

# split heads
# they use multi-query attention, yet another Noam Shazeer paper
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously
# https://arxiv.org/abs/1911.02150

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.8',
version = '0.0.9',
license='MIT',
description = 'PaLM: Scaling Language Modeling with Pathways - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 49232e1

Please sign in to comment.