diff --git a/matsciml/__init__.py b/matsciml/__init__.py index 045deb6e..38d2685b 100644 --- a/matsciml/__init__.py +++ b/matsciml/__init__.py @@ -3,22 +3,8 @@ from logging import getLogger # determine if intel libraries are available -from matsciml.common.packages import package_registry +from matsciml.common.packages import package_registry # noqa: F401 __version__ = "1.1.0" logger = getLogger(__file__) - - -if package_registry["ipex"]: - try: - import intel_extension_for_pytorch # noqa: F401 - except ImportError as e: - logger.warning(f"Unable to load IPEX because of {e} - XPU may not function.") -if package_registry["ccl"]: - try: - import oneccl_bindings_for_pytorch # noqa: F401 - except ImportError as e: - logger.warning( - f"Unable to load CCL bindings because of {e} - DDP XPU may not function." - ) diff --git a/matsciml/common/packages.py b/matsciml/common/packages.py index e2c3ae95..a654e526 100644 --- a/matsciml/common/packages.py +++ b/matsciml/common/packages.py @@ -13,25 +13,26 @@ been installed. """ package_registry = {} -package_registry["ipex"] = ( - True if util.find_spec("intel_extension_for_pytorch") else False -) -package_registry["ccl"] = ( - True if util.find_spec("oneccl_bindings_for_pytorch") else False -) # graph specific packages; slightly more involved because we should try import -for package in ["torch_geometric", "torch_scatter", "torch_sparse", "dgl"]: +for package in [ + "torch_geometric", + "torch_scatter", + "torch_sparse", + "dgl", + "intel_extension_for_pytorch", + "oneccl_bindings_for_pytorch", +]: success = False try: import_module(package) success = True except Exception: - logger.opt(exception=True).warning( - f"Could not import {package}, which may impact functionality." - ) + logger.warning(f"Could not import {package}, which may impact functionality.") package_registry[package] = success # for backwards compatibility and looks better anyway package_registry["pyg"] = package_registry["torch_geometric"] +package_registry["ipex"] = package_registry["intel_extension_for_pytorch"] +package_registry["ccl"] = package_registry["oneccl_bindings_for_pytorch"] package_registry["codecarbon"] = True if util.find_spec("codecarbon") else False diff --git a/matsciml/lightning/__init__.py b/matsciml/lightning/__init__.py index c0a61a79..258b2d21 100644 --- a/matsciml/lightning/__init__.py +++ b/matsciml/lightning/__init__.py @@ -2,11 +2,9 @@ # SPDX-License-Identifier: MIT License from __future__ import annotations -from matsciml.common.packages import package_registry +from matsciml.common.packages import package_registry # noqa: F401 from matsciml.lightning.ddp import * from matsciml.lightning.data_utils import * - -if package_registry["ipex"]: - from matsciml.lightning.xpu import * +from matsciml.lightning.xpu import * __all__ = ["MatSciMLDataModule", "MultiDataModule"] diff --git a/matsciml/lightning/xpu.py b/matsciml/lightning/xpu.py index def5e66c..26d37c67 100644 --- a/matsciml/lightning/xpu.py +++ b/matsciml/lightning/xpu.py @@ -16,12 +16,8 @@ logger = getLogger(__file__) -if package_registry["ipex"]: - try: - import intel_extension_for_pytorch as ipex # noqa: F401 - except ImportError as e: - logger.warning(f"Unable to import IPEX due to {e} - XPU may not function.") - +# IPEX is not absolutely required for XPU usage for torch>=2.5.0 +if package_registry["ipex"] or torch.xpu.is_available(): __all__ = ["XPUAccelerator", "SingleXPUStrategy"] class XPUAccelerator(Accelerator): @@ -163,3 +159,7 @@ def register_strategies(cls, strategy_registry) -> None: SingleXPUStrategy, description="Strategy utilizing a single Intel GPU device or tile.", ) +else: + logger.warning( + "IPEX was not installed or XPU is not available. `matsciml.lightning.xpu` will be empty." + ) diff --git a/matsciml/models/pyg/__init__.py b/matsciml/models/pyg/__init__.py index 2e001a44..71ead469 100644 --- a/matsciml/models/pyg/__init__.py +++ b/matsciml/models/pyg/__init__.py @@ -4,8 +4,11 @@ This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ + from __future__ import annotations +from loguru import logger + from matsciml.common.packages import package_registry if "pyg" in package_registry: @@ -16,10 +19,28 @@ # load models if we have PyG installed if _has_pyg: from matsciml.models.pyg.cgcnn import CGCNN - from matsciml.models.pyg.dimenet import DimeNetWrap - from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap from matsciml.models.pyg.egnn import EGNN - from matsciml.models.pyg.faenet import FAENet - from matsciml.models.pyg.forcenet import ForceNet from matsciml.models.pyg.mace import MACE, ScaleShiftMACE - from matsciml.models.pyg.schnet import SchNetWrap + + __all__ = ["CGCNN", "EGNN", "FAENet", "MACE", "ScaleShiftMACE"] + + # these packages need additional pyg dependencies + if package_registry["torch_sparse"] and package_registry["torch_scatter"]: + from matsciml.models.pyg.dimenet import DimeNetWrap + from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap + + __all__.extend(["DimeNetWrap", "DimeNetPlusPlusWrap"]) + else: + logger.warning( + "Missing torch_sparse and torch_scatter; DimeNet models will not be available." + ) + if package_registry["torch_scatter"]: + from matsciml.models.pyg.forcenet import ForceNet + from matsciml.models.pyg.schnet import SchNetWrap + from matsciml.models.pyg.faenet import FAENet + + __all__.extend(["ForceNet", "SchNetWrap", "FAENet"]) + else: + logger.warning( + "Missing torch_scatter; ForceNet, SchNet, and FAENet models will not be available." + )