From 8b1de30162d553c5b387d877fd5f3bb72ff3247b Mon Sep 17 00:00:00 2001 From: Eyon Date: Fri, 19 Jan 2024 00:22:02 +0000 Subject: [PATCH] #4003: Pytest automatically picking up sweep tests --- tests/ttnn/sweep_tests/sweep.py | 14 ++++-- .../ttnn/sweep_tests/test_all_sweep_tests.py | 45 +++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 tests/ttnn/sweep_tests/test_all_sweep_tests.py diff --git a/tests/ttnn/sweep_tests/sweep.py b/tests/ttnn/sweep_tests/sweep.py index fa2e191a426..93141fecee0 100644 --- a/tests/ttnn/sweep_tests/sweep.py +++ b/tests/ttnn/sweep_tests/sweep.py @@ -53,11 +53,15 @@ def get_parameter_names(parameters): return [name] + get_parameter_names(other_parameters) +def preprocess_parameter_value(parameter_value): + if callable(parameter_value): + parameter_value = parameter_value.__name__ + return parameter_value + + def get_parameter_values(parameter_names, permutation): for parameter_name in parameter_names: - parameter_value = permutation[parameter_name] - if callable(parameter_value): - parameter_value = parameter_value.__name__ + parameter_value = preprocess_parameter_value(permutation[parameter_name]) yield parameter_value @@ -94,7 +98,9 @@ def run_single_test(test_name, index, *, device): sweep_module = SourceFileLoader(f"sweep_module_{file_name.stem}", str(file_name)).load_module() permutation = list(permutations(sweep_module.parameters))[index] - pretty_printed_parameters = ",\n".join(f"\t{key}={value}" for key, value in permutation.items()) + pretty_printed_parameters = ",\n".join( + f"\t{key}={preprocess_parameter_value(value)}" for key, value in permutation.items() + ) logger.info(f"Running sweep test at index {index}:\n{{{pretty_printed_parameters}}}") return _run_single_test( sweep_module.run, sweep_module.skip, sweep_module.is_expected_to_fail, permutation, device=device diff --git a/tests/ttnn/sweep_tests/test_all_sweep_tests.py b/tests/ttnn/sweep_tests/test_all_sweep_tests.py new file mode 100644 index 00000000000..8c542124e13 --- /dev/null +++ b/tests/ttnn/sweep_tests/test_all_sweep_tests.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from importlib.machinery import SourceFileLoader +from tests.ttnn.sweep_tests.sweep import ( + SWEEP_SOURCES_DIR, + permutations, + run_single_test, +) +from loguru import logger +from dataclasses import dataclass +import pytest +import os + + +@dataclass +class SweepTest: + file_name: str + sweep_test_index: int + + def __str__(self): + return f"{os.path.basename(self.file_name)}-{self.sweep_test_index}" + + +sweep_tests = [] +for file_name in sorted(SWEEP_SOURCES_DIR.glob("*.py")): + logger.info(f"Running {file_name}") + base_name = os.path.basename(file_name) + base_name = os.path.splitext(base_name)[0] + sweep_module = SourceFileLoader(f"sweep_module_{base_name}", str(file_name)).load_module() + base_name = base_name + ".csv" + for sweep_test_index, parameter_list in enumerate(permutations(sweep_module.parameters)): + sweep_tests.append(SweepTest(file_name, sweep_test_index)) + + +@pytest.mark.parametrize("sweep_test", sweep_tests, ids=str) +def test_all_sweeps(device, sweep_test): + status, message = run_single_test( + sweep_test.file_name, + sweep_test.sweep_test_index, + device=device, + ) + + assert status not in {"failed", "crashed"}, f"{message}"