Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable directml #356

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
default=False,
help="Load the config only from environment variables. This is useful for running the worker in a container.",
)
parser.add_argument(
"--directml",
type=int,
default=None,
help="Enable directml and specify device to use.",
)

args = parser.parse_args()

Expand All @@ -25,4 +31,5 @@
download_all_models(
purge_unused_loras=args.purge_unused_loras,
load_config_from_env_vars=args.load_config_from_env_vars,
directml=args.directml,
)
35 changes: 35 additions & 0 deletions horde-bridge-directml.cmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@echo off
cd /d %~dp0

: This first call to runtime activates the environment for the rest of the script
call runtime python -s -m pip -V

call python -s -m pip uninstall hordelib
call python -s -m pip install horde_sdk~=0.16.4 horde_model_reference~=0.9.1 horde_engine~=2.18.1 horde_safety~=0.2.3 -U

if %ERRORLEVEL% NEQ 0 (
echo "Please run update-runtime.cmd."
GOTO END
)

call python -s -m pip check
if %ERRORLEVEL% NEQ 0 (
echo "Please run update-runtime.cmd."
GOTO END
)

:DOWNLOAD
call python -s download_models.py --directml=0
if %ERRORLEVEL% NEQ 0 GOTO ABORT
echo "Model Download OK. Starting worker..."
call python -s run_worker.py --directml=0 %*

GOTO END

:ABORT
echo "download_models.py exited with error code. Aborting"

