Skip to content

Commit

Permalink
fix generation with large sequences when flash2 is False (OpenNMT#2564)
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 authored Feb 22, 2024
1 parent 0e72326 commit d9d8b77
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,11 @@ def forward(
or query.dtype != torch.float16
):
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):

if seqlen + start_pos > self.rope.size(0):
# Resize rotary embeddings.
self.rope, _, _ = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
maxseqlen=(seqlen + start_pos + 2048),
base=self.rotary_theta,
device=self.rope.device,
)
Expand Down

0 comments on commit d9d8b77

Please sign in to comment.