From 6f7b35d09169a37fd7324e6c5f4fcddd7f1adcbb Mon Sep 17 00:00:00 2001 From: Thomas Roeblitz Date: Tue, 11 Jun 2024 05:36:27 +0200 Subject: [PATCH] set TORCHVISION_* env vars --- eb_hooks.py | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/eb_hooks.py b/eb_hooks.py index 49d2c4ce29..530b362123 100644 --- a/eb_hooks.py +++ b/eb_hooks.py @@ -299,6 +299,44 @@ def parse_hook_pybind11_replace_catch2(ec, eprefix): build_deps[idx] = (catch2_name, catch2_version) +def parse_hook_pytorch_bundle_torchvision_setenv(ec, eprefix): + """ + Set TORCHVISION_{INCLUDE,LIBRARY}, initially for non-CUDA version only + """ + if ec.name == 'PyTorch-bundle' and ec.version in ['2.1.2']: + if not hasattr(ec, 'versionsuffix') or (ec.versionsuffix and not 'CUDA' in ec.versionsuffix): + print_msg("parse_hook for PyTorch-bundle without CUDA: extslist '%s'", ec['exts_list']) + print_msg("New exts_list: '%s'", ec['exts_list']) + ec_dict = ec.asdict() + libpng_root = os.getenv('EBROOTLIBPNG') + libpng_include = os.path.join(libpng_root, 'include') + libpng_lib = os.path.join(libpng_root, 'lib') + libjpeg_turbo_root = os.getenv('EBROOTLIBJPEGMINTURBO') + libjpeg_turbo_include = os.path.join(libjpeg_turbo_root, 'include') + libjpeg_turbo_lib = os.path.join(libjpeg_turbo_root, 'lib') + exts_list_new = [] + torchvision_include = 'export TORCHVISION_INCLUDE=%s:%s' % (libpng_include, libjpeg_turbo_include) + torchvision_library = 'export TORCHVISION_LIBRARY=%s:%s' % (libpng_lib, libjpeg_turbo_lib) + for item in ec_dict['exts_list']: + if item[0] != 'torchvision': + exts_list_new.append(item) + else: + ext_dict = item[2] + if 'preinstallopts' in ext_dict: + raise EasyBuildError("found value for 'preinstallopts' for extension 'torchvision'," + " but expected NONE") + else: + # add preinstallopts + ext_dict['preinstallopts'] = torchvision_include + '; ' + torchvision_library + exts_list_new.append((item[0], item[1], ext_dict)) + ec['exts_list'] = exts_list_new + print_msg("New exts_list: '%s'", ec['exts_list']) + else: + print_msg("parse_hook for PyTorch-bundle for CUDA -> leaving preinstallopts unchanged") + else: + raise EasyBuildError("PyTorch-bundle-specific hook triggered for non-PyTorch-bundle easyconfig?!") + + def parse_hook_qt5_check_qtwebengine_disable(ec, eprefix): """ Disable check for QtWebEngine in Qt5 as workaround for problem with determining glibc version. @@ -312,7 +350,6 @@ def parse_hook_qt5_check_qtwebengine_disable(ec, eprefix): raise EasyBuildError("Qt5-specific hook triggered for non-Qt5 easyconfig?!") - def parse_hook_sentencepiece_disable_tcmalloc_aarch64(ec, eprefix): """ Disable using TC_Malloc on 'aarch64/generic' @@ -943,6 +980,7 @@ def pre_module_hook_librosa_augment_modluafooter(self, *args, **kwargs): 'OpenBLAS': parse_hook_openblas_relax_lapack_tests_num_errors, 'Pillow-SIMD' : parse_hook_Pillow_SIMD_harcoded_paths, 'pybind11': parse_hook_pybind11_replace_catch2, + 'PyTorch-bundle': parse_hook_pytorch_bundle_torchvision_setenv, 'Qt5': parse_hook_qt5_check_qtwebengine_disable, 'SentencePiece': parse_hook_sentencepiece_disable_tcmalloc_aarch64, 'UCX': parse_hook_ucx_eprefix,