Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Sep 4, 2024
1 parent 480f2ef commit 3682672
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 2 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ def patch_mistral_nemo_config(config):

# =============================================
# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
accelerate_old_send_to_device = None
accelerate_new_send_to_device = None
if Version(xformers_version) >= Version("0.0.27"):
import accelerate.utils.operations
if hasattr(accelerate.utils.operations, "send_to_device") and \
Expand Down
10 changes: 7 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,8 +1371,10 @@ def _fast_generate(*args, **kwargs):
internal_model._flag_for_generation = True

# Must patch accelerate for Xformers
import accelerate.utils.operations
accelerate.utils.operations.send_to_device = accelerate_new_send_to_device
if accelerate_new_send_to_device is not None:
import accelerate.utils.operations
accelerate.utils.operations.send_to_device = accelerate_new_send_to_device
pass

# For newer HF
kwargs["cache_implementation"] = "dynamic"
Expand Down Expand Up @@ -1411,7 +1413,9 @@ def _fast_generate(*args, **kwargs):
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation

# Return accelerate back
accelerate.utils.operations.send_to_device = accelerate_old_send_to_device
if accelerate_new_send_to_device is not None:
accelerate.utils.operations.send_to_device = accelerate_old_send_to_device
pass

return output
pass
Expand Down

0 comments on commit 3682672

Please sign in to comment.