Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross Attention does not work on CPU and older GPUs #6

Open
Vectorrent opened this issue Dec 18, 2024 · 3 comments
Open

Cross Attention does not work on CPU and older GPUs #6

Vectorrent opened this issue Dec 18, 2024 · 3 comments

Comments

@Vectorrent
Copy link
Contributor

The current CrossAttention code has a hardcoded dependency on FlexAttention. This is a problem for people like me, who need to use older, consumer GPUs.

I'm not sure of the feasibility, but it would be great if we could fall-back onto eager execution, if possible.

On CPU, FlexAttention doesn't work at all.

Error type: <class 'torch._dynamo.exc.BackendCompilerFailed'>
Traceback (most recent call last):
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2234, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 588, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1370, in load
    compiled_graph = compile_fx_fn(
                     ^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 570, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 859, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 780, in run
    return super().run(*args)
           ^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1319, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1024, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1021, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 361, in wrapped
    out = decomp_fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py", line 895, in flex_attention
    autotune_select_algorithm(
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1724, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1209, in __call__
    raise NoValidChoicesError(
torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[2, 1, 128, 256], stride=[32768, 256, 256, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='primals_2', layout=FixedLayout('cpu', torch.float32, size=[2, 1, 128, 256], stride=[32768, 256, 256, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cpu', torch.float32, size=[2, 1, 128, 256], stride=[32768, 256, 256, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.constant(1, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1],
      origin_node=full,
      origins=OrderedSet([full])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.constant(0, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1, 1],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    ))
  )), None, None, TensorBox(StorageBox(
    ComputedBuffer(name='buf7', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.load(buf2, 0)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1],
      origin_node=convert_element_type,
      origins=OrderedSet([convert_element_type, sum_1])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf8', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.load(buf6, 0)
          tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)
          return tmp1
      ,
      ranges=[1, 1, 1, 1],
      origin_node=convert_element_type_1,
      origins=OrderedSet([convert_element_type_1])
    ))
  )), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.0625
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: ()

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/crow/repos/praxis/praxis/modules/encoder.py", line 197, in run_test
    decoder_output = model.decode(
                     ^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/praxis/modules/encoder.py", line 90, in decode
    output, _ = self.decoder(
                ^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/bytelatent/model/local_models.py", line 343, in forward
    h_cross = self.cross_attn_layers[i](
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/bytelatent/model/transformer.py", line 98, in forward
    output = flex_attention_comp(xq, xk, xv, block_mask=mask)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
    self._return(inst)
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
    self.output.compile_subgraph(
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. 
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[2, 1, 128, 256], stride=[32768, 256, 256, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='primals_2', layout=FixedLayout('cpu', torch.float32, size=[2, 1, 128, 256], stride=[32768, 256, 256, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cpu', torch.float32, size=[2, 1, 128, 256], stride=[32768, 256, 256, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.constant(1, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1],
      origin_node=full,
      origins=OrderedSet([full])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf1', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.constant(0, torch.int32)
          return tmp0
      ,
      ranges=[1, 1, 1, 1],
      origin_node=full_default,
      origins=OrderedSet([full_default])
    ))
  )), None, None, TensorBox(StorageBox(
    ComputedBuffer(name='buf7', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1], stride=[1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _ = index
          tmp0 = ops.load(buf2, 0)
          tmp1 = ops.to_dtype(tmp0, torch.int64, src_dtype=torch.int32)
          tmp2 = ops.to_dtype(tmp1, torch.int32, src_dtype=torch.int64)
          return tmp2
      ,
      ranges=[1, 1, 1],
      origin_node=convert_element_type,
      origins=OrderedSet([convert_element_type, sum_1])
    ))
  )), TensorBox(StorageBox(
    ComputedBuffer(name='buf8', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 1, 1], stride=[1, 1, 1, 1]), data=Pointwise(
      'cpu',
      torch.int32,
      def inner_fn(index):
          _, _, _, _ = index
          tmp0 = ops.load(buf6, 0)
          tmp1 = ops.to_dtype(tmp0, torch.int32, src_dtype=torch.int64)
          return tmp1
      ,
      ranges=[1, 1, 1, 1],
      origin_node=convert_element_type_1,
      origins=OrderedSet([convert_element_type_1])
    ))
  )), None, None, 1073741824, 1073741824, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.0625
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

On older GPUs:

RuntimeError: Found NVIDIA GeForce GTX 1070 which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability 6.1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

If this could be easily fixed, I would greatly appreciate a patch! But if not, I'll likely be looking to implement this myself.

@Vectorrent
Copy link
Contributor Author

Vectorrent commented Dec 18, 2024

I created a monkey patch for CrossAttention, which replaces the FlexAttention part with standard SDPA. This seems to fix the problem, though it would be nice to integrate upstream, somehow.

(note: this particular patch seems to have an unchecked memory leak; I have no idea why cross attention is causing that here; something about cross_attn_encoder=True is causing a gradual increase in VRAM)

@Vectorrent
Copy link
Contributor Author

After more investigation, I found that we can break the unchecked memory growth by detaching x from the computation graph here.

    def forward(self, x, kv, mask=None) -> torch.Tensor:
        # B S D
        bsz, seq_len, _ = x.shape
        _, slen_kv, _ = kv.shape
        x = x.detach()
        x = self.cross_attn_norm_q(x)

I don't know how much of a "solution" this really is, but it does indicate that the problems relate to gradient growth. Again, this only happens when using cross attention in the encoder; not the decoder. So, I'm thinking there must be a huge gradients matrix here, which comes from comparing inputs via cross attention, and somehow - FlexAttention makes that a lot more efficient.

@JohnlNguyen
Copy link

We ran our experiments using FlexAttention and didn't test on any other variants. I'm not sure if you will reproduce the results with SDPA, but worth a shot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants