Skip to content

Commit

Permalink
#4003: Pytest automatically picking up sweep tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eyonland committed Jan 22, 2024
1 parent 119f31f commit 8b1de30
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
14 changes: 10 additions & 4 deletions tests/ttnn/sweep_tests/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,15 @@ def get_parameter_names(parameters):
return [name] + get_parameter_names(other_parameters)


def preprocess_parameter_value(parameter_value):
if callable(parameter_value):
parameter_value = parameter_value.__name__
return parameter_value


def get_parameter_values(parameter_names, permutation):
for parameter_name in parameter_names:
parameter_value = permutation[parameter_name]
if callable(parameter_value):
parameter_value = parameter_value.__name__
parameter_value = preprocess_parameter_value(permutation[parameter_name])
yield parameter_value


Expand Down Expand Up @@ -94,7 +98,9 @@ def run_single_test(test_name, index, *, device):
sweep_module = SourceFileLoader(f"sweep_module_{file_name.stem}", str(file_name)).load_module()
permutation = list(permutations(sweep_module.parameters))[index]

pretty_printed_parameters = ",\n".join(f"\t{key}={value}" for key, value in permutation.items())
pretty_printed_parameters = ",\n".join(
f"\t{key}={preprocess_parameter_value(value)}" for key, value in permutation.items()
)
logger.info(f"Running sweep test at index {index}:\n{{{pretty_printed_parameters}}}")
return _run_single_test(
sweep_module.run, sweep_module.skip, sweep_module.is_expected_to_fail, permutation, device=device
Expand Down
45 changes: 45 additions & 0 deletions tests/ttnn/sweep_tests/test_all_sweep_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from importlib.machinery import SourceFileLoader
from tests.ttnn.sweep_tests.sweep import (
SWEEP_SOURCES_DIR,
permutations,
run_single_test,
)
from loguru import logger
from dataclasses import dataclass
import pytest
import os


@dataclass
class SweepTest:
file_name: str
sweep_test_index: int

def __str__(self):
return f"{os.path.basename(self.file_name)}-{self.sweep_test_index}"


sweep_tests = []
for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")):
logger.info(f"Running {file_name}")
base_name = os.path.basename(file_name)
base_name = os.path.splitext(base_name)[0]
sweep_module = SourceFileLoader(f"sweep_module_{base_name}", str(file_name)).load_module()
base_name = base_name + ".csv"
for sweep_test_index, parameter_list in enumerate(permutations(sweep_module.parameters)):
sweep_tests.append(SweepTest(file_name, sweep_test_index))


@pytest.mark.parametrize("sweep_test", sweep_tests, ids=str)
def test_all_sweeps(device, sweep_test):
status, message = run_single_test(
sweep_test.file_name,
sweep_test.sweep_test_index,
device=device,
)

assert status not in {"failed", "crashed"}, f"{message}"

0 comments on commit 8b1de30

Please sign in to comment.