diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 242d234d..6dd17e73 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -285,6 +285,8 @@ def patch_mistral_nemo_config(config): # ============================================= # Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout' +accelerate_old_send_to_device = None +accelerate_new_send_to_device = None if Version(xformers_version) >= Version("0.0.27"): import accelerate.utils.operations if hasattr(accelerate.utils.operations, "send_to_device") and \ diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3fcb8a76..39998127 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1371,8 +1371,10 @@ def _fast_generate(*args, **kwargs): internal_model._flag_for_generation = True # Must patch accelerate for Xformers - import accelerate.utils.operations - accelerate.utils.operations.send_to_device = accelerate_new_send_to_device + if accelerate_new_send_to_device is not None: + import accelerate.utils.operations + accelerate.utils.operations.send_to_device = accelerate_new_send_to_device + pass # For newer HF kwargs["cache_implementation"] = "dynamic" @@ -1411,7 +1413,9 @@ def _fast_generate(*args, **kwargs): if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation # Return accelerate back - accelerate.utils.operations.send_to_device = accelerate_old_send_to_device + if accelerate_new_send_to_device is not None: + accelerate.utils.operations.send_to_device = accelerate_old_send_to_device + pass return output pass