Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jul 27, 2023
1 parent 17b708a commit 9c82dc5
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions zeta/utils/attention/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,11 @@ def forward(
v = self.v_proj(value)
q *= self.scaling

q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, src_len, self.num_heads, self.head_dim).transpose(1, 2)
q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
k = k.view(bsz, self.num_heads, src_len, self.head_dim)
v = v.view(bsz, self.num_heads, src_len, self.head_dim)


q = q.reshape(bsz * self.num_heads, tgt_len, self.head_dim)
k = k.reshape(bsz * self.num_heads, src_len, self.head_dim)
v = v.reshape(bsz * self.num_heads, src_len, self.head_dim)
Expand Down

0 comments on commit 9c82dc5

Please sign in to comment.