diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_sgd.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_sgd.py index 173f5644002..afd4d7d6c7f 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_sgd.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_sgd.py @@ -56,10 +56,10 @@ def create_tt_tensor(tensor, device): @pytest.mark.parametrize( "shape", - ( - (32, 32), # single - (12, 6, 64, 64), # multiple tiles - ), + [ + [32, 32], # single + [12, 6, 64, 64], # multiple tiles + ], ) @pytest.mark.parametrize("lr", [3.0]) @pytest.mark.parametrize("momentum", [0.0, 7.7]) @@ -195,7 +195,7 @@ def forward(self, x): @pytest.mark.parametrize( "shape", - ((32, 32),), # single + [[32, 32]], # single ) @pytest.mark.parametrize("lr", [3.0]) @pytest.mark.parametrize("momentum", [7.7])