diff --git a/download_models.py b/download_models.py index ec611b7b..8d4a1e48 100644 --- a/download_models.py +++ b/download_models.py @@ -17,6 +17,12 @@ default=False, help="Load the config only from environment variables. This is useful for running the worker in a container.", ) + parser.add_argument( + "--directml", + type=int, + default=None, + help="Enable directml and specify device to use.", + ) args = parser.parse_args() @@ -25,4 +31,5 @@ download_all_models( purge_unused_loras=args.purge_unused_loras, load_config_from_env_vars=args.load_config_from_env_vars, + directml=args.directml, ) diff --git a/horde-bridge-directml.cmd b/horde-bridge-directml.cmd new file mode 100644 index 00000000..7cb987f1 --- /dev/null +++ b/horde-bridge-directml.cmd @@ -0,0 +1,35 @@ +@echo off +cd /d %~dp0 + +: This first call to runtime activates the environment for the rest of the script +call runtime python -s -m pip -V + +call python -s -m pip uninstall hordelib +call python -s -m pip install horde_sdk~=0.16.4 horde_model_reference~=0.9.1 horde_engine~=2.18.1 horde_safety~=0.2.3 -U + +if %ERRORLEVEL% NEQ 0 ( + echo "Please run update-runtime.cmd." + GOTO END +) + +call python -s -m pip check +if %ERRORLEVEL% NEQ 0 ( + echo "Please run update-runtime.cmd." + GOTO END +) + +:DOWNLOAD +call python -s download_models.py --directml=0 +if %ERRORLEVEL% NEQ 0 GOTO ABORT +echo "Model Download OK. Starting worker..." +call python -s run_worker.py --directml=0 %* + +GOTO END + +:ABORT +echo "download_models.py exited with error code. Aborting" + +:END +call micromamba deactivate >nul +call deactivate >nul +pause diff --git a/horde_worker_regen/download_models.py b/horde_worker_regen/download_models.py index 51432149..ae32af94 100644 --- a/horde_worker_regen/download_models.py +++ b/horde_worker_regen/download_models.py @@ -5,6 +5,7 @@ def download_all_models( *, load_config_from_env_vars: bool = False, purge_unused_loras: bool = False, + directml: int | None = None, ) -> None: """Download all models specified in the config file.""" from horde_worker_regen.load_env_vars import load_env_vars_from_config @@ -55,7 +56,11 @@ def download_all_models( _ = get_interrogator_no_blip() del _ - hordelib.initialise() + extra_comfyui_args = [] + if directml is not None: + extra_comfyui_args.append(f"--directml={directml}") + + hordelib.initialise(extra_comfyui_args=extra_comfyui_args) from hordelib.shared_model_manager import SharedModelManager SharedModelManager.load_model_managers() diff --git a/horde_worker_regen/process_management/main_entry_point.py b/horde_worker_regen/process_management/main_entry_point.py index 04e539f5..47e14ded 100644 --- a/horde_worker_regen/process_management/main_entry_point.py +++ b/horde_worker_regen/process_management/main_entry_point.py @@ -12,6 +12,7 @@ def start_working( horde_model_reference_manager: ModelReferenceManager, *, amd_gpu: bool = False, + directml: int | None = None, ) -> None: """Create and start process manager.""" process_manager = HordeWorkerProcessManager( @@ -19,6 +20,7 @@ def start_working( bridge_data=bridge_data, horde_model_reference_manager=horde_model_reference_manager, amd_gpu=amd_gpu, + directml=directml, ) process_manager.start() diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 0c0169f6..6c84fcb4 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -1076,6 +1076,9 @@ def num_total_processes(self) -> int: _amd_gpu: bool """Whether or not the GPU is an AMD GPU.""" + _directml: int | None + """ID of the potential directml device.""" + def __init__( self, *, @@ -1087,6 +1090,7 @@ def __init__( max_safety_processes: int = 1, max_download_processes: int = 1, amd_gpu: bool = False, + directml: int | None = None, ) -> None: """Initialise the process manager. @@ -1103,6 +1107,7 @@ def __init__( max_download_processes (int, optional): The maximum number of download processes that can run at once. \ Defaults to 1. amd_gpu (bool, optional): Whether or not the GPU is an AMD GPU. Defaults to False. + directml (int, optional): ID of the potential directml device. Defaults to None. """ self.session_start_time = time.time() @@ -1124,6 +1129,7 @@ def __init__( self._lru = LRUCache(self.max_inference_processes) self._amd_gpu = amd_gpu + self._directml = directml # If there is only one model to load and only one inference process, then we can only run one job at a time # and there is no point in having more than one inference process @@ -1374,6 +1380,7 @@ def start_safety_processes(self) -> None: kwargs={ "high_memory_mode": self.bridge_data.high_memory_mode, "amd_gpu": self._amd_gpu, + "directml": self._directml, }, ) @@ -1436,6 +1443,7 @@ def _start_inference_process(self, pid: int) -> HordeProcessInfo: "very_high_memory_mode": self.bridge_data.very_high_memory_mode, "high_memory_mode": self.bridge_data.high_memory_mode, "amd_gpu": self._amd_gpu, + "directml": self._directml, }, ) process.start() diff --git a/horde_worker_regen/process_management/worker_entry_points.py b/horde_worker_regen/process_management/worker_entry_points.py index c7e77941..751352f6 100644 --- a/horde_worker_regen/process_management/worker_entry_points.py +++ b/horde_worker_regen/process_management/worker_entry_points.py @@ -24,6 +24,7 @@ def start_inference_process( high_memory_mode: bool = False, very_high_memory_mode: bool = False, amd_gpu: bool = False, + directml: int | None = None, ) -> None: """Start an inference process. @@ -40,6 +41,8 @@ def start_inference_process( Defaults to False. amd_gpu (bool, optional): If true, the process will attempt to use AMD GPU-specific optimisations. Defaults to False. + directml (int | None, optional): If not None, the process will attempt to use DirectML \ + with the specified device """ with contextlib.nullcontext(): # contextlib.redirect_stdout(None), contextlib.redirect_stderr(None): logger.remove() @@ -64,6 +67,9 @@ def start_inference_process( if amd_gpu: extra_comfyui_args.append("--use-pytorch-cross-attention") + if directml is not None: + extra_comfyui_args.append(f"--directml={directml}") + models_not_to_force_load = ["flux"] if very_high_memory_mode: @@ -120,6 +126,7 @@ def start_safety_process( *, high_memory_mode: bool = False, amd_gpu: bool = False, + directml: int | None = None, ) -> None: """Start a safety process. @@ -132,6 +139,8 @@ def start_safety_process( high_memory_mode (bool, optional): If true, the process will attempt to use more memory. Defaults to False. amd_gpu (bool, optional): If true, the process will attempt to use AMD GPU-specific optimisations. Defaults to False. + directml (int | None, optional): If not None, the process will attempt to use DirectML \ + with the specified device """ with contextlib.nullcontext(): # contextlib.redirect_stdout(), contextlib.redirect_stderr(): logger.remove() @@ -153,6 +162,9 @@ def start_safety_process( if amd_gpu: extra_comfyui_args.append("--use-pytorch-cross-attention") + if directml is not None: + extra_comfyui_args.append(f"--directml={directml}") + with logger.catch(reraise=True): hordelib.initialise( setup_logging=None, diff --git a/horde_worker_regen/run_worker.py b/horde_worker_regen/run_worker.py index 8f48af80..4525a022 100644 --- a/horde_worker_regen/run_worker.py +++ b/horde_worker_regen/run_worker.py @@ -18,7 +18,13 @@ from loguru import logger -def main(ctx: BaseContext, load_from_env_vars: bool = False, *, amd_gpu: bool = False) -> None: +def main( + ctx: BaseContext, + load_from_env_vars: bool = False, + *, + amd_gpu: bool = False, + directml: int | None = None, +) -> None: """Check for a valid config and start the driver ('main') process for the reGen worker.""" from horde_model_reference.model_reference_manager import ModelReferenceManager from pydantic import ValidationError @@ -98,6 +104,7 @@ def ensure_model_db_downloaded() -> ModelReferenceManager: bridge_data=bridge_data, horde_model_reference_manager=horde_model_reference_manager, amd_gpu=amd_gpu, + directml=directml, ) @@ -165,6 +172,12 @@ def init() -> None: default=None, help="Override the worker name from the config file, for running multiple workers on one machine", ) + parser.add_argument( + "--directml", + type=int, + default=None, + help="Enable directml and specify device to use.", + ) args = parser.parse_args() @@ -206,7 +219,12 @@ def init() -> None: # We only need to download the legacy DBs once, so we do it here instead of in the worker processes - main(multiprocessing.get_context("spawn"), args.load_config_from_env_vars, amd_gpu=args.amd) + main( + multiprocessing.get_context("spawn"), + args.load_config_from_env_vars, + amd_gpu=args.amd, + directml=args.directml, + ) if __name__ == "__main__": diff --git a/requirements.directml.txt b/requirements.directml.txt new file mode 100644 index 00000000..bd06bc1e --- /dev/null +++ b/requirements.directml.txt @@ -0,0 +1,24 @@ +torch-directml==0.2.5.dev240914 +qrcode==7.4.2 # >8 breaks horde-engine 2.18.1 via the qr code generation nodes + +certifi # Required for SSL cert resolution + +horde_sdk~=0.16.4 +horde_safety~=0.2.3 +horde_engine~=2.18.1 +horde_model_reference>=0.9.1 + +python-dotenv +ruamel.yaml +semver +wheel + +python-Levenshtein + +pydantic>=2.9.2 +typing_extensions +requests +StrEnum +loguru + +babel diff --git a/tests/conftest.py b/tests/conftest.py index 69a1fe65..9651bb0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ def init_hordelib() -> None: PRECOMMIT_FILE_PATH = Path(__file__).parent.parent / ".pre-commit-config.yaml" REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.txt" ROCM_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.rocm.txt" +DIRECTML_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.directml.txt" TRACKED_DEPENDENCIES = [ "horde_sdk", @@ -34,10 +35,11 @@ def tracked_dependencies() -> list[str]: return TRACKED_DEPENDENCIES -@pytest.fixture(scope="session") -def horde_dependency_versions() -> dict[str, str]: - """Get the versions of horde dependencies from the requirements file.""" - with open(REQUIREMENTS_FILE_PATH) as f: +def get_dependency_versions(requirements_file_path: str | Path) -> dict[str, str]: + """Get the versions of horde dependencies from the given requirements file.""" + requirements_file_path = Path(requirements_file_path) + + with open(requirements_file_path) as f: requirements = f.readlines() dependencies = {} @@ -60,27 +62,19 @@ def horde_dependency_versions() -> dict[str, str]: return dependencies +@pytest.fixture(scope="session") +def horde_dependency_versions() -> dict[str, str]: + """Get the versions of horde dependencies from the requirements file.""" + return get_dependency_versions(REQUIREMENTS_FILE_PATH) + + @pytest.fixture(scope="session") def rocm_horde_dependency_versions() -> dict[str, str]: """Get the versions of horde dependencies from the ROCm requirements file.""" - with open(ROCM_REQUIREMENTS_FILE_PATH) as f: - requirements = f.readlines() - - dependencies = {} - for req in requirements: - for dep in TRACKED_DEPENDENCIES: - if req.startswith(dep): - if "==" in req: - version = req.split("==")[1].strip() - elif "~=" in req: - version = req.split("~=")[1].strip() - elif ">=" in req: - version = req.split(">=")[1].strip() - else: - raise ValueError(f"Unsupported version pin: {req}") + return get_dependency_versions(ROCM_REQUIREMENTS_FILE_PATH) - # Strip any info starting from the `+` character - version = version.split("+")[0] - dependencies[dep] = version - return dependencies +@pytest.fixture(scope="session") +def directml_horde_dependency_versions() -> dict[str, str]: + """Get the versions of horde dependencies from the DirectML requirements file.""" + return get_dependency_versions(DIRECTML_REQUIREMENTS_FILE_PATH) diff --git a/tests/test_horde_dep_updates.py b/tests/test_horde_dep_updates.py index 0122a470..3f3f9279 100644 --- a/tests/test_horde_dep_updates.py +++ b/tests/test_horde_dep_updates.py @@ -43,33 +43,54 @@ def test_horde_update_runtime_updating(horde_dependency_versions: dict[str, str] assert found_line, "No initial torch install command found" -def test_different_requirements_files_match( - horde_dependency_versions: dict[str, str], - rocm_horde_dependency_versions: list[tuple[str, str]], +def check_dependency_versions( + main_deps: dict[str, str], + other_deps: dict[str, str], + other_name: str, ) -> None: - """Check that the versions of horde deps. in the main and rocm requirements files match.""" - rocm_deps = dict(rocm_horde_dependency_versions) + """Check that the main requirements file is consistent with the other requirements file. - for dep in horde_dependency_versions: + Args: + main_deps (dict[str, str]): The versions of the dependencies in the main requirements file. + other_deps (dict[str, str]): The versions of the dependencies in the other requirements file. + other_name (str): The name of the other requirements file. + + Raises: + AssertionError: If the versions of the dependencies are inconsistent. + """ + for dep in main_deps: if dep == "torch": logger.warning( - f"Skipping torch version check (main: {horde_dependency_versions[dep]}, rocm: {rocm_deps[dep]})", + f"Skipping torch version check (main: {main_deps[dep]}, {other_name}: {other_deps[dep]})", ) continue - assert dep in rocm_deps, f"Dependency {dep} not found in rocm requirements file" + assert dep in other_deps, f"Dependency {dep} not found in {other_name} requirements file" assert ( - horde_dependency_versions[dep] == rocm_deps[dep] - ), f"Dependency {dep} has different versions in main and rocm requirements files" + main_deps[dep] == other_deps[dep] + ), f"Dependency {dep} has different versions in main and {other_name} requirements files" - for dep in rocm_deps: + for dep in other_deps: if dep == "torch": logger.warning( - f"Skipping torch version check (main: {horde_dependency_versions[dep]}, rocm: {rocm_deps[dep]})", + f"Skipping torch version check (main: {main_deps[dep]}, {other_name}: {other_deps[dep]})", ) continue - assert dep in horde_dependency_versions, f"Dependency {dep} not found in main requirements file" + assert dep in main_deps, f"Dependency {dep} not found in main requirements file" assert ( - rocm_deps[dep] == horde_dependency_versions[dep] - ), f"Dependency {dep} has different versions in main and rocm requirements files" + other_deps[dep] == main_deps[dep] + ), f"Dependency {dep} has different versions in main and {other_name} requirements files" + + +def test_different_requirements_files_match( + horde_dependency_versions: dict[str, str], + rocm_horde_dependency_versions: list[tuple[str, str]], + directml_horde_dependency_versions: list[tuple[str, str]], +) -> None: + """Check that the versions of horde deps. in the all of the various requirements files are consistent.""" + rocm_deps = dict(rocm_horde_dependency_versions) + directml_deps = dict(directml_horde_dependency_versions) + + check_dependency_versions(horde_dependency_versions, rocm_deps, "rocm") + check_dependency_versions(horde_dependency_versions, directml_deps, "directml") diff --git a/tests/test_sdk_models.py b/tests/test_sdk_models.py index acb6dce9..8d71deec 100644 --- a/tests/test_sdk_models.py +++ b/tests/test_sdk_models.py @@ -7,7 +7,7 @@ def test_skipped_status_handles_unknown_fields() -> None: # printed without error. skipped_status = ImageGenerateJobPopSkippedStatus( max_pixels=100, - testing_field=1, + testing_field=1, # type: ignore ) assert skipped_status.max_pixels == 100 diff --git a/update-runtime-directml.cmd b/update-runtime-directml.cmd new file mode 100644 index 00000000..331b1bbd --- /dev/null +++ b/update-runtime-directml.cmd @@ -0,0 +1,86 @@ +@echo off +cd /d "%~dp0" + +SET MAMBA_ROOT_PREFIX=%~dp0conda +echo %MAMBA_ROOT_PREFIX% + +if exist "%MAMBA_ROOT_PREFIX%\condabin\micromamba.bat" ( + echo Deleting micromamba.exe as its out of date + del micromamba.exe + if errorlevel 1 ( + echo Error: Failed to delete micromamba.exe. Please delete it manually. + exit /b 1 + ) + echo Deleting the conda directory as its out of date + rmdir /s /q conda + if errorlevel 1 ( + echo Error: Failed to delete the conda directory. Please delete it manually. + exit /b 1 + ) +) + +:Check if micromamba is already installed +if exist micromamba.exe goto Isolation + curl.exe -L -o micromamba.exe https://github.com/mamba-org/micromamba-releases/releases/latest/download/micromamba-win-64 + + +:Isolation +SET CONDA_SHLVL= +SET PYTHONNOUSERSITE=1 +SET PYTHONPATH= +echo %MAMBA_ROOT_PREFIX% + + + +setlocal EnableDelayedExpansion +for %%a in (%*) do ( + if /I "%%a"=="--hordelib" ( + set hordelib=true + ) else ( + set hordelib= + ) + if /I "%%a"=="--scribe" ( + set scribe=true + ) else ( + set scribe= + ) +) +endlocal + +if defined scribe ( + SET CONDA_ENVIRONMENT_FILE=environment_scribe.yaml + +) else ( + SET CONDA_ENVIRONMENT_FILE=environment.rocm.yaml +) + +Reg add "HKLM\SYSTEM\CurrentControlSet\Control\FileSystem" /v "LongPathsEnabled" /t REG_DWORD /d "1" /f 2>nul +:We do this twice the first time to workaround a conda bug where pip is not installed correctly the first time - Henk +IF EXIST CONDA GOTO WORKAROUND_END +.\micromamba.exe create --no-shortcuts -r conda -n windows -f %CONDA_ENVIRONMENT_FILE% -y +:WORKAROUND_END +.\micromamba.exe create --no-shortcuts -r conda -n windows -f %CONDA_ENVIRONMENT_FILE% -y + +REM Check if hordelib argument is defined + +micromamba.exe shell hook -s cmd.exe %MAMBA_ROOT_PREFIX% -v +call "%MAMBA_ROOT_PREFIX%\condabin\mamba_hook.bat" +call "%MAMBA_ROOT_PREFIX%\condabin\mamba.bat" activate windows + +python -s -m pip install torch-directml torchvision==0.19.1 + +if defined hordelib ( + python -s -m pip uninstall -y hordelib horde_engine horde_model_reference + python -s -m pip install horde_engine horde_model_reference +) else ( + if defined scribe ( + python -s -m pip install -r requirements-scribe.txt + ) else ( + python -s -m pip install -r requirements.directml.txt + ) +) +call deactivate + +echo If there are no errors above everything should be correctly installed (If not, try deleting the folder /conda/envs/ and try again). + +pause