Skip to content

Commit

Permalink
Update Test Files
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Nov 15, 2024
1 parent 727a7ee commit 087877d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 14 deletions.
17 changes: 3 additions & 14 deletions tests/sweep_framework/sweeps/eltwise/unary/prelu/prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30

random.seed(0)


# Parameters provided to the test vector generator are defined here.
# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values.
Expand All @@ -45,12 +40,6 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]:
return False, None


def torch_prelu(x, *args, **kwargs):
weight = kwargs.pop("scalar")
result = torch.nn.functional.prelu(x, torch.tensor(weight, dtype=x.dtype))
return result


# This is the run instructions for the test, defined by the developer.
# The run function must take the above-defined parameters as inputs.
# The runner will call this run function with each test vector, and the returned results from this function will be stored.
Expand All @@ -65,14 +54,14 @@ def run(
*,
device,
) -> list:
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)
torch.manual_seed(0)

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_shape)

torch_output_tensor = torch_prelu(torch_input_tensor_a, scalar=weight)
golden_function = ttnn.get_golden_function(ttnn.prelu)
torch_output_tensor = golden_function(torch_input_tensor_a, weight)

input_tensor_a = ttnn.from_torch(
torch_input_tensor_a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def test_run_eltwise_leaky_relu_op(
)

@pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5])
@skip_for_grayskull()
def test_run_eltwise_prelu(
self,
input_shapes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def test_scalarB_leaky_relu(device, h, w, scalar):
run_activation_test_leaky_relu(device, h, w, scalar, ttnn.leaky_relu)


@skip_for_grayskull()
@pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5])
@pytest.mark.parametrize("h", [64])
@pytest.mark.parametrize("w", [128])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ def test_binary_prelu_ttnn(input_shapes, device):
"scalar",
{-0.25, -2.7, 0.45, 6.4},
)
@skip_for_grayskull()
def test_binary_prelu_scalar_ttnn(input_shapes, scalar, device):
in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device)
output_tensor = ttnn.prelu(input_tensor1, scalar)
Expand Down

0 comments on commit 087877d

Please sign in to comment.