Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revising optional package imports #312

Merged
merged 8 commits into from
Oct 29, 2024
16 changes: 1 addition & 15 deletions matsciml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,8 @@
from logging import getLogger

# determine if intel libraries are available
from matsciml.common.packages import package_registry
from matsciml.common.packages import package_registry # noqa: F401

__version__ = "1.1.0"

logger = getLogger(__file__)


if package_registry["ipex"]:
try:
import intel_extension_for_pytorch # noqa: F401
except ImportError as e:
logger.warning(f"Unable to load IPEX because of {e} - XPU may not function.")
if package_registry["ccl"]:
try:
import oneccl_bindings_for_pytorch # noqa: F401
except ImportError as e:
logger.warning(
f"Unable to load CCL bindings because of {e} - DDP XPU may not function."
)
21 changes: 11 additions & 10 deletions matsciml/common/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,26 @@
been installed.
"""
package_registry = {}
package_registry["ipex"] = (
True if util.find_spec("intel_extension_for_pytorch") else False
)
package_registry["ccl"] = (
True if util.find_spec("oneccl_bindings_for_pytorch") else False
)
# graph specific packages; slightly more involved because we should try import
for package in ["torch_geometric", "torch_scatter", "torch_sparse", "dgl"]:
for package in [
"torch_geometric",
"torch_scatter",
"torch_sparse",
"dgl",
"intel_extension_for_pytorch",
"oneccl_bindings_for_pytorch",
]:
success = False
try:
import_module(package)
success = True
except Exception:
logger.opt(exception=True).warning(
f"Could not import {package}, which may impact functionality."
)
logger.warning(f"Could not import {package}, which may impact functionality.")
package_registry[package] = success
# for backwards compatibility and looks better anyway
package_registry["pyg"] = package_registry["torch_geometric"]
package_registry["ipex"] = package_registry["intel_extension_for_pytorch"]
package_registry["ccl"] = package_registry["oneccl_bindings_for_pytorch"]
package_registry["codecarbon"] = True if util.find_spec("codecarbon") else False


Expand Down
6 changes: 2 additions & 4 deletions matsciml/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# SPDX-License-Identifier: MIT License
from __future__ import annotations

from matsciml.common.packages import package_registry
from matsciml.common.packages import package_registry # noqa: F401
from matsciml.lightning.ddp import *
from matsciml.lightning.data_utils import *

if package_registry["ipex"]:
from matsciml.lightning.xpu import *
from matsciml.lightning.xpu import *

__all__ = ["MatSciMLDataModule", "MultiDataModule"]
12 changes: 6 additions & 6 deletions matsciml/lightning/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@

logger = getLogger(__file__)

if package_registry["ipex"]:
try:
import intel_extension_for_pytorch as ipex # noqa: F401
except ImportError as e:
logger.warning(f"Unable to import IPEX due to {e} - XPU may not function.")

# IPEX is not absolutely required for XPU usage for torch>=2.5.0
if package_registry["ipex"] or torch.xpu.is_available():
__all__ = ["XPUAccelerator", "SingleXPUStrategy"]

class XPUAccelerator(Accelerator):
Expand Down Expand Up @@ -163,3 +159,7 @@ def register_strategies(cls, strategy_registry) -> None:
SingleXPUStrategy,
description="Strategy utilizing a single Intel GPU device or tile.",
)
else:
logger.warning(
"IPEX was not installed or XPU is not available. `matsciml.lightning.xpu` will be empty."
)
31 changes: 26 additions & 5 deletions matsciml/models/pyg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

from loguru import logger

from matsciml.common.packages import package_registry

if "pyg" in package_registry:
Expand All @@ -16,10 +19,28 @@
# load models if we have PyG installed
if _has_pyg:
from matsciml.models.pyg.cgcnn import CGCNN
from matsciml.models.pyg.dimenet import DimeNetWrap
from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap
from matsciml.models.pyg.egnn import EGNN
from matsciml.models.pyg.faenet import FAENet
from matsciml.models.pyg.forcenet import ForceNet
from matsciml.models.pyg.mace import MACE, ScaleShiftMACE
from matsciml.models.pyg.schnet import SchNetWrap

__all__ = ["CGCNN", "EGNN", "FAENet", "MACE", "ScaleShiftMACE"]

# these packages need additional pyg dependencies
if package_registry["torch_sparse"] and package_registry["torch_scatter"]:
from matsciml.models.pyg.dimenet import DimeNetWrap
from matsciml.models.pyg.dimenet_plus_plus import DimeNetPlusPlusWrap

__all__.extend(["DimeNetWrap", "DimeNetPlusPlusWrap"])
else:
logger.warning(
"Missing torch_sparse and torch_scatter; DimeNet models will not be available."
)
if package_registry["torch_scatter"]:
from matsciml.models.pyg.forcenet import ForceNet
from matsciml.models.pyg.schnet import SchNetWrap
from matsciml.models.pyg.faenet import FAENet

__all__.extend(["ForceNet", "SchNetWrap", "FAENet"])
else:
logger.warning(
"Missing torch_scatter; ForceNet, SchNet, and FAENet models will not be available."
)
Loading