diff --git a/tests/nn/modules/test_simple_mamba.py b/tests/nn/modules/test_simple_mamba.py index e8ceacec..c17773be 100644 --- a/tests/nn/modules/test_simple_mamba.py +++ b/tests/nn/modules/test_simple_mamba.py @@ -4,10 +4,11 @@ from zeta.nn.modules.simple_mamba import ( Mamba, MambaBlock, - ResidualBlock, RMSNorm, ) +from zeta.rl.vision_model_rl import ResidualBlock + def test_mamba_class_init(): model = Mamba(10000, 512, 6)