From 8bb63c0c72b1f77d43ad72f4ce25e80d583bc2c2 Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Sun, 1 Sep 2024 21:23:03 +0100 Subject: [PATCH] Manually add back new code from main branch --- environment-dev-arm.yml | 5 +++- environment-dev.yml | 5 +++- pylops/signalprocessing/__init__.py | 3 +++ pylops/utils/deps.py | 36 ++++++++++++++++++++++++++--- pyproject.toml | 2 +- requirements-dev.txt | 5 +++- 6 files changed, 49 insertions(+), 7 deletions(-) diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml index 5ba127d3..c711fe76 100755 --- a/environment-dev-arm.yml +++ b/environment-dev-arm.yml @@ -8,8 +8,10 @@ dependencies: - python>=3.6.4 - pip - numpy>=1.21.0 - - scipy>=1.4.0 + - scipy>=1.11.0 - pytorch>=1.2.0 + - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -34,6 +36,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx + - sphinxemoji - image - flake8 - mypy diff --git a/environment-dev.yml b/environment-dev.yml index 24f38424..135319f7 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -8,8 +8,10 @@ dependencies: - python>=3.6.4 - pip - numpy>=1.21.0 - - scipy>=1.4.0 + - scipy>=1.11.0 - pytorch>=1.2.0 + - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -35,6 +37,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx + - sphinxemoji - image - flake8 - mypy diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index b3fa0881..b4f8fd7b 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -23,6 +23,7 @@ Shift Fractional Shift operator. DWT One dimensional Wavelet operator. DWT2D Two dimensional Wavelet operator. + DWTND N-dimensional Wavelet operator. DCT Discrete Cosine Transform. DTCWT Dual-Tree Complex Wavelet Transform. Radon2D Two dimensional Radon transform. @@ -62,6 +63,7 @@ from .fredholm1 import * from .dwt import * from .dwt2d import * +from .dwtnd import * from .seislet import * from .dct import * from .dtcwt import * @@ -95,6 +97,7 @@ "Fredholm1", "DWT", "DWT2D", + "DWTND", "Seislet", "DCT", "DTCWT", diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index 3497ce86..b320028b 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,5 +1,6 @@ __all__ = [ "cupy_enabled", + "jax_enabled", "devito_enabled", "dtcwt_enabled", "ucurv_enabled", @@ -51,6 +52,32 @@ def cupy_import(message: Optional[str] = None) -> str: return cupy_message +def jax_import(message: Optional[str] = None) -> str: + jax_test = ( + util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1 + ) + if jax_test: + try: + import_module("jax") # noqa: F401 + + jax_message = None + except (ImportError, ModuleNotFoundError) as e: + jax_message = ( + f"Failed to import jax, Falling back to numpy (error: {e}). " + "Please ensure your environment is set up correctly " + "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" + ) + print(UserWarning(jax_message)) + else: + jax_message = ( + "Jax package not installed or os.getenv('JAX_PYLOPS') == 0. " + f"In order to be able to use {message} " + "ensure 'os.getenv('JAX_PYLOPS') == 1' and run " + "'pip install jax'; " + "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" + ) + + return jax_message def devito_import(message: Optional[str] = None) -> str: if devito_enabled: @@ -211,15 +238,18 @@ def sympy_import(message: Optional[str] = None) -> str: # Set package availability booleans -# cupy: the package is imported to check everything is working correctly, -# if not the package is disabled. We do this here as this library is used as drop-in -# replacement for many numpy and scipy routines when cupy arrays are provided. +# cupy and jax: the package is imported to check everything is working correctly, +# if not the package is disabled. We do this here as these libraries are used as drop-in +# replacement for many numpy and scipy routines when cupy/jax arrays are provided. # all other libraries: we simply check if the package is available and postpone its import # to check everything is working correctly when a user tries to create an operator that requires # such a package cupy_enabled: bool = ( True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False ) +jax_enabled: bool = ( + True if (jax_import() is None and int(os.getenv("JAX_PYLOPS", 1)) == 1) else False +) devito_enabled = util.find_spec("devito") is not None dtcwt_enabled = util.find_spec("dtcwt") is not None ucurv_enabled = util.find_spec("ucurv") is not None diff --git a/pyproject.toml b/pyproject.toml index 5e435cc7..01b44d01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ ] dependencies = [ "numpy >= 1.21.0", - "scipy >= 1.4.0", + "scipy >= 1.11.0", ] dynamic = ["version"] diff --git a/requirements-dev.txt b/requirements-dev.txt index cd4b5911..42477b9a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,8 @@ numpy>=1.21.0 -scipy>=1.4.0 +scipy>=1.11.0 +--extra-index-url https://download.pytorch.org/whl/cpu torch>=1.2.0 +jax numba pyfftw PyWavelets @@ -18,6 +20,7 @@ docutils<0.18 Sphinx pydata-sphinx-theme sphinx-gallery +sphinxemoji numpydoc nbsphinx image