From f5d77ef53ede77e4941c38bbab612cb2c293f3c3 Mon Sep 17 00:00:00 2001 From: Rok Mandeljc Date: Sun, 5 Jan 2025 20:08:55 +0100 Subject: [PATCH] hooks: update triton hook for compatibility with triton >= 3.0.0 Update `triton` hook to collect backend-specific modules and data files from `triton.backends` sub-package that was introduced in triton v3.0.0. --- .../stdhooks/hook-triton.py | 25 +++++++++++--- news/848.update.rst | 3 ++ tests/test_deep_learning.py | 34 ++++++++++++++++++- 3 files changed, 56 insertions(+), 6 deletions(-) create mode 100644 news/848.update.rst diff --git a/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py b/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py index f36315fcd..a1e006cc6 100644 --- a/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py +++ b/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py @@ -10,7 +10,10 @@ # SPDX-License-Identifier: GPL-2.0-or-later # --------------------------------------------------- -from PyInstaller.utils.hooks import collect_data_files, collect_dynamic_libs +from PyInstaller.utils.hooks import collect_data_files, collect_dynamic_libs, collect_submodules, is_module_satisfies + +hiddenimports = [] +datas = [] # Ensure that triton/_C/libtriton.so is collected binaries = collect_dynamic_libs('triton') @@ -26,7 +29,19 @@ 'triton.language': 'py', } -# Collect ptxas compiler files from triton/third_party/cuda directory. Strictly speaking, the ptxas executable from bin -# directory should be collected as a binary, but in this case, it makes no difference (plus, PyInstaller >= 6.0 has -# automatic binary-vs-data reclassification). -datas = collect_data_files('triton.third_party.cuda') +# triton 3.0.0 introduced `triton.backends` sub-package with backend-specific files. +if is_module_satisfies('triton >= 3.0.0'): + # Collect backend sub-modules/packages. + hiddenimports += collect_submodules('triton.backends') + + # At the time of writing (triton v3.1.0), `triton.backends.amd` is a namespace package, and is not captured by the + # above `collect_submodules` call. + hiddenimports += collect_submodules('triton.backends.amd') + + # Collect ptxas compiler files from `triton/backends/nvidia`, and the HIP/ROCm files from `triton/backends/amd`. + datas += collect_data_files('triton.backends') +else: + # Collect ptxas compiler files from triton/third_party/cuda directory. Strictly speaking, the ptxas executable from + # bin directory should be collected as a binary, but in this case, it makes no difference (plus, PyInstaller >= 6.0 + # has automatic binary-vs-data reclassification). + datas += collect_data_files('triton.third_party.cuda') diff --git a/news/848.update.rst b/news/848.update.rst new file mode 100644 index 000000000..ff9ce8d5b --- /dev/null +++ b/news/848.update.rst @@ -0,0 +1,3 @@ +Update ``triton`` hook for compatibility with ``triton`` >= 3.0.0; the +hook should now collect backend-specific modules and data files from +``triton.backends``. diff --git a/tests/test_deep_learning.py b/tests/test_deep_learning.py index e635a04a3..8f35f8db6 100644 --- a/tests/test_deep_learning.py +++ b/tests/test_deep_learning.py @@ -12,7 +12,8 @@ import pytest -from PyInstaller.utils.tests import importorskip +from PyInstaller import isolated +from PyInstaller.utils.tests import importorskip, requires # Run the tests in onedir mode only @@ -490,3 +491,34 @@ def test_sam2(pyi_builder): print(predictor) """) + + +# Check that backends are properly collected with triton >= 3.0.0 +@requires('triton >= 3.0.0') +def test_triton_backends(pyi_builder, tmp_path): + # Get the list of backends in unfrozen python + @isolated.decorate + def _get_triton_backends(): + import triton.backends + return sorted(list(triton.backends.backends.keys())) + + backends_unfrozen = _get_triton_backends() + print(f"Unfrozen backends: {backends_unfrozen}") + + # Obtain list of backends in frozen application. + output_file = tmp_path / "output.txt" + + pyi_builder.test_source(""" + import sys + import triton.backends + + with open(sys.argv[1], 'w') as fp: + for backend_name in triton.backends.backends.keys(): + print(f"{backend_name}", file=fp) + """, app_args=[str(output_file)]) + + with open(output_file, "r") as fp: + backends_frozen = sorted(line.strip() for line in fp) + print(f"Frozen backends: {backends_frozen}") + + assert backends_frozen == backends_unfrozen