diff --git a/tests/conftest.py b/tests/conftest.py index 69a1fe65..ff81a1cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ def init_hordelib() -> None: PRECOMMIT_FILE_PATH = Path(__file__).parent.parent / ".pre-commit-config.yaml" REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.txt" ROCM_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.rocm.txt" +DIRECTML_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.directml.txt" TRACKED_DEPENDENCIES = [ "horde_sdk", @@ -34,10 +35,9 @@ def tracked_dependencies() -> list[str]: return TRACKED_DEPENDENCIES -@pytest.fixture(scope="session") -def horde_dependency_versions() -> dict[str, str]: - """Get the versions of horde dependencies from the requirements file.""" - with open(REQUIREMENTS_FILE_PATH) as f: +def get_dependency_versions(requirements_file_path: str) -> dict[str, str]: + """Get the versions of horde dependencies from the given requirements file.""" + with open(requirements_file_path) as f: requirements = f.readlines() dependencies = {} @@ -60,27 +60,19 @@ def horde_dependency_versions() -> dict[str, str]: return dependencies +@pytest.fixture(scope="session") +def horde_dependency_versions() -> dict[str, str]: + """Get the versions of horde dependencies from the requirements file.""" + return get_dependency_versions(REQUIREMENTS_FILE_PATH) + + @pytest.fixture(scope="session") def rocm_horde_dependency_versions() -> dict[str, str]: """Get the versions of horde dependencies from the ROCm requirements file.""" - with open(ROCM_REQUIREMENTS_FILE_PATH) as f: - requirements = f.readlines() + return get_dependency_versions(ROCM_REQUIREMENTS_FILE_PATH) - dependencies = {} - for req in requirements: - for dep in TRACKED_DEPENDENCIES: - if req.startswith(dep): - if "==" in req: - version = req.split("==")[1].strip() - elif "~=" in req: - version = req.split("~=")[1].strip() - elif ">=" in req: - version = req.split(">=")[1].strip() - else: - raise ValueError(f"Unsupported version pin: {req}") - # Strip any info starting from the `+` character - version = version.split("+")[0] - dependencies[dep] = version - - return dependencies +@pytest.fixture(scope="session") +def directml_horde_dependency_versions() -> dict[str, str]: + """Get the versions of horde dependencies from the DirectML requirements file.""" + return get_dependency_versions(DIRECTML_REQUIREMENTS_FILE_PATH) diff --git a/tests/test_horde_dep_updates.py b/tests/test_horde_dep_updates.py index 0122a470..3f3f9279 100644 --- a/tests/test_horde_dep_updates.py +++ b/tests/test_horde_dep_updates.py @@ -43,33 +43,54 @@ def test_horde_update_runtime_updating(horde_dependency_versions: dict[str, str] assert found_line, "No initial torch install command found" -def test_different_requirements_files_match( - horde_dependency_versions: dict[str, str], - rocm_horde_dependency_versions: list[tuple[str, str]], +def check_dependency_versions( + main_deps: dict[str, str], + other_deps: dict[str, str], + other_name: str, ) -> None: - """Check that the versions of horde deps. in the main and rocm requirements files match.""" - rocm_deps = dict(rocm_horde_dependency_versions) + """Check that the main requirements file is consistent with the other requirements file. - for dep in horde_dependency_versions: + Args: + main_deps (dict[str, str]): The versions of the dependencies in the main requirements file. + other_deps (dict[str, str]): The versions of the dependencies in the other requirements file. + other_name (str): The name of the other requirements file. + + Raises: + AssertionError: If the versions of the dependencies are inconsistent. + """ + for dep in main_deps: if dep == "torch": logger.warning( - f"Skipping torch version check (main: {horde_dependency_versions[dep]}, rocm: {rocm_deps[dep]})", + f"Skipping torch version check (main: {main_deps[dep]}, {other_name}: {other_deps[dep]})", ) continue - assert dep in rocm_deps, f"Dependency {dep} not found in rocm requirements file" + assert dep in other_deps, f"Dependency {dep} not found in {other_name} requirements file" assert ( - horde_dependency_versions[dep] == rocm_deps[dep] - ), f"Dependency {dep} has different versions in main and rocm requirements files" + main_deps[dep] == other_deps[dep] + ), f"Dependency {dep} has different versions in main and {other_name} requirements files" - for dep in rocm_deps: + for dep in other_deps: if dep == "torch": logger.warning( - f"Skipping torch version check (main: {horde_dependency_versions[dep]}, rocm: {rocm_deps[dep]})", + f"Skipping torch version check (main: {main_deps[dep]}, {other_name}: {other_deps[dep]})", ) continue - assert dep in horde_dependency_versions, f"Dependency {dep} not found in main requirements file" + assert dep in main_deps, f"Dependency {dep} not found in main requirements file" assert ( - rocm_deps[dep] == horde_dependency_versions[dep] - ), f"Dependency {dep} has different versions in main and rocm requirements files" + other_deps[dep] == main_deps[dep] + ), f"Dependency {dep} has different versions in main and {other_name} requirements files" + + +def test_different_requirements_files_match( + horde_dependency_versions: dict[str, str], + rocm_horde_dependency_versions: list[tuple[str, str]], + directml_horde_dependency_versions: list[tuple[str, str]], +) -> None: + """Check that the versions of horde deps. in the all of the various requirements files are consistent.""" + rocm_deps = dict(rocm_horde_dependency_versions) + directml_deps = dict(directml_horde_dependency_versions) + + check_dependency_versions(horde_dependency_versions, rocm_deps, "rocm") + check_dependency_versions(horde_dependency_versions, directml_deps, "directml")