diff --git a/sparc/__init__.py b/sparc/__init__.py index 4a69f389..2a494a54 100644 --- a/sparc/__init__.py +++ b/sparc/__init__.py @@ -5,18 +5,27 @@ conda build and CI where not all dependencies are present """ + def _missing_deps_func(*args, **kwargs): raise ImportError("Importing fails for ase / numpy!") + class SPARCMissingDeps: def __init__(self, *args, **kwargs): - raise ImportError("Cannot initialize sparc.SPARC because the required dependencies (ase and numpy) are not available.") + raise ImportError( + "Cannot initialize sparc.SPARC because the required dependencies (ase and numpy) are not available." + ) def __getattr__(self, name): - raise ImportError(f"Cannot access '{name}' on sparc.SPARC because the required dependencies (ase and numpy) are not available.") + raise ImportError( + f"Cannot access '{name}' on sparc.SPARC because the required dependencies (ase and numpy) are not available." + ) + + try: import ase import numpy + _import_complete = True except ImportError: _import_complete = False @@ -25,6 +34,7 @@ def __getattr__(self, name): from .io import read_sparc, write_sparc from .io import register_ase_io_sparc from .calculator import SPARC + register_ase_io_sparc() else: # If importing is not complete, any code trying to directly import @@ -32,4 +42,3 @@ def __getattr__(self, name): read_sparc = _missing_deps_func write_sparc = _missing_deps_func SPARC = SPARCMissingDeps - diff --git a/sparc/utils.py b/sparc/utils.py index 31ff2572..7a29ae2d 100644 --- a/sparc/utils.py +++ b/sparc/utils.py @@ -11,9 +11,7 @@ def deprecated(message): def decorator(func): def new_func(*args, **kwargs): warn( - "Function {} is deprecated! {}".format( - func.__name__, message - ), + "Function {} is deprecated! {}".format(func.__name__, message), category=DeprecationWarning, ) return func(*args, **kwargs) diff --git a/tests/test_import.py b/tests/test_import.py index e4ccc7f5..d425c8a3 100644 --- a/tests/test_import.py +++ b/tests/test_import.py @@ -4,44 +4,25 @@ """ import pytest import sys +import ase -def test_download_data(): - import ase - original_ase = sys.modules.get('ase') - sys.modules['ase'] = None + +def test_download_data(monkeypatch): + monkeypatch.setitem(sys.modules, "ase", None) with pytest.raises(ImportError): import ase from sparc.download_data import download_psp - # Recover sys ase - sys.modules['ase'] = original_ase -def test_api(): - import ase - original_ase = sys.modules.get('ase') - sys.modules['ase'] = None + +def test_api(monkeypatch): + monkeypatch.setitem(sys.modules, "ase", None) with pytest.raises(ImportError): import ase from sparc.api import SparcAPI - # Recover sys ase - sys.modules['ase'] = original_ase -def test_docparser(): - import ase - original_ase = sys.modules.get('ase') - sys.modules['ase'] = None - with pytest.raises(ImportError): - import ase - from sparc.docparser import SPARCDocParser - # Recover sys ase - sys.modules['ase'] = original_ase -def test_normal(): - import ase - original_ase = sys.modules.get('ase') - sys.modules['ase'] = None +def test_docparser(monkeypatch): + monkeypatch.setitem(sys.modules, "ase", None) with pytest.raises(ImportError): import ase - with pytest.raises(ImportError): - from sparc.io import read_sparc - # Recover sys ase - sys.modules['ase'] = original_ase + from sparc.docparser import SPARCDocParser