Skip to content

Commit

Permalink
Fix ROPE extension issue and device mismatch (#840)
Browse files Browse the repository at this point in the history
* When an exception has been assigned using as target, it is cleared at the end of the except clause.(https://docs.python.org/3/reference/compound_stmts.html#the-try-statement)

* Update loader.py

* round up to extend rope size

* inv_freq.device changed, make sure they are on the same device

---------

Co-authored-by: xiaoyang <[email protected]>
Co-authored-by: Daniel Han <[email protected]>
  • Loading branch information
3 people authored Jul 31, 2024
1 parent fdfe1f5 commit 2de1427
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import gc
import math
from typing import Optional, Tuple, List, Union
from ._utils import *
from ._utils import __version__
Expand Down Expand Up @@ -1036,7 +1037,7 @@ def forward(self, x, position_ids=None, seq_len=None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down Expand Up @@ -1109,7 +1110,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
# in FP32. They are applied (multiplied) in FP32 as well.
self.current_rope_size = seq_len

t = torch.arange(self.current_rope_size, device="cpu", dtype=torch.int64).float()
t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float()

freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
Expand Down Expand Up @@ -1158,7 +1159,7 @@ def forward(self, x, position_ids=None, seq_len=None):
def extend_rope_embedding(self, x, seq_len):
if seq_len <= self.current_rope_size: return
# Iteratively grow by increments of 8192
self.current_rope_size = int(round(seq_len / 8192)) * 8192
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
pass
pass
Expand Down

0 comments on commit 2de1427

Please sign in to comment.