You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am unable to run Mamba2 with a 2080Ti GPU due to triton errors. I don't have problems with Ampere cards. I install mamba from source in the NGC container and I'm using the latest 560 driver. The problem occurs when compiling the _mamba_chunk_scan_combined_fwd kernel. Full traceback is below. I used MLIR_ENABLE_DUMP=1 at some point but I'm not sure really how to read it to get any insight. I can re-run and add that to the issue if it helps. I assume this is somewhat related to issues IndexError: map::at others are having with <SM80 GPUS.
I'm using fp32, no AMP is enabled or anything.
FROM nvcr.io/nvidia/pytorch:24.08-py3 AS train-image
...
RUN pip3 install git+https://github.com/state-spaces/[email protected] \
git+https://github.com/Dao-AILab/[email protected]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba2.py", line 183, in forward
out = mamba_split_conv1d_scan_combined(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 456, in decorate_fwd
return fwd(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 795, in forward
out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 313, in _mamba_chunk_scan_combined_fwd
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 746, in _chunk_state_fwd
_chunk_state_fwd_kernel[grid](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 103, in do_bench
fn()
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 114, in kernel_call
self.fn.run(
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 662, in run
kernel = self.compile(
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 282, in compile
next_module = compile_ir(module, metadata)
File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 318, in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/compiler.py", line 216, in make_llir
pm.run(mod)
IndexError: map::at
The text was updated successfully, but these errors were encountered:
I am unable to run Mamba2 with a 2080Ti GPU due to triton errors. I don't have problems with Ampere cards. I install mamba from source in the NGC container and I'm using the latest 560 driver. The problem occurs when compiling the _mamba_chunk_scan_combined_fwd kernel. Full traceback is below. I used MLIR_ENABLE_DUMP=1 at some point but I'm not sure really how to read it to get any insight. I can re-run and add that to the issue if it helps. I assume this is somewhat related to issues
IndexError: map::at
others are having with <SM80 GPUS.I'm using fp32, no AMP is enabled or anything.
The text was updated successfully, but these errors were encountered: