From 5c71e5b8902d45f47713021c909be21207bd9940 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 6 Dec 2024 08:46:33 -0500 Subject: [PATCH] tests: flexible reqs/deps checking for more backends With the introduction of directml and the current landscape, its likely we may yet increase the number of supported driving technologies, so I've refactored the relevant tests to be more modular --- tests/conftest.py | 38 ++++++++++-------------- tests/test_horde_dep_updates.py | 51 +++++++++++++++++++++++---------- 2 files changed, 51 insertions(+), 38 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 69a1fe65..ff81a1cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ def init_hordelib() -> None: PRECOMMIT_FILE_PATH = Path(__file__).parent.parent / ".pre-commit-config.yaml" REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.txt" ROCM_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.rocm.txt" +DIRECTML_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.directml.txt" TRACKED_DEPENDENCIES = [ "horde_sdk", @@ -34,10 +35,9 @@ def tracked_dependencies() -> list[str]: return TRACKED_DEPENDENCIES -@pytest.fixture(scope="session") -def horde_dependency_versions() -> dict[str, str]: - """Get the versions of horde dependencies from the requirements file.""" - with open(REQUIREMENTS_FILE_PATH) as f: +def get_dependency_versions(requirements_file_path: str) -> dict[str, str]: + """Get the versions of horde dependencies from the given requirements file.""" + with open(requirements_file_path) as f: requirements = f.readlines() dependencies = {} @@ -60,27 +60,19 @@ def horde_dependency_versions() -> dict[str, str]: return dependencies +@pytest.fixture(scope="session") +def horde_dependency_versions() -> dict[str, str]: + """Get the versions of horde dependencies from the requirements file.""" + return get_dependency_versions(REQUIREMENTS_FILE_PATH) + + @pytest.fixture(scope="session") def rocm_horde_dependency_versions() -> dict[str, str]: """Get the versions of horde dependencies from the ROCm requirements file.""" - with open(ROCM_REQUIREMENTS_FILE_PATH) as f: - requirements = f.readlines() + return get_dependency_versions(ROCM_REQUIREMENTS_FILE_PATH) - dependencies = {} - for req in requirements: - for dep in TRACKED_DEPENDENCIES: - if req.startswith(dep): - if "==" in req: - version = req.split("==")[1].strip() - elif "~=" in req: - version = req.split("~=")[1].strip() - elif ">=" in req: - version = req.split(">=")[1].strip() - else: - raise ValueError(f"Unsupported version pin: {req}") - # Strip any info starting from the `+` character - version = version.split("+")[0] - dependencies[dep] = version - - return dependencies +@pytest.fixture(scope="session") +def directml_horde_dependency_versions() -> dict[str, str]: + """Get the versions of horde dependencies from the DirectML requirements file.""" + return get_dependency_versions(DIRECTML_REQUIREMENTS_FILE_PATH) diff --git a/tests/test_horde_dep_updates.py b/tests/test_horde_dep_updates.py index 0122a470..3f3f9279 100644 --- a/tests/test_horde_dep_updates.py +++ b/tests/test_horde_dep_updates.py @@ -43,33 +43,54 @@ def test_horde_update_runtime_updating(horde_dependency_versions: dict[str, str] assert found_line, "No initial torch install command found" -def test_different_requirements_files_match( - horde_dependency_versions: dict[str, str], - rocm_horde_dependency_versions: list[tuple[str, str]], +def check_dependency_versions( + main_deps: dict[str, str], + other_deps: dict[str, str], + other_name: str, ) -> None: - """Check that the versions of horde deps. in the main and rocm requirements files match.""" - rocm_deps = dict(rocm_horde_dependency_versions) + """Check that the main requirements file is consistent with the other requirements file. - for dep in horde_dependency_versions: + Args: + main_deps (dict[str, str]): The versions of the dependencies in the main requirements file. + other_deps (dict[str, str]): The versions of the dependencies in the other requirements file. + other_name (str): The name of the other requirements file. + + Raises: + AssertionError: If the versions of the dependencies are inconsistent. + """ + for dep in main_deps: if dep == "torch": logger.warning( - f"Skipping torch version check (main: {horde_dependency_versions[dep]}, rocm: {rocm_deps[dep]})", + f"Skipping torch version check (main: {main_deps[dep]}, {other_name}: {other_deps[dep]})", ) continue - assert dep in rocm_deps, f"Dependency {dep} not found in rocm requirements file" + assert dep in other_deps, f"Dependency {dep} not found in {other_name} requirements file" assert ( - horde_dependency_versions[dep] == rocm_deps[dep] - ), f"Dependency {dep} has different versions in main and rocm requirements files" + main_deps[dep] == other_deps[dep] + ), f"Dependency {dep} has different versions in main and {other_name} requirements files" - for dep in rocm_deps: + for dep in other_deps: if dep == "torch": logger.warning( - f"Skipping torch version check (main: {horde_dependency_versions[dep]}, rocm: {rocm_deps[dep]})", + f"Skipping torch version check (main: {main_deps[dep]}, {other_name}: {other_deps[dep]})", ) continue - assert dep in horde_dependency_versions, f"Dependency {dep} not found in main requirements file" + assert dep in main_deps, f"Dependency {dep} not found in main requirements file" assert ( - rocm_deps[dep] == horde_dependency_versions[dep] - ), f"Dependency {dep} has different versions in main and rocm requirements files" + other_deps[dep] == main_deps[dep] + ), f"Dependency {dep} has different versions in main and {other_name} requirements files" + + +def test_different_requirements_files_match( + horde_dependency_versions: dict[str, str], + rocm_horde_dependency_versions: list[tuple[str, str]], + directml_horde_dependency_versions: list[tuple[str, str]], +) -> None: + """Check that the versions of horde deps. in the all of the various requirements files are consistent.""" + rocm_deps = dict(rocm_horde_dependency_versions) + directml_deps = dict(directml_horde_dependency_versions) + + check_dependency_versions(horde_dependency_versions, rocm_deps, "rocm") + check_dependency_versions(horde_dependency_versions, directml_deps, "directml")