diff --git a/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py b/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py index f36315fc..a1e006cc 100644 --- a/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py +++ b/_pyinstaller_hooks_contrib/stdhooks/hook-triton.py @@ -10,7 +10,10 @@ # SPDX-License-Identifier: GPL-2.0-or-later # --------------------------------------------------- -from PyInstaller.utils.hooks import collect_data_files, collect_dynamic_libs +from PyInstaller.utils.hooks import collect_data_files, collect_dynamic_libs, collect_submodules, is_module_satisfies + +hiddenimports = [] +datas = [] # Ensure that triton/_C/libtriton.so is collected binaries = collect_dynamic_libs('triton') @@ -26,7 +29,19 @@ 'triton.language': 'py', } -# Collect ptxas compiler files from triton/third_party/cuda directory. Strictly speaking, the ptxas executable from bin -# directory should be collected as a binary, but in this case, it makes no difference (plus, PyInstaller >= 6.0 has -# automatic binary-vs-data reclassification). -datas = collect_data_files('triton.third_party.cuda') +# triton 3.0.0 introduced `triton.backends` sub-package with backend-specific files. +if is_module_satisfies('triton >= 3.0.0'): + # Collect backend sub-modules/packages. + hiddenimports += collect_submodules('triton.backends') + + # At the time of writing (triton v3.1.0), `triton.backends.amd` is a namespace package, and is not captured by the + # above `collect_submodules` call. + hiddenimports += collect_submodules('triton.backends.amd') + + # Collect ptxas compiler files from `triton/backends/nvidia`, and the HIP/ROCm files from `triton/backends/amd`. + datas += collect_data_files('triton.backends') +else: + # Collect ptxas compiler files from triton/third_party/cuda directory. Strictly speaking, the ptxas executable from + # bin directory should be collected as a binary, but in this case, it makes no difference (plus, PyInstaller >= 6.0 + # has automatic binary-vs-data reclassification). + datas += collect_data_files('triton.third_party.cuda') diff --git a/news/848.update.rst b/news/848.update.rst new file mode 100644 index 00000000..ff9ce8d5 --- /dev/null +++ b/news/848.update.rst @@ -0,0 +1,3 @@ +Update ``triton`` hook for compatibility with ``triton`` >= 3.0.0; the +hook should now collect backend-specific modules and data files from +``triton.backends``. diff --git a/tests/test_deep_learning.py b/tests/test_deep_learning.py index e635a04a..8f35f8db 100644 --- a/tests/test_deep_learning.py +++ b/tests/test_deep_learning.py @@ -12,7 +12,8 @@ import pytest -from PyInstaller.utils.tests import importorskip +from PyInstaller import isolated +from PyInstaller.utils.tests import importorskip, requires # Run the tests in onedir mode only @@ -490,3 +491,34 @@ def test_sam2(pyi_builder): print(predictor) """) + + +# Check that backends are properly collected with triton >= 3.0.0 +@requires('triton >= 3.0.0') +def test_triton_backends(pyi_builder, tmp_path): + # Get the list of backends in unfrozen python + @isolated.decorate + def _get_triton_backends(): + import triton.backends + return sorted(list(triton.backends.backends.keys())) + + backends_unfrozen = _get_triton_backends() + print(f"Unfrozen backends: {backends_unfrozen}") + + # Obtain list of backends in frozen application. + output_file = tmp_path / "output.txt" + + pyi_builder.test_source(""" + import sys + import triton.backends + + with open(sys.argv[1], 'w') as fp: + for backend_name in triton.backends.backends.keys(): + print(f"{backend_name}", file=fp) + """, app_args=[str(output_file)]) + + with open(output_file, "r") as fp: + backends_frozen = sorted(line.strip() for line in fp) + print(f"Frozen backends: {backends_frozen}") + + assert backends_frozen == backends_unfrozen