:END
call micromamba deactivate >nul
call deactivate >nul
pause
7 changes: 6 additions & 1 deletion horde_worker_regen/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ def download_all_models(
*,
load_config_from_env_vars: bool = False,
purge_unused_loras: bool = False,
directml: int | None = None,
) -> None:
"""Download all models specified in the config file."""
from horde_worker_regen.load_env_vars import load_env_vars_from_config
Expand Down Expand Up @@ -55,7 +56,11 @@ def download_all_models(
_ = get_interrogator_no_blip()
del _

hordelib.initialise()
extra_comfyui_args = []
if directml is not None:
extra_comfyui_args.append(f"--directml={directml}")

hordelib.initialise(extra_comfyui_args=extra_comfyui_args)
from hordelib.shared_model_manager import SharedModelManager

SharedModelManager.load_model_managers()
Expand Down
2 changes: 2 additions & 0 deletions horde_worker_regen/process_management/main_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ def start_working(
horde_model_reference_manager: ModelReferenceManager,
*,
amd_gpu: bool = False,
directml: int | None = None,
) -> None:
"""Create and start process manager."""
process_manager = HordeWorkerProcessManager(
ctx=ctx,
bridge_data=bridge_data,
horde_model_reference_manager=horde_model_reference_manager,
amd_gpu=amd_gpu,
directml=directml,
)

process_manager.start()
8 changes: 8 additions & 0 deletions horde_worker_regen/process_management/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,9 @@ def num_total_processes(self) -> int:
_amd_gpu: bool
"""Whether or not the GPU is an AMD GPU."""

_directml: int | None
"""ID of the potential directml device."""

def __init__(
self,
*,
Expand All @@ -1087,6 +1090,7 @@ def __init__(
max_safety_processes: int = 1,
max_download_processes: int = 1,
amd_gpu: bool = False,
directml: int | None = None,
) -> None:
"""Initialise the process manager.

Expand All @@ -1103,6 +1107,7 @@ def __init__(
max_download_processes (int, optional): The maximum number of download processes that can run at once. \
Defaults to 1.
amd_gpu (bool, optional): Whether or not the GPU is an AMD GPU. Defaults to False.
directml (int, optional): ID of the potential directml device. Defaults to None.
"""
self.session_start_time = time.time()

Expand All @@ -1124,6 +1129,7 @@ def __init__(
self._lru = LRUCache(self.max_inference_processes)

self._amd_gpu = amd_gpu
self._directml = directml

# If there is only one model to load and only one inference process, then we can only run one job at a time
# and there is no point in having more than one inference process
Expand Down Expand Up @@ -1374,6 +1380,7 @@ def start_safety_processes(self) -> None:
kwargs={
"high_memory_mode": self.bridge_data.high_memory_mode,
"amd_gpu": self._amd_gpu,
"directml": self._directml,
},
)

Expand Down Expand Up @@ -1436,6 +1443,7 @@ def _start_inference_process(self, pid: int) -> HordeProcessInfo:
"very_high_memory_mode": self.bridge_data.very_high_memory_mode,
"high_memory_mode": self.bridge_data.high_memory_mode,
"amd_gpu": self._amd_gpu,
"directml": self._directml,
},
)
process.start()
Expand Down
12 changes: 12 additions & 0 deletions horde_worker_regen/process_management/worker_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def start_inference_process(
high_memory_mode: bool = False,
very_high_memory_mode: bool = False,
amd_gpu: bool = False,
directml: int | None = None,
) -> None:
"""Start an inference process.

Expand All @@ -40,6 +41,8 @@ def start_inference_process(
Defaults to False.
amd_gpu (bool, optional): If true, the process will attempt to use AMD GPU-specific optimisations.
Defaults to False.
directml (int | None, optional): If not None, the process will attempt to use DirectML \
with the specified device
"""
with contextlib.nullcontext(): # contextlib.redirect_stdout(None), contextlib.redirect_stderr(None):
logger.remove()
Expand All @@ -64,6 +67,9 @@ def start_inference_process(
if amd_gpu:
extra_comfyui_args.append("--use-pytorch-cross-attention")

if directml is not None:
extra_comfyui_args.append(f"--directml={directml}")

models_not_to_force_load = ["flux"]

if very_high_memory_mode:
Expand Down Expand Up @@ -120,6 +126,7 @@ def start_safety_process(
*,
high_memory_mode: bool = False,
amd_gpu: bool = False,
directml: int | None = None,
) -> None:
"""Start a safety process.

Expand All @@ -132,6 +139,8 @@ def start_safety_process(
high_memory_mode (bool, optional): If true, the process will attempt to use more memory. Defaults to False.
amd_gpu (bool, optional): If true, the process will attempt to use AMD GPU-specific optimisations.
Defaults to False.
directml (int | None, optional): If not None, the process will attempt to use DirectML \
with the specified device
"""
with contextlib.nullcontext(): # contextlib.redirect_stdout(), contextlib.redirect_stderr():
logger.remove()
Expand All @@ -153,6 +162,9 @@ def start_safety_process(
if amd_gpu:
extra_comfyui_args.append("--use-pytorch-cross-attention")

if directml is not None:
extra_comfyui_args.append(f"--directml={directml}")

with logger.catch(reraise=True):
hordelib.initialise(
setup_logging=None,
Expand Down
22 changes: 20 additions & 2 deletions horde_worker_regen/run_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from loguru import logger


def main(ctx: BaseContext, load_from_env_vars: bool = False, *, amd_gpu: bool = False) -> None:
def main(
ctx: BaseContext,
load_from_env_vars: bool = False,
*,
amd_gpu: bool = False,
directml: int | None = None,
) -> None:
"""Check for a valid config and start the driver ('main') process for the reGen worker."""
from horde_model_reference.model_reference_manager import ModelReferenceManager
from pydantic import ValidationError
Expand Down Expand Up @@ -98,6 +104,7 @@ def ensure_model_db_downloaded() -> ModelReferenceManager:
bridge_data=bridge_data,
horde_model_reference_manager=horde_model_reference_manager,
amd_gpu=amd_gpu,
directml=directml,
)


Expand Down Expand Up @@ -165,6 +172,12 @@ def init() -> None:
default=None,
help="Override the worker name from the config file, for running multiple workers on one machine",
)
parser.add_argument(
"--directml",
type=int,
default=None,
help="Enable directml and specify device to use.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -206,7 +219,12 @@ def init() -> None:

# We only need to download the legacy DBs once, so we do it here instead of in the worker processes

main(multiprocessing.get_context("spawn"), args.load_config_from_env_vars, amd_gpu=args.amd)
main(
multiprocessing.get_context("spawn"),
args.load_config_from_env_vars,
amd_gpu=args.amd,
directml=args.directml,
)


if __name__ == "__main__":
Expand Down
24 changes: 24 additions & 0 deletions requirements.directml.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
torch-directml==0.2.5.dev240914
qrcode==7.4.2 # >8 breaks horde-engine 2.18.1 via the qr code generation nodes

certifi # Required for SSL cert resolution

horde_sdk~=0.16.4
horde_safety~=0.2.3
horde_engine~=2.18.1
horde_model_reference>=0.9.1

python-dotenv
ruamel.yaml
semver
wheel

python-Levenshtein

pydantic>=2.9.2
typing_extensions
requests
StrEnum
loguru

babel
40 changes: 17 additions & 23 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def init_hordelib() -> None:
PRECOMMIT_FILE_PATH = Path(__file__).parent.parent / ".pre-commit-config.yaml"
REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.txt"
ROCM_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.rocm.txt"
DIRECTML_REQUIREMENTS_FILE_PATH = Path(__file__).parent.parent / "requirements.directml.txt"

TRACKED_DEPENDENCIES = [
"horde_sdk",
Expand All @@ -34,10 +35,11 @@ def tracked_dependencies() -> list[str]:
return TRACKED_DEPENDENCIES


@pytest.fixture(scope="session")
def horde_dependency_versions() -> dict[str, str]:
"""Get the versions of horde dependencies from the requirements file."""
with open(REQUIREMENTS_FILE_PATH) as f:
def get_dependency_versions(requirements_file_path: str | Path) -> dict[str, str]:
"""Get the versions of horde dependencies from the given requirements file."""
requirements_file_path = Path(requirements_file_path)

with open(requirements_file_path) as f:
requirements = f.readlines()

dependencies = {}
Expand All @@ -60,27 +62,19 @@ def horde_dependency_versions() -> dict[str, str]:
return dependencies


@pytest.fixture(scope="session")
def horde_dependency_versions() -> dict[str, str]:
"""Get the versions of horde dependencies from the requirements file."""
return get_dependency_versions(REQUIREMENTS_FILE_PATH)


@pytest.fixture(scope="session")
def rocm_horde_dependency_versions() -> dict[str, str]:
"""Get the versions of horde dependencies from the ROCm requirements file."""
with open(ROCM_REQUIREMENTS_FILE_PATH) as f:
requirements = f.readlines()

dependencies = {}
for req in requirements:
for dep in TRACKED_DEPENDENCIES:
if req.startswith(dep):
if "==" in req:
version = req.split("==")[1].strip()
elif "~=" in req:
version = req.split("~=")[1].strip()
elif ">=" in req:
version = req.split(">=")[1].strip()
else:
raise ValueError(f"Unsupported version pin: {req}")
return get_dependency_versions(ROCM_REQUIREMENTS_FILE_PATH)

# Strip any info starting from the `+` character
version = version.split("+")[0]
dependencies[dep] = version

return dependencies
@pytest.fixture(scope="session")
def directml_horde_dependency_versions() -> dict[str, str]:
"""Get the versions of horde dependencies from the DirectML requirements file."""
return get_dependency_versions(DIRECTML_REQUIREMENTS_FILE_PATH)
Loading
Loading