Skip to content

Commit

Permalink
Manually add back new code from main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
yud08 committed Sep 1, 2024
1 parent c45c7bf commit 8bb63c0
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 7 deletions.
5 changes: 4 additions & 1 deletion environment-dev-arm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ dependencies:
- python>=3.6.4
- pip
- numpy>=1.21.0
- scipy>=1.4.0
- scipy>=1.11.0
- pytorch>=1.2.0
- cpuonly
- jax
- pyfftw
- pywavelets
- sympy
Expand All @@ -34,6 +36,7 @@ dependencies:
- pydata-sphinx-theme
- sphinx-gallery
- nbsphinx
- sphinxemoji
- image
- flake8
- mypy
Expand Down
5 changes: 4 additions & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ dependencies:
- python>=3.6.4
- pip
- numpy>=1.21.0
- scipy>=1.4.0
- scipy>=1.11.0
- pytorch>=1.2.0
- cpuonly
- jax
- pyfftw
- pywavelets
- sympy
Expand All @@ -35,6 +37,7 @@ dependencies:
- pydata-sphinx-theme
- sphinx-gallery
- nbsphinx
- sphinxemoji
- image
- flake8
- mypy
3 changes: 3 additions & 0 deletions pylops/signalprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Shift Fractional Shift operator.
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
DWTND N-dimensional Wavelet operator.
DCT Discrete Cosine Transform.
DTCWT Dual-Tree Complex Wavelet Transform.
Radon2D Two dimensional Radon transform.
Expand Down Expand Up @@ -62,6 +63,7 @@
from .fredholm1 import *
from .dwt import *
from .dwt2d import *
from .dwtnd import *
from .seislet import *
from .dct import *
from .dtcwt import *
Expand Down Expand Up @@ -95,6 +97,7 @@
"Fredholm1",
"DWT",
"DWT2D",
"DWTND",
"Seislet",
"DCT",
"DTCWT",
Expand Down
36 changes: 33 additions & 3 deletions pylops/utils/deps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = [
"cupy_enabled",
"jax_enabled",
"devito_enabled",
"dtcwt_enabled",
"ucurv_enabled",
Expand Down Expand Up @@ -51,6 +52,32 @@ def cupy_import(message: Optional[str] = None) -> str:

return cupy_message

def jax_import(message: Optional[str] = None) -> str:
jax_test = (
util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1
)
if jax_test:
try:
import_module("jax") # noqa: F401

jax_message = None
except (ImportError, ModuleNotFoundError) as e:
jax_message = (
f"Failed to import jax, Falling back to numpy (error: {e}). "
"Please ensure your environment is set up correctly "
"for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'"
)
print(UserWarning(jax_message))
else:
jax_message = (
"Jax package not installed or os.getenv('JAX_PYLOPS') == 0. "
f"In order to be able to use {message} "
"ensure 'os.getenv('JAX_PYLOPS') == 1' and run "
"'pip install jax'; "
"for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'"
)

return jax_message

def devito_import(message: Optional[str] = None) -> str:
if devito_enabled:
Expand Down Expand Up @@ -211,15 +238,18 @@ def sympy_import(message: Optional[str] = None) -> str:


# Set package availability booleans
# cupy: the package is imported to check everything is working correctly,
# if not the package is disabled. We do this here as this library is used as drop-in
# replacement for many numpy and scipy routines when cupy arrays are provided.
# cupy and jax: the package is imported to check everything is working correctly,
# if not the package is disabled. We do this here as these libraries are used as drop-in
# replacement for many numpy and scipy routines when cupy/jax arrays are provided.
# all other libraries: we simply check if the package is available and postpone its import
# to check everything is working correctly when a user tries to create an operator that requires
# such a package
cupy_enabled: bool = (
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
)
jax_enabled: bool = (
True if (jax_import() is None and int(os.getenv("JAX_PYLOPS", 1)) == 1) else False
)
devito_enabled = util.find_spec("devito") is not None
dtcwt_enabled = util.find_spec("dtcwt") is not None
ucurv_enabled = util.find_spec("ucurv") is not None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [
]
dependencies = [
"numpy >= 1.21.0",
"scipy >= 1.4.0",
"scipy >= 1.11.0",
]
dynamic = ["version"]

Expand Down
5 changes: 4 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
numpy>=1.21.0
scipy>=1.4.0
scipy>=1.11.0
--extra-index-url https://download.pytorch.org/whl/cpu
torch>=1.2.0
jax
numba
pyfftw
PyWavelets
Expand All @@ -18,6 +20,7 @@ docutils<0.18
Sphinx
pydata-sphinx-theme
sphinx-gallery
sphinxemoji
numpydoc
nbsphinx
image
Expand Down

0 comments on commit 8bb63c0

Please sign in to comment.