From 5ec43b8a2802008abd8ca1581124d11a857ba55f Mon Sep 17 00:00:00 2001 From: Rok Mandeljc Date: Thu, 21 Dec 2023 12:09:47 +0100 Subject: [PATCH] hooks: tensorflow: collect plugins from tensorflow-plugins Have the `tensorflow` standard hook collect binaries from the `tensorflow-plugins` package; this contains plugins for tensorflow's pluggable device architecture (such as `tensorflow-metal` for macOS and `tensorflow-directml-plugin` for Windows). Have the `tensorflow` run-time hook override the `site.getsitepackages()` with custom implementation that allows us to trick `tensorflow` into loading the plugins. --- news/676.update.7.rst | 5 ++ .../hooks/rthooks/pyi_rth_tensorflow.py | 50 ++++++++++++++++--- .../hooks/stdhooks/hook-tensorflow.py | 7 +++ 3 files changed, 54 insertions(+), 8 deletions(-) create mode 100644 news/676.update.7.rst diff --git a/news/676.update.7.rst b/news/676.update.7.rst new file mode 100644 index 000000000..9b6b456d4 --- /dev/null +++ b/news/676.update.7.rst @@ -0,0 +1,5 @@ +Extend ``tensorflow`` hook to collect plugins installed in the +``tensorflow-plugins`` directory/package. Have the run-time ``tensorflow`` +hook provide an override for ``site.getsitepackages()`` that allows us +to work around a broken module file location check and trick ``tensorflow`` +into loading the collected plugins. diff --git a/src/_pyinstaller_hooks_contrib/hooks/rthooks/pyi_rth_tensorflow.py b/src/_pyinstaller_hooks_contrib/hooks/rthooks/pyi_rth_tensorflow.py index c967f4895..5941013e7 100644 --- a/src/_pyinstaller_hooks_contrib/hooks/rthooks/pyi_rth_tensorflow.py +++ b/src/_pyinstaller_hooks_contrib/hooks/rthooks/pyi_rth_tensorflow.py @@ -9,11 +9,45 @@ # SPDX-License-Identifier: Apache-2.0 #----------------------------------------------------------------------------- -# `tensorflow` versions prior to 2.3.0 attempt to use `site.USER_SITE` in path/string manipulation functions. -# As frozen application runs with disabled `site`, the value of this variable is `None`, and causes path/string -# manipulation functions to raise an error. As a work-around, we set `site.USER_SITE` to an empty string, which is -# also what the fake `site` module available in PyInstaller prior to v5.5 did. -import site - -if site.USER_SITE is None: - site.USER_SITE = '' +def _pyi_rthook(): + import sys + + # `tensorflow` versions prior to 2.3.0 attempt to use `site.USER_SITE` in path/string manipulation functions. + # As frozen application runs with disabled `site`, the value of this variable is `None`, and causes path/string + # manipulation functions to raise an error. As a work-around, we set `site.USER_SITE` to an empty string, which is + # also what the fake `site` module available in PyInstaller prior to v5.5 did. + import site + + if site.USER_SITE is None: + site.USER_SITE = '' + + # The issue described about with site.USER_SITE being None has largely been resolved in contemporary `tensorflow` + # versions, which now check that `site.ENABLE_USER_SITE` is set and that `site.USER_SITE` is not None before + # trying to use it. + # + # However, `tensorflow` will attempt to search and load its plugins only if it believes that it is running from + # "a pip-based installation" - if the package's location is rooted in one of the "site-packages" directories. See + # https://github.com/tensorflow/tensorflow/blob/6887368d6d46223f460358323c4b76d61d1558a8/tensorflow/api_template.__init__.py#L110C76-L156 + # Unfortunately, they "cleverly" infer the module's location via `inspect.getfile(inspect.currentframe())`, which + # in the frozen application returns anonymized relative source file name (`tensorflow/__init__.py`) - so we need one + # of the "site directories" to be just "tensorflow" (to fool the `_running_from_pip_package()` check), and we also + # need `sys._MEIPASS` to be among them (to load the plugins from the actual `sys._MEIPASS/tensorflow-plugins`). + # Therefore, we monkey-patch `site.getsitepackages` to add those two entries to the list of "site directories". + + _orig_getsitepackages = getattr(site, 'getsitepackages') + + def _pyi_getsitepackages(): + return [ + sys._MEIPASS, + "tensorflow", + *(_orig_getsitepackages() if _orig_getsitepackages is not None else []), + ] + + site.getsitepackages = _pyi_getsitepackages + + # NOTE: instead of the above override, we could also set TF_PLUGGABLE_DEVICE_LIBRARY_PATH, but that works only + # for tensorflow >= 2.12. + + +_pyi_rthook() +del _pyi_rthook diff --git a/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-tensorflow.py b/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-tensorflow.py index 97dfe6b5f..64da3b94b 100644 --- a/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-tensorflow.py +++ b/src/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-tensorflow.py @@ -16,6 +16,7 @@ from PyInstaller.compat import is_linux from PyInstaller.utils.hooks import ( collect_data_files, + collect_dynamic_libs, collect_submodules, get_module_attribute, is_module_satisfies, @@ -125,6 +126,7 @@ def _submodules_filter(x): if version >= Version("2.14.0"): hiddenimports += ['ml_dtypes'] +binaries = [] excludedimports = excluded_submodules # Suppress warnings for missing hidden imports generated by this hook. @@ -165,3 +167,8 @@ def _infer_nvidia_hiddenimports(): nvidia_hiddenimports = [] logger.info("hook-tensorflow: inferred hidden imports for CUDA libraries: %r", nvidia_hiddenimports) hiddenimports += nvidia_hiddenimports + + +# Collect the tensorflow-plugins (pluggable device plugins) +hiddenimports += ['tensorflow-plugins'] +binaries += collect_dynamic_libs('tensorflow-plugins')