From ddc17bd7908bbd473c467b1cf6a2cf452bc81055 Mon Sep 17 00:00:00 2001 From: mouliraj-mcw Date: Wed, 13 Nov 2024 06:55:37 +0000 Subject: [PATCH] Update Test Files --- .../sweeps/eltwise/unary/prelu/prelu.py | 17 +++-------------- .../pytests/tt_dnn/test_eltwise_unary.py | 1 + .../operations/eltwise/test_activation.py | 1 + .../operations/eltwise/test_binary_composite.py | 1 + 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py b/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py index 134825bdf04..b2940fcf817 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py @@ -14,11 +14,6 @@ from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random -# Override the default timeout in seconds for hang detection. -TIMEOUT = 30 - -random.seed(0) - # Parameters provided to the test vector generator are defined here. # They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. @@ -45,12 +40,6 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: return False, None -def torch_prelu(x, *args, **kwargs): - weight = kwargs.pop("scalar") - result = torch.nn.functional.prelu(x, torch.tensor(weight, dtype=x.dtype)) - return result - - # This is the run instructions for the test, defined by the developer. # The run function must take the above-defined parameters as inputs. # The runner will call this run function with each test vector, and the returned results from this function will be stored. @@ -65,14 +54,14 @@ def run( *, device, ) -> list: - data_seed = random.randint(0, 20000000) - torch.manual_seed(data_seed) + torch.manual_seed(0) torch_input_tensor_a = gen_func_with_cast_tt( partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype )(input_shape) - torch_output_tensor = torch_prelu(torch_input_tensor_a, scalar=weight) + golden_function = ttnn.get_golden_function(ttnn.prelu) + torch_output_tensor = golden_function(torch_input_tensor_a, weight) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py index ea9029a7cf6..1db66f53ced 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py @@ -871,6 +871,7 @@ def test_run_eltwise_leaky_relu_op( ) @pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5]) + @skip_for_grayskull() def test_run_eltwise_prelu( self, input_shapes, diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_activation.py b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py index 57f8ecf4284..4407dd2a306 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_activation.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py @@ -307,6 +307,7 @@ def test_scalarB_leaky_relu(device, h, w, scalar): run_activation_test_leaky_relu(device, h, w, scalar, ttnn.leaky_relu) +@skip_for_grayskull() @pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5]) @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index 96050f92511..6f5ff325837 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -996,6 +996,7 @@ def test_binary_prelu_ttnn(input_shapes, device): "scalar", {-0.25, -2.7, 0.45, 6.4}, ) +@skip_for_grayskull() def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device): in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) output_tensor = ttnn.prelu(input_tensor1, scalar)