diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 9a2054c5..9cf999e2 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -47,7 +47,7 @@ pass # Logit softcapping -@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options) +@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim