Skip to content

Commit

Permalink
hooks: update triton hook for compatibility with triton >= 3.0.0
Browse files Browse the repository at this point in the history
Update `triton` hook to collect backend-specific modules and data
files from `triton.backends` sub-package that was introduced in
triton v3.0.0.
  • Loading branch information
rokm committed Jan 5, 2025
1 parent c3b5792 commit f5d77ef
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
25 changes: 20 additions & 5 deletions _pyinstaller_hooks_contrib/stdhooks/hook-triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
# SPDX-License-Identifier: GPL-2.0-or-later
# ---------------------------------------------------

from PyInstaller.utils.hooks import collect_data_files, collect_dynamic_libs
from PyInstaller.utils.hooks import collect_data_files, collect_dynamic_libs, collect_submodules, is_module_satisfies

hiddenimports = []
datas = []

# Ensure that triton/_C/libtriton.so is collected
binaries = collect_dynamic_libs('triton')
Expand All @@ -26,7 +29,19 @@
'triton.language': 'py',
}

# Collect ptxas compiler files from triton/third_party/cuda directory. Strictly speaking, the ptxas executable from bin
# directory should be collected as a binary, but in this case, it makes no difference (plus, PyInstaller >= 6.0 has
# automatic binary-vs-data reclassification).
datas = collect_data_files('triton.third_party.cuda')
# triton 3.0.0 introduced `triton.backends` sub-package with backend-specific files.
if is_module_satisfies('triton >= 3.0.0'):
# Collect backend sub-modules/packages.
hiddenimports += collect_submodules('triton.backends')

# At the time of writing (triton v3.1.0), `triton.backends.amd` is a namespace package, and is not captured by the
# above `collect_submodules` call.
hiddenimports += collect_submodules('triton.backends.amd')

# Collect ptxas compiler files from `triton/backends/nvidia`, and the HIP/ROCm files from `triton/backends/amd`.
datas += collect_data_files('triton.backends')
else:
# Collect ptxas compiler files from triton/third_party/cuda directory. Strictly speaking, the ptxas executable from
# bin directory should be collected as a binary, but in this case, it makes no difference (plus, PyInstaller >= 6.0
# has automatic binary-vs-data reclassification).
datas += collect_data_files('triton.third_party.cuda')
3 changes: 3 additions & 0 deletions news/848.update.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Update ``triton`` hook for compatibility with ``triton`` >= 3.0.0; the
hook should now collect backend-specific modules and data files from
``triton.backends``.
34 changes: 33 additions & 1 deletion tests/test_deep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import pytest

from PyInstaller.utils.tests import importorskip
from PyInstaller import isolated
from PyInstaller.utils.tests import importorskip, requires


# Run the tests in onedir mode only
Expand Down Expand Up @@ -490,3 +491,34 @@ def test_sam2(pyi_builder):
print(predictor)
""")


# Check that backends are properly collected with triton >= 3.0.0
@requires('triton >= 3.0.0')
def test_triton_backends(pyi_builder, tmp_path):
# Get the list of backends in unfrozen python
@isolated.decorate
def _get_triton_backends():
import triton.backends
return sorted(list(triton.backends.backends.keys()))

backends_unfrozen = _get_triton_backends()
print(f"Unfrozen backends: {backends_unfrozen}")

# Obtain list of backends in frozen application.
output_file = tmp_path / "output.txt"

pyi_builder.test_source("""
import sys
import triton.backends
with open(sys.argv[1], 'w') as fp:
for backend_name in triton.backends.backends.keys():
print(f"{backend_name}", file=fp)
""", app_args=[str(output_file)])

with open(output_file, "r") as fp:
backends_frozen = sorted(line.strip() for line in fp)
print(f"Frozen backends: {backends_frozen}")

assert backends_frozen == backends_unfrozen

0 comments on commit f5d77ef

Please sign in to comment.