diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
index 16278ba7..a0887beb 100644
--- a/.github/workflows/build.yaml
+++ b/.github/workflows/build.yaml
@@ -6,8 +6,8 @@ jobs:
build:
strategy:
matrix:
- platform: [ ubuntu-latest, macos-latest ]
- python-version: ["3.8", "3.9", "3.10", "3.11"]
+ platform: [ ubuntu-latest, macos-13 ]
+ python-version: ["3.9", "3.10", "3.11"]
runs-on: ${{ matrix.platform }}
steps:
diff --git a/.github/workflows/codacy-coverage-reporter.yaml b/.github/workflows/codacy-coverage-reporter.yaml
index c9a8596f..0e610bb1 100644
--- a/.github/workflows/codacy-coverage-reporter.yaml
+++ b/.github/workflows/codacy-coverage-reporter.yaml
@@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
platform: [ ubuntu-latest, ]
- python-version: ["3.8", ]
+ python-version: ["3.9", ]
runs-on: ${{ matrix.platform }}
steps:
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index a74e3480..202c29ca 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -18,6 +18,6 @@ sphinx:
# Declare the Python requirements required to build your docs
python:
install:
- - requirements: requirements-dev.txt
+ - requirements: requirements-doc.txt
- method: pip
path: .
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 93ce3982..09d6c06b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,33 @@
+Changelog
+=========
+
+# 2.3.1
+* Fixed bug in :py:mod:`pylops.utils.backend` (see [Issue #606](https://github.com/PyLops/pylops/issues/606))
+
+# 2.3.0
+
+* Added `pylops.JaxOperator`, `pylops.signalprocessing.DWTND`, and `pylops.signalprocessing.DTCWT` operators.
+* Added `updatesrc` method to `pylops.waveeqprocessing.AcousticWave2D`.
+* Added `verb` to `pylops.signalprocessing.Sliding1D.sliding1d_design`, `pylops.signalprocessing.Sliding2D.sliding2d_design`, `pylops.signalprocessing.Sliding3D.sliding3d_design`, `pylops.signalprocessing.Patch2D.patch2d_design`, and `pylops.signalprocessing.Patch3D.patch3d_design`.
+* Added `kwargs_fft` to `pylops.signalprocessing.FFTND`.
+* Added `cosinetaper` to `pylops.utils.tapers.cosinetaper`.
+* Added `kind` to `pylops.waveeqprocessing.Deghosting`.
+* Modified all methods in `pylops.utils.backend` to enable jax integration.
+* Modified implementations of `pylops.signalprocessing.Sliding1D`, `pylops.signalprocessing.Sliding2D`,
+`pylops.signalprocessing.Sliding3D`, `pylops.signalprocessing.Patch2D`, and
+`pylops.signalprocessing.Patch3D` to being directly implemented instead of relying on other PyLops operators. Added also `savetaper` parameter and an option to apply the operator `Op` simultaneously to all windows.
+* Modified `pylops.waveeqprocessing.AcousticWave2D._born_oneshot` and
+`pylops.waveeqprocessing.AcousticWave2D._born_allshots` to avoid recreating the devito solver for each shot (and enabling internal caching...)
+* Modified `dtype` of `pylops.signalprocessing.Shift` to be that of the input vector.
+* Modified `pylops.waveeqprocessing.BlendingContinuous` to use `matvec/rmatvec` instead of `@/.H @` for compatibility with pylops solvers.
+* Removed `cusignal` as optional dependency and `cupy`'s equivalent methods (since the library
+is now unmantained and merged into `cupy`).
+* Fixed ImportError of optional dependencies when installed but not correctly functioning (see [Issue #548](https://github.com/PyLops/pylops/issues/548))
+* Fixed bug in `pylops.utils.deps.to_cupy_conditional` (see [Issue #579](https://github.com/PyLops/pylops/issues/579))
+* Fixed bug in the definition of `nttot` in `pylops.waveeqprocessing.BlendingContinuous`
+* Fixed bug in `pylops.utils.signalprocessing.dip_estimate` (see [Issue #572](https://github.com/PyLops/pylops/issues/572))
+
+
# 2.2.0
* Added `pylops.signalprocessing.NonStationaryConvolve3D` operator
@@ -287,7 +317,7 @@ To aid users in navigating the breaking changes, we provide the following docume
``pylops.waveeqprocessing.UpDownComposition3Doperator``, and
``pylops.waveeqprocessing.PhaseShift`` operators
* Fix bug in ``pylops.basicoperators.Kronecker``
- (see [Issue #125](https://github.com/Statoil/pylops/issues/125))
+ (see [Issue #125](https://github.com/PyLops/pylops/issues/125))
# 1.7.0
* Added ``pylops.basicoperators.Gradient``,
diff --git a/Makefile b/Makefile
index 341c53be..35cf0e1e 100755
--- a/Makefile
+++ b/Makefile
@@ -1,7 +1,7 @@
PIP := $(shell command -v pip3 2> /dev/null || command which pip 2> /dev/null)
PYTHON := $(shell command -v python3 2> /dev/null || command which python 2> /dev/null)
-.PHONY: install dev-install install_conda dev-install_conda tests doc docupdate
+.PHONY: install dev-install install_conda dev-install_conda tests doc docupdate servedoc lint typeannot coverage
pipcheck:
ifndef PIP
diff --git a/README.md b/README.md
index 17512736..71399610 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[![PyPI version](https://badge.fury.io/py/pylops.svg)](https://badge.fury.io/py/pylops)
[![Anaconda-Server Badge](https://anaconda.org/conda-forge/pylops/badges/version.svg)](https://anaconda.org/conda-forge/pylops)
[![AzureDevOps Status](https://dev.azure.com/matteoravasi/PyLops/_apis/build/status/PyLops.pylops?branchName=dev)](https://dev.azure.com/matteoravasi/PyLops/_build/latest?definitionId=9&branchName=dev)
-[![GithubAction Status](https://github.com/mrava87/pylops/actions/workflows/build.yaml/badge.svg)](https://github.com/mrava87/pylops/actions/workflows/build.yaml)
+[![GithubAction Status](https://github.com/PyLops/pylops/actions/workflows/build.yaml/badge.svg?branch=dev)](https://github.com/PyLops/pylops/actions/workflows/build.yaml)
[![Documentation Status](https://readthedocs.org/projects/pylops/badge/?version=stable)](https://pylops.readthedocs.io/en/stable/?badge=stable)
[![Codacy Badge](https://app.codacy.com/project/badge/Grade/17fd60b4266347d8890dd6b64f2c0807)](https://www.codacy.com/gh/PyLops/pylops/dashboard?utm_source=github.com&utm_medium=referral&utm_content=PyLops/pylops&utm_campaign=Badge_Grade)
[![Codacy Badge](https://app.codacy.com/project/badge/Coverage/17fd60b4266347d8890dd6b64f2c0807)](https://www.codacy.com/gh/PyLops/pylops/dashboard?utm_source=github.com&utm_medium=referral&utm_content=PyLops/pylops&utm_campaign=Badge_Coverage)
@@ -150,3 +150,5 @@ A list of video tutorials to learn more about PyLops:
* Wei Zhang, ZhangWeiGeo
* Fedor Goncharov, fedor-goncharov
* Alex Rakowski, alex-rakowski
+* David Sollberger, solldavid
+* Gustavo Coelho, guaacoelho
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index ef5306bb..00db3746 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -26,7 +26,7 @@ jobs:
# steps:
# - task: UsePythonVersion@0
# inputs:
-# versionSpec: '3.7'
+# versionSpec: '3.9'
# architecture: 'x64'
#
# - script: |
@@ -55,7 +55,7 @@ jobs:
steps:
- task: UsePythonVersion@0
inputs:
- versionSpec: '3.8'
+ versionSpec: '3.9'
architecture: 'x64'
- script: |
@@ -84,7 +84,7 @@ jobs:
steps:
- task: UsePythonVersion@0
inputs:
- versionSpec: '3.8'
+ versionSpec: '3.9'
architecture: 'x64'
- script: |
diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst
index 6879b247..1d77dc15 100755
--- a/docs/source/api/index.rst
+++ b/docs/source/api/index.rst
@@ -29,6 +29,7 @@ Templates
FunctionOperator
MemoizeOperator
TorchOperator
+ JaxOperator
Basic operators
~~~~~~~~~~~~~~~
@@ -102,6 +103,7 @@ Signal processing
Shift
DWT
DWT2D
+ DWTND
DCT
DTCWT
Seislet
diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst
index ad3d94b1..b5854475 100644
--- a/docs/source/changelog.rst
+++ b/docs/source/changelog.rst
@@ -3,20 +3,61 @@
Changelog
=========
+
+Version 2.3.1
+-------------
+
+*Released on: 17/08/2024*
+
+* Fixed bug in :py:mod:`pylops.utils.backend` (see https://github.com/PyLops/pylops/issues/606)
+
+
+Version 2.3.0
+-------------
+
+*Released on: 16/08/2024*
+
+* Added :py:class:`pylops.JaxOperator`, :py:class:`pylops.signalprocessing.DWTND`, and :py:class:`pylops.signalprocessing.DTCWT` operators.
+* Added `updatesrc` method to :py:class:`pylops.waveeqprocessing.AcousticWave2D`
+* Added `verb` to :py:func:`pylops.signalprocessing.Sliding1D.sliding1d_design`, :py:func:`pylops.signalprocessing.Sliding2D.sliding2d_design`,
+ :py:func:`pylops.signalprocessing.Sliding3D.sliding3d_design`, :py:func:`pylops.signalprocessing.Patch2D.patch2d_design`,
+ and :py:func:`pylops.signalprocessing.Patch3D.patch3d_design`
+* Added `kwargs_fft` to :py:class:`pylops.signalprocessing.FFTND`
+* Added `cosinetaper` to :py:class:`pylops.utils.tapers.cosinetaper`
+* Added `kind` to :py:class:`pylops.waveeqprocessing.Deghosting`.
+* Modified all methods in :py:mod:`pylops.utils.backend` to enable jax integration
+* Modified implementations of :py:class:`pylops.signalprocessing.Sliding1D`, :py:class:`pylops.signalprocessing.Sliding2D`,
+ :py:class:`pylops.signalprocessing.Sliding3D`, :py:class:`pylops.signalprocessing.Patch2D`, and
+ :py:class:`pylops.signalprocessing.Patch3D` to being directly implemented instead of relying on
+ other PyLops operators. Added also `savetaper` parameter and an option to apply the operator `Op`
+ simultaneously to all windows
+* Modified :py:func:`pylops.waveeqprocessing.AcousticWave2D._born_oneshot`
+ and :py:func:`pylops.waveeqprocessing.AcousticWave2D._born_allshots` to avoid
+ recreating the devito solver for each shot (and enabling internal caching...)
+* Modified `dtype` of :py:class:`pylops.signalprocessing.Shift` to be that of the input vector.
+* Modified :py:class:`pylops.waveeqprocessing.BlendingContinuous` to use `matvec/rmatvec` instead of `@/.H @`
+ for compatibility with pylops solvers
+* Removed `cusignal` as optional dependency and `cupy`'s equivalent methods (since the library
+ is now unmantained and merged into `cupy`)
+* Fixed ImportError of optional dependencies when installed but not correctly functioning (see https://github.com/PyLops/pylops/issues/548)
+* Fixed bug in :py:func:`pylops.utils.deps.to_cupy_conditional` (see https://github.com/PyLops/pylops/issues/579)
+* Fixed bug in the definition of `nttot` in :py:class:`pylops.waveeqprocessing.BlendingContinuous`
+* Fixed bug in :py:func:`pylops.utils.signalprocessing.dip_estimate` (see https://github.com/PyLops/pylops/issues/572)
+
Version 2.2.0
-------------
*Released on: 11/11/2023*
-* Added :class:`pylops.signalprocessing.NonStationaryConvolve3D` operator
-* Added nd-array capabilities to :class:`pylops.basicoperators.Identity` and :class:`pylops.basicoperators.Zero`
-* Added second implementation in :class:`pylops.waveeqprocessing.BlendingContinuous` which is more
+* Added :py:class:`pylops.signalprocessing.NonStationaryConvolve3D` operator
+* Added nd-array capabilities to :py:class:`pylops.basicoperators.Identity` and :py:class:`pylops.basicoperators.Zero`
+* Added second implementation in :py:class:`pylops.waveeqprocessing.BlendingContinuous` which is more
performant when dealing with small number of receivers
-* Added `forceflat` property to operators with ambiguous `rmatvec` (:class:`pylops.basicoperators.Block`,
- :class:`pylops.basicoperators.Bilinear`, :class:`pylops.basicoperators.BlockDiag`, :class:`pylops.basicoperators.HStack`,
- :class:`pylops.basicoperators.MatrixMult`, :class:`pylops.basicoperators.VStack`, and :class:`pylops.basicoperators.Zero`)
-* Improved `dynamic` mode of :class:`pylops.waveeqprocessing.Kirchhoff` operator
-* Modified :class:`pylops.signalprocessing.Convolve1D` to allow both filters that are both shorter and longer of the
+* Added `forceflat` property to operators with ambiguous `rmatvec` (:py:class:`pylops.basicoperators.Block`,
+ :py:class:`pylops.basicoperators.Bilinear`, :py:class:`pylops.basicoperators.BlockDiag`, :py:class:`pylops.basicoperators.HStack`,
+ :py:class:`pylops.basicoperators.MatrixMult`, :py:class:`pylops.basicoperators.VStack`, and :py:class:`pylops.basicoperators.Zero`)
+* Improved `dynamic` mode of :py:class:`pylops.waveeqprocessing.Kirchhoff` operator
+* Modified :py:class:`pylops.signalprocessing.Convolve1D` to allow both filters that are both shorter and longer of the
input vector
* Modified all solvers to use `matvec/rmatvec` instead of `@/.H @` to improve performance
@@ -26,19 +67,19 @@ Version 2.1.0
*Released on: 17/03/2023*
-* Added :class:`pylops.signalprocessing.DCT`, :class:`pylops.signalprocessing.NonStationaryConvolve1D`,
- :class:`pylops.signalprocessing.NonStationaryConvolve2D`, :class:`pylops.signalprocessing.NonStationaryFilters1D`, and
- :class:`pylops.signalprocessing.NonStationaryFilters2D` operators
-* Added :class:`pylops.waveeqprocessing.BlendingContinuous`, :class:`pylops.waveeqprocessing.BlendingGroup`, and
- :class:`pylops.waveeqprocessing.BlendingHalf` operators
-* Added `kind='datamodel'` to :class:`pylops.optimization.cls_sparsity.IRLS`
-* Improved inner working of :class:`pylops.waveeqprocessing.Kirchhoff` operator significantly
+* Added :py:class:`pylops.signalprocessing.DCT`, :py:class:`pylops.signalprocessing.NonStationaryConvolve1D`,
+ :py:class:`pylops.signalprocessing.NonStationaryConvolve2D`, :py:class:`pylops.signalprocessing.NonStationaryFilters1D`, and
+ :py:class:`pylops.signalprocessing.NonStationaryFilters2D` operators
+* Added :py:class:`pylops.waveeqprocessing.BlendingContinuous`, :py:class:`pylops.waveeqprocessing.BlendingGroup`, and
+ :py:class:`pylops.waveeqprocessing.BlendingHalf` operators
+* Added `kind='datamodel'` to :py:class:`pylops.optimization.cls_sparsity.IRLS`
+* Improved inner working of :py:class:`pylops.waveeqprocessing.Kirchhoff` operator significantly
reducing the memory usage related to storing traveltime, angle, and amplitude tables.
-* Improved handling of `haxes` in :class:`pylops.signalprocessing.Radon2D` and :class:`pylops.signalprocessing.Radon3D` operators
-* Added possibility to feed ND-arrays to :class:`pylops.TorchOperator`
-* Removed :class:`pylops.LinearOperator` inheritance and added `__call__` method to :class:`pylops.TorchOperator`
-* Removed `scipy.sparse.linalg.LinearOperator` and added :class:`abc.ABC` inheritance to :class:`pylops.LinearOperator`
-* All operators are now classes of `:class:`pylops.LinearOperator` type
+* Improved handling of `haxes` in :py:class:`pylops.signalprocessing.Radon2D` and :py:class:`pylops.signalprocessing.Radon3D` operators
+* Added possibility to feed ND-arrays to :py:class:`pylops.TorchOperator`
+* Removed :py:class:`pylops.LinearOperator` inheritance and added `__call__` method to :py:class:`pylops.TorchOperator`
+* Removed `scipy.sparse.linalg.LinearOperator` and added :py:class:`abc.ABC` inheritance to :py:class:`pylops.LinearOperator`
+* All operators are now classes of `:py:class:`pylops.LinearOperator` type
Version 2.0.0
@@ -56,25 +97,25 @@ To aid users in navigating the breaking changes, we provide the following docume
Users do not need to use ``.ravel`` and ``.reshape`` as often anymore. See the migration guide for more information.
* Typing annotations for several submodules (``avo``, ``basicoperators``, ``signalprocessing``, ``utils``, ``optimization``,
``waveeqprocessing``)
-* New :class:`pylops.TorchOperator` wraps a Pylops operator into a PyTorch function
-* New :class:`pylops.signalprocessing.Patch3D` applies a linear operator repeatedly to patches of the model vector
-* Each of :class:`pylops.signalprocessing.Sliding1D`, :class:`pylops.signalprocessing.Sliding2D`,
- :class:`pylops.signalprocessing.Sliding3D`, :class:`pylops.signalprocessing.Patch2D` and :class:`pylops.signalprocessing.Patch3D`
+* New :py:class:`pylops.TorchOperator` wraps a Pylops operator into a PyTorch function
+* New :py:class:`pylops.signalprocessing.Patch3D` applies a linear operator repeatedly to patches of the model vector
+* Each of :py:class:`pylops.signalprocessing.Sliding1D`, :py:class:`pylops.signalprocessing.Sliding2D`,
+ :py:class:`pylops.signalprocessing.Sliding3D`, :py:class:`pylops.signalprocessing.Patch2D` and :py:class:`pylops.signalprocessing.Patch3D`
have an associated ``slidingXd_design`` or ``patchXd_design`` functions associated with them to aid the user in designing the windows
-* :class:`pylops.FirstDerivative` and :class:`pylops.SecondDerivative`, and therefore other derivative operators which rely on the
- (e.g., :class:`pylops.Gradient`) support higher order stencils
-* :class:`pylops.waveeqprocessing.Kirchhoff` substitutes :class:`pylops.waveeqprocessing.Demigration` and incorporates a variety of
+* :py:class:`pylops.FirstDerivative` and :py:class:`pylops.SecondDerivative`, and therefore other derivative operators which rely on the
+ (e.g., :py:class:`pylops.Gradient`) support higher order stencils
+* :py:class:`pylops.waveeqprocessing.Kirchhoff` substitutes :py:class:`pylops.waveeqprocessing.Demigration` and incorporates a variety of
new functionalities
-* New :class:`pylops.waveeqprocessing.AcousticWave2D` wraps the `Devito `_ acoutic wave propagator
+* New :py:class:`pylops.waveeqprocessing.AcousticWave2D` wraps the `Devito `_ acoutic wave propagator
providing a wave-equation based Born modeling operator with a reverse-time migration adjoint
-* Solvers can now be implemented via the :class:`pylops.optimization.basesolver.Solver` class. They can now be used through a
- functional interface with lowercase name (e.g., :func:`pylops.optimization.sparsity.splitbregman`) or via class interface with CamelCase name
- (e.g., :class:`pylops.optimization.cls_sparsity.SplitBregman`. Moreover, solvers now accept callbacks defined by the
- :class:`pylops.optimization.callback.Callbacks` interface (see e.g., :class:`pylops.optimization.callback.MetricsCallback`).
-* Metrics such as :func:`pylops.utils.metrics.mae` and :func:`pylops.utils.metrics.mse` and others
-* New :func:`pylops.utils.signalprocessing.dip_estimate` estimates local dips in an image (measured in radians) in a stabler way than the old :func:`pylops.utils.signalprocessing.dip_estimate` did for slopes.
-* New :func:`pylops.utils.tapers.tapernd` for N-dimensional tapers
-* New wavelets :func:`pylops.utils.wavelets.klauder` and :func:`pylops.utils.wavelets.ormsby`
+* Solvers can now be implemented via the :py:class:`pylops.optimization.basesolver.Solver` class. They can now be used through a
+ functional interface with lowercase name (e.g., :py:func:`pylops.optimization.sparsity.splitbregman`) or via class interface with CamelCase name
+ (e.g., :py:class:`pylops.optimization.cls_sparsity.SplitBregman`. Moreover, solvers now accept callbacks defined by the
+ :py:class:`pylops.optimization.callback.Callbacks` interface (see e.g., :py:class:`pylops.optimization.callback.MetricsCallback`)
+* Metrics such as :py:func:`pylops.utils.metrics.mae` and :py:func:`pylops.utils.metrics.mse` and others
+* New :py:func:`pylops.utils.signalprocessing.dip_estimate` estimates local dips in an image (measured in radians) in a stabler way than the old :py:func:`pylops.utils.signalprocessing.dip_estimate` did for slopes.
+* New :py:func:`pylops.utils.tapers.tapernd` for N-dimensional tapers
+* New wavelets :py:func:`pylops.utils.wavelets.klauder` and :py:func:`pylops.utils.wavelets.ormsby`
**Documentation**
@@ -210,7 +251,7 @@ Version 1.15.0
``full``, ``half``, or ``trapezoidal`` integration.
* Fixed `_hardthreshold_percentile` in
:py:mod:`pylops.optimization.sparsity`
- - `Issue #249 `_.
+ (see https://github.com/PyLops/pylops/issues/249).
* Fixed r2norm in :py:func:`pylops.optimization.solver.cgls`.
@@ -261,7 +302,7 @@ Version 1.13.0
* Fixed bug in data reshaping in check in
:py:class:`pylops.avo.prestack.PrestackInversion`
* Fixed loading error when using old cupy and/or cusignal
- (see `Issue #201 `_)
+ (see https://github.com/PyLops/pylops/issues/201)
Version 1.12.0
@@ -380,7 +421,7 @@ Version 1.8.0
:py:class:`pylops.waveeqprocessing.UpDownComposition3Doperator`, and
:py:class:`pylops.waveeqprocessing.PhaseShift` operators
* Fix bug in :py:class:`pylops.basicoperators.Kronecker`
- (see `Issue #125 `_)
+ (see https://github.com/PyLops/pylops/issues/125)
Version 1.7.0
diff --git a/docs/source/conf.py b/docs/source/conf.py
index caf745e5..c5e6536d 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -21,6 +21,7 @@
"numpydoc",
"nbsphinx",
"sphinx_gallery.gen_gallery",
+ "sphinxemoji.sphinxemoji",
# 'sphinx.ext.napoleon',
]
@@ -29,6 +30,8 @@
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "cupy": ("https://docs.cupy.dev/en/stable/", None),
+ "jax": ("https://jax.readthedocs.io/en/latest", None),
"sklearn": ("http://scikit-learn.org/stable/", None),
"pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None),
"matplotlib": ("https://matplotlib.org/", None),
diff --git a/docs/source/credits.rst b/docs/source/credits.rst
index 6310549d..c9c8e129 100755
--- a/docs/source/credits.rst
+++ b/docs/source/credits.rst
@@ -22,3 +22,5 @@ Contributors
* `Wei Zhang `_, ZhangWeiGeo
* `Fedor Goncharov `_, fedor-goncharov
* `Alex Rakowski `_, alex-rakowski
+* `David Sollberger `_, solldavid
+* `Gustavo Coelho `_, guaacoelho
\ No newline at end of file
diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst
index 864451c4..a2d7d9f1 100755
--- a/docs/source/gpu.rst
+++ b/docs/source/gpu.rst
@@ -1,55 +1,404 @@
.. _gpu:
-GPU Support
-===========
+GPU / TPU Support
+=================
Overview
--------
-PyLops supports computations on GPUs powered by `CuPy `_ (``cupy-cudaXX>=10.6.0``).
+From ``v1.12.0``, PyLops supports computations on GPUs powered by
+`CuPy `_ (``cupy-cudaXX>=13.0.0``).
This library must be installed *before* PyLops is installed.
-.. note::
-
- Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend.
- This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system,
- otherwise you will get an error when importing PyLops.
+From ``v2.3.0``, PyLops supports also computations on GPUs/TPUs powered by
+`JAX `_.
+This library must be installed *before* PyLops is installed.
+.. note::
+ Set environment variables ``CUPY_PYLOPS=0`` and/or ``JAX_PYLOPS=0`` to force PyLops to ignore
+ ``cupy`` and ``jax`` backends. This can be also used if a previous version of ``cupy``
+ or ``jax`` is installed in your system, otherwise you will get an error when importing PyLops.
Apart from a few exceptions, all operators and solvers in PyLops can
-seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy`` arrays
-on GPU. Users do simply need to consistently create operators and
+seamlessly work with ``numpy`` arrays on CPU as well as with ``cupy/jax`` arrays
+on GPU. For CuPy, users simply need to consistently create operators and
provide data vectors to the solvers, e.g., when using
:class:`pylops.MatrixMult` the input matrix must be a
``cupy`` array if the data provided to a solver is also ``cupy`` array.
+For JAX, apart from following the same procedure described for CuPy, the PyLops operator must
+be also wrapped into a :class:`pylops.JaxOperator`.
-.. warning::
- Some :class:`pylops.LinearOperator` methods are currently on GPU:
+In the following, we provide a list of methods in :class:`pylops.LinearOperator` with their current status (available on CPU,
+GPU with CuPy, and GPU with JAX):
- - :meth:`pylops.LinearOperator.eigs`
- - :meth:`pylops.LinearOperator.cond`
- - :meth:`pylops.LinearOperator.tosparse`
- - :meth:`pylops.LinearOperator.estimate_spectral_norm`
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
-.. warning::
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :meth:`pylops.LinearOperator.cond`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :meth:`pylops.LinearOperator.conj`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.LinearOperator.div`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.LinearOperator.eigs`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :meth:`pylops.LinearOperator.todense`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.LinearOperator.tosparse`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :meth:`pylops.LinearOperator.trace`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+
+Similarly, we provide a list of operators with their current status.
+
+Basic operators:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.basicoperators.MatrixMult`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Identity`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Zero`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Diagonal`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :meth:`pylops.basicoperators.Transpose`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Flip`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Roll`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Pad`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Sum`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Symmetrize`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Restriction`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Regression`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.LinearRegression`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.CausalIntegration`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Spread`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.basicoperators.VStack`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.HStack`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Block`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.BlockDiag`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+
+
+Smoothing and derivatives:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.basicoperators.FirstDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.SecondDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Laplacian`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.Gradient`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.FirstDirectionalDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.basicoperators.SecondDirectionalDerivative`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+
+Signal processing:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.signalprocessing.Convolve1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:warning:|
+ * - :class:`pylops.signalprocessing.Convolve2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.ConvolveND`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.NonStationaryConvolve1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.NonStationaryFilters1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.NonStationaryConvolve2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.NonStationaryFilters2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Interp`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.Bilinear`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.FFT`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.FFT2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.FFTND`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.Shift`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.signalprocessing.DWT`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.DWT2D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.DCT`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Seislet`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Radon2D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Radon3D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.ChirpRadon2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.ChirpRadon3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Sliding1D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Sliding2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Sliding3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Patch2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Patch3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:red_circle:|
+ * - :class:`pylops.signalprocessing.Fredholm1`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
- Some operators are currently not available on GPU:
+Wave-Equation processing
- - :class:`pylops.Spread`
- - :class:`pylops.signalprocessing.Radon2D`
- - :class:`pylops.signalprocessing.Radon3D`
- - :class:`pylops.signalprocessing.DWT`
- - :class:`pylops.signalprocessing.DWT2D`
- - :class:`pylops.signalprocessing.Seislet`
- - :class:`pylops.waveeqprocessing.Demigration`
- - :class:`pylops.waveeqprocessing.LSM`
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.avo.avo.PressureToVelocity`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.UpDownComposition2D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.UpDownComposition3D`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.BlendingContinuous`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.BlendingGroup`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.BlendingHalf`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.MDC`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.avo.Kirchhoff`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+ * - :class:`pylops.avo.avo.AcousticWave2D`
+ - |:white_check_mark:|
+ - |:red_circle:|
+ - |:red_circle:|
+
+Geophysical subsurface characterization:
+
+.. list-table::
+ :widths: 50 25 25 25
+ :header-rows: 1
+
+ * - Operator/method
+ - CPU
+ - GPU with CuPy
+ - GPU/TPU with JAX
+ * - :class:`pylops.avo.avo.AVOLinearModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.poststack.PoststackLinearModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ * - :class:`pylops.avo.prestack.PrestackLinearModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:warning:|
+ * - :class:`pylops.avo.prestack.PrestackWaveletModelling`
+ - |:white_check_mark:|
+ - |:white_check_mark:|
+ - |:warning:|
.. warning::
- Some solvers are currently not available on GPU:
- - :class:`pylops.optimization.sparsity.SPGL1`
+ 1. The JAX backend of the :class:`pylops.signalprocessing.Convolve1D` operator
+ currently works only with 1d-arrays due to a different behaviour of
+ :meth:`scipy.signal.convolve` and :meth:`jax.scipy.signal.convolve` with
+ nd-arrays.
+
+ 2. The JAX backend of the :class:`pylops.avo.prestack.PrestackLinearModelling`
+ operator currently works only with ``explicit=True`` due to the same issue as
+ in point 1 for the :class:`pylops.signalprocessing.Convolve1D` operator employed
+ when ``explicit=False``.
Example
@@ -68,8 +417,7 @@ Finally, let's briefly look at an example. First we write a code snippet using
y = Gop * x
xest = Gop / y
-
-Now we write a code snippet using ``cupy`` arrays which PyLops will run on
+Now we write a code snippet using ``cupy`` arrays which PyLops will run on
your GPU:
.. code-block:: python
@@ -83,9 +431,28 @@ your GPU:
xest = Gop / y
The code is almost unchanged apart from the fact that we now use ``cupy`` arrays,
-PyLops will figure this out!
+PyLops will figure this out.
+
+Similarly, we write a code snippet using ``jax`` arrays which PyLops will run on
+your GPU/TPU:
+
+.. code-block:: python
+
+ ny, nx = 400, 400
+ G = jnp.array(np.random.normal(0, 1, (ny, nx)).astype(np.float32))
+ x = jnp.ones(nx, dtype=np.float32)
+
+ Gop = JaxOperator(MatrixMult(G, dtype='float32'))
+ y = Gop * x
+ xest = Gop / y
+
+ # Adjoint via AD
+ xadj = Gop.rmatvecad(x, y)
+
+
+Again, the code is almost unchanged apart from the fact that we now use ``jax`` arrays,
.. note::
- The CuPy backend is in active development, with many examples not yet in the docs.
- You can find many `other examples `_ from the `PyLops Notebooks repository `_.
+ More examples for the CuPy and JAX backends be found `here `_
+ and `here `_.
\ No newline at end of file
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index a9c2d52a..094f09bd 100755
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -9,7 +9,7 @@ The PyLops project strives to create a library that is easy to install in
any environment and has a very limited number of dependencies.
Required dependencies are limited to:
-* Python 3.8 or greater
+* Python 3.9 or greater
* `NumPy `_
* `SciPy `_
@@ -321,6 +321,11 @@ In alphabetic order:
dtcwt
-----
+
+.. warning::
+
+ ``dtcwt`` is not yet supported with Numpy 2.
+
`dtcwt `_ is a library used to implement the DT-CWT operators.
Install it via ``pip`` with:
@@ -330,6 +335,7 @@ Install it via ``pip`` with:
>> pip install dtcwt
+
Devito
------
`Devito `_ is a library used to solve PDEs via
@@ -529,4 +535,14 @@ CuPy
for GPU-accelerated computations. Since many different versions of CuPy exist (based on the
CUDA drivers of the GPU), users must install CuPy prior to installing
PyLops. To do so, follow their
-`installation instructions `__.
\ No newline at end of file
+`installation instructions `__.
+
+
+JAX
+---
+`JAX `_ is another library that can be used as a drop-in replacement
+to NumPy and some parts of SciPy. It provides seamless support for multiple accelerators (e.g., GPUs, TPUs),
+Just-In-Time (JIT) compilation via Open XLA, and Automatic Differentiation. Similar to CuPy, since many
+different versions of JAX exist (based on the CUDA drivers of the GPU), users must install JAX prior
+to installing PyLops. To do so, follow their
+`installation instructions `__.
\ No newline at end of file
diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml
index 439a34e0..c711fe76 100755
--- a/environment-dev-arm.yml
+++ b/environment-dev-arm.yml
@@ -8,7 +8,7 @@ dependencies:
- python>=3.6.4
- pip
- numpy>=1.21.0
- - scipy>=1.4.0
+ - scipy>=1.11.0
- pytorch>=1.2.0
- cpuonly
- jax
@@ -36,6 +36,7 @@ dependencies:
- pydata-sphinx-theme
- sphinx-gallery
- nbsphinx
+ - sphinxemoji
- image
- flake8
- mypy
diff --git a/environment-dev.yml b/environment-dev.yml
index f8161474..135319f7 100755
--- a/environment-dev.yml
+++ b/environment-dev.yml
@@ -8,7 +8,7 @@ dependencies:
- python>=3.6.4
- pip
- numpy>=1.21.0
- - scipy>=1.4.0
+ - scipy>=1.11.0
- pytorch>=1.2.0
- cpuonly
- jax
diff --git a/environment.yml b/environment.yml
index 31f5c88a..e09650de 100755
--- a/environment.yml
+++ b/environment.yml
@@ -4,4 +4,4 @@ channels:
dependencies:
- python>=3.6.4
- numpy>=1.21.0
- - scipy>=1.4.0
+ - scipy>=1.14.0
diff --git a/examples/plot_twoway.py b/examples/plot_twoway.py
new file mode 100644
index 00000000..ac8b146a
--- /dev/null
+++ b/examples/plot_twoway.py
@@ -0,0 +1,179 @@
+r"""
+Acoustic Wave Equation modelling
+================================
+
+This example shows how to perform acoustic wave equation modelling
+using the :class:`pylops.waveeqprocessing.AcousticWave2D` operator,
+which brings the power of finite-difference modelling via the Devito
+modelling engine to PyLops.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.ndimage import gaussian_filter
+
+import pylops
+
+plt.close("all")
+np.random.seed(0)
+
+
+###############################################################################
+# To begin with, we will create a simple layered velocity model. We will also
+# define a background velocity model by smoothing the original velocity model
+# which will be responsible of the kinematic of the wavefield modelled via
+# Born modelling, and the perturbation velocity model which will lead to
+# scattering effects and therefore guide the dynamic of the modelled wavefield.
+
+# Velocity Model
+nx, nz = 61, 40
+dx, dz = 4, 4
+x, z = np.arange(nx) * dx, np.arange(nz) * dz
+vel = 1000 * np.ones((nx, nz))
+vel[:, 15:] = 1200
+vel[:, 35:] = 1600
+
+# Smooth velocity model
+v0 = gaussian_filter(vel, sigma=10)
+
+# Born perturbation from m - m0
+dv = vel ** (-2) - v0 ** (-2)
+
+# Receivers
+nr = 101
+rx = np.linspace(0, x[-1], nr)
+rz = 20 * np.ones(nr)
+recs = np.vstack((rx, rz))
+dr = recs[0, 1] - recs[0, 0]
+
+# Sources
+ns = 3
+sx = np.linspace(0, x[-1], ns)
+sz = 10 * np.ones(ns)
+sources = np.vstack((sx, sz))
+
+plt.figure(figsize=(10, 5))
+im = plt.imshow(vel.T, cmap="summer", extent=(x[0], x[-1], z[-1], z[0]))
+plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k")
+plt.scatter(sources[0], sources[1], marker="*", s=150, c="r", edgecolors="k")
+cb = plt.colorbar(im)
+cb.set_label("[m/s]")
+plt.axis("tight")
+plt.xlabel("x [m]"), plt.ylabel("z [m]")
+plt.title("Velocity")
+plt.xlim(x[0], x[-1])
+plt.tight_layout()
+
+plt.figure(figsize=(10, 5))
+im = plt.imshow(dv.T, cmap="seismic", extent=(x[0], x[-1], z[-1], z[0]))
+plt.scatter(recs[0], recs[1], marker="v", s=150, c="b", edgecolors="k")
+plt.scatter(sources[0], sources[1], marker="*", s=150, c="r", edgecolors="k")
+cb = plt.colorbar(im)
+cb.set_label("[m/s]")
+plt.axis("tight")
+plt.xlabel("x [m]"), plt.ylabel("z [m]")
+plt.title("Velocity perturbation")
+plt.xlim(x[0], x[-1])
+plt.tight_layout()
+
+###############################################################################
+# Let us now define the Born modelling operator
+
+Aop = pylops.waveeqprocessing.AcousticWave2D(
+ (nx, nz),
+ (0, 0),
+ (dx, dz),
+ v0,
+ sources[0],
+ sources[1],
+ recs[0],
+ recs[1],
+ 0.0,
+ 0.5 * 1e3,
+ "Ricker",
+ space_order=4,
+ nbl=100,
+ f0=15,
+ dtype="float32",
+)
+
+###############################################################################
+# And we use it to model our data
+
+dobs = Aop @ dv
+
+fig, axs = plt.subplots(1, 3, sharey=True, figsize=(10, 6))
+fig.suptitle("FD modelling with Ricker", y=0.99)
+
+for isrc in range(ns):
+ axs[isrc].imshow(
+ dobs[isrc].reshape(Aop.geometry.nrec, Aop.geometry.nt).T,
+ cmap="gray",
+ vmin=-1e-7,
+ vmax=1e-7,
+ extent=(
+ recs[0, 0],
+ recs[0, -1],
+ Aop.geometry.time_axis.time_values[-1] * 1e-3,
+ 0,
+ ),
+ )
+ axs[isrc].axis("tight")
+ axs[isrc].set_xlabel("rec [m]")
+axs[0].set_ylabel("t [s]")
+fig.tight_layout()
+
+###############################################################################
+# Finally, we are going to show how despite the
+# :class:`pylops.waveeqprocessing.AcousticWave2D` operator allows a user to
+# specify a limited number of source wavelets (this is directly borrowed from
+# Devito), a simple modification can be applied to pass any user defined wavelet.
+# We are going to do that with a Ormsby wavelet
+
+# Extract Ricker wavelet
+wav = Aop.geometry.src.data[:, 0]
+wavc = np.argmax(wav)
+
+# Define Ormsby wavelet
+wavest = pylops.utils.wavelets.ormsby(
+ Aop.geometry.time_axis.time_values[:wavc] * 1e-3, f=[3, 20, 30, 45]
+)[0]
+
+# Update wavelet in operator and model new data
+Aop.updatesrc(wavest)
+
+dobs1 = Aop @ dv
+
+fig, axs = plt.subplots(1, 3, sharey=True, figsize=(10, 6))
+fig.suptitle("FD modelling with Ormsby", y=0.99)
+
+for isrc in range(ns):
+ axs[isrc].imshow(
+ dobs1[isrc].reshape(Aop.geometry.nrec, Aop.geometry.nt).T,
+ cmap="gray",
+ vmin=-1e-7,
+ vmax=1e-7,
+ extent=(
+ recs[0, 0],
+ recs[0, -1],
+ Aop.geometry.time_axis.time_values[-1] * 1e-3,
+ 0,
+ ),
+ )
+ axs[isrc].axis("tight")
+ axs[isrc].set_xlabel("rec [m]")
+axs[0].set_ylabel("t [s]")
+fig.tight_layout()
+
+fig, axs = plt.subplots(1, 2, figsize=(10, 3))
+axs[0].plot(wav[: 2 * wavc], "k")
+axs[0].plot(wavest, "r")
+axs[1].plot(
+ dobs[isrc].reshape(Aop.geometry.nrec, Aop.geometry.nt)[nr // 2], "k", label="Ricker"
+)
+axs[1].plot(
+ dobs1[isrc].reshape(Aop.geometry.nrec, Aop.geometry.nt)[nr // 2],
+ "r",
+ label="Ormsby",
+)
+axs[1].legend()
+fig.tight_layout()
diff --git a/examples/plot_wavelet.py b/examples/plot_wavelet.py
index d080b025..4c410112 100644
--- a/examples/plot_wavelet.py
+++ b/examples/plot_wavelet.py
@@ -1,8 +1,9 @@
"""
Wavelet transform
=================
-This example shows how to use the :py:class:`pylops.DWT` and
-:py:class:`pylops.DWT2D` operators to perform 1- and 2-dimensional DWT.
+This example shows how to use the :py:class:`pylops.DWT`,
+:py:class:`pylops.DWT2D`, and :py:class:`pylops.DWTND` operators
+to perform 1-, 2-, and N-dimensional DWT.
"""
import matplotlib.pyplot as plt
import numpy as np
@@ -67,3 +68,46 @@
axs[1, 1].set_title("DWT2 coefficients (zeroed)")
axs[1, 1].axis("tight")
plt.tight_layout()
+
+###############################################################################
+# Let us now try the same with a 3D volumetric model, where we use the
+# N-dimensional DWT. This time, we only retain 10 percent of the coefficients
+# of the DWT.
+
+nx = 128
+ny = 256
+nz = 128
+
+x = np.arange(nx)
+y = np.arange(ny)
+z = np.arange(nz)
+
+xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
+# Generate a 3D model with two block anomalies
+m = np.ones_like(xx, dtype=float)
+block1 = (xx > 10) & (xx < 60) & (yy > 100) & (yy < 150) & (zz > 20) & (zz < 70)
+block2 = (xx > 70) & (xx < 80) & (yy > 100) & (yy < 200) & (zz > 10) & (zz < 50)
+m[block1] = 1.2
+m[block2] = 0.8
+Wop = pylops.signalprocessing.DWTND((nx, ny, nz), wavelet="haar", level=3)
+y = Wop * m
+
+ratio = 0.1
+yf = y.copy()
+yf.flat[int(ratio * y.size) :] = 0
+iminv = Wop.H * yf
+
+fig, axs = plt.subplots(2, 2, figsize=(6, 6))
+axs[0, 0].imshow(m[:, :, 30], cmap="gray")
+axs[0, 0].set_title("Model (Slice at z=30)")
+axs[0, 0].axis("tight")
+axs[0, 1].imshow(y[:, :, 90], cmap="gray_r")
+axs[0, 1].set_title("DWTNT coefficients")
+axs[0, 1].axis("tight")
+axs[1, 0].imshow(iminv[:, :, 30], cmap="gray")
+axs[1, 0].set_title("Reconstructed model (Slice at z=30)")
+axs[1, 0].axis("tight")
+axs[1, 1].imshow(yf[:, :, 90], cmap="gray_r")
+axs[1, 1].set_title("DWTNT coefficients (zeroed)")
+axs[1, 1].axis("tight")
+plt.tight_layout()
diff --git a/pylops/__init__.py b/pylops/__init__.py
index 55d4ce3d..7672fda4 100755
--- a/pylops/__init__.py
+++ b/pylops/__init__.py
@@ -48,6 +48,7 @@
from .config import *
from .linearoperator import *
from .torchoperator import *
+from .jaxoperator import *
from .basicoperators import *
from . import (
avo,
diff --git a/pylops/avo/poststack.py b/pylops/avo/poststack.py
index 7e514707..8a9001ac 100644
--- a/pylops/avo/poststack.py
+++ b/pylops/avo/poststack.py
@@ -27,6 +27,7 @@
get_csc_matrix,
get_lstsq,
get_module_name,
+ inplace_set,
)
from pylops.utils.signalprocessing import convmtx, nonstationary_convmtx
from pylops.utils.typing import NDArray, ShapeLike
@@ -93,12 +94,13 @@ def _PoststackLinearModelling(
D = ncp.diag(0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
0.5 * ncp.ones(nt0 - 1, dtype=dtype), -1
)
- D[0] = D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, 0)
+ D = inplace_set(ncp.array(0.0), D, -1)
else:
D = ncp.diag(ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
ncp.ones(nt0, dtype=dtype), k=0
)
- D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, -1)
# Create wavelet operator
if len(wav.shape) == 1:
diff --git a/pylops/avo/prestack.py b/pylops/avo/prestack.py
index 8630bc9a..4cf6c4eb 100644
--- a/pylops/avo/prestack.py
+++ b/pylops/avo/prestack.py
@@ -31,6 +31,7 @@
get_block_diag,
get_lstsq,
get_module_name,
+ inplace_set,
)
from pylops.utils.signalprocessing import convmtx
from pylops.utils.typing import NDArray, ShapeLike
@@ -182,12 +183,13 @@ def PrestackLinearModelling(
D = ncp.diag(0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
0.5 * ncp.ones(nt0 - 1, dtype=dtype), k=-1
)
- D[0] = D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, 0)
+ D = inplace_set(ncp.array(0.0), D, -1)
else:
D = ncp.diag(ncp.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
ncp.ones(nt0, dtype=dtype), k=0
)
- D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, -1)
D = get_block_diag(theta)(*([D] * nG))
# Create wavelet operator
@@ -339,7 +341,8 @@ def PrestackWaveletModelling(
D = ncp.diag(0.5 * np.ones(nt0 - 1, dtype=dtype), k=1) - ncp.diag(
0.5 * np.ones(nt0 - 1, dtype=dtype), k=-1
)
- D[0] = D[-1] = 0
+ D = inplace_set(ncp.array(0.0), D, 0)
+ D = inplace_set(ncp.array(0.0), D, -1)
D = get_block_diag(theta)(*([D] * nG))
# Create infinite-reflectivity data
diff --git a/pylops/basicoperators/blockdiag.py b/pylops/basicoperators/blockdiag.py
index e13ed026..166ae137 100644
--- a/pylops/basicoperators/blockdiag.py
+++ b/pylops/basicoperators/blockdiag.py
@@ -21,7 +21,7 @@
from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_set
from pylops.utils.typing import DTypeLike, NDArray
@@ -175,18 +175,22 @@ def _matvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.nops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.nnops[iop] : self.nnops[iop + 1]] = oper.matvec(
- x[self.mmops[iop] : self.mmops[iop + 1]]
- ).squeeze()
+ y = inplace_set(
+ oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze(),
+ y,
+ slice(self.nnops[iop], self.nnops[iop + 1]),
+ )
return y
def _rmatvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.mops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.mmops[iop] : self.mmops[iop + 1]] = oper.rmatvec(
- x[self.nnops[iop] : self.nnops[iop + 1]]
- ).squeeze()
+ y = inplace_set(
+ oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze(),
+ y,
+ slice(self.mmops[iop], self.mmops[iop + 1]),
+ )
return y
def _matvec_multiproc(self, x: NDArray) -> NDArray:
diff --git a/pylops/basicoperators/firstderivative.py b/pylops/basicoperators/firstderivative.py
index 58edf17f..f8bd208e 100644
--- a/pylops/basicoperators/firstderivative.py
+++ b/pylops/basicoperators/firstderivative.py
@@ -7,7 +7,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -100,6 +100,16 @@ def __init__(
self.kind = kind
self.edge = edge
self.order = order
+ self.slice = {
+ i: {
+ j: tuple([slice(None, None)] * (len(dims) - 1) + [slice(i, j)])
+ for j in (None, -1, -2, -3, -4)
+ }
+ for i in (None, 1, 2, 3, 4)
+ }
+ self.sample = {
+ i: tuple([slice(None, None)] * (len(dims) - 1) + [i]) for i in range(-3, 4)
+ }
self._register_multiplications(self.kind, self.order)
def _register_multiplications(
@@ -140,15 +150,20 @@ def _rmatvec(self, x: NDArray) -> NDArray:
def _matvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ # y[..., :-1] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ y = inplace_set(
+ (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[None][-1]
+ )
return y
@reshaped(swapaxis=True)
def _rmatvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-1] -= x[..., :-1]
- y[..., 1:] += x[..., :-1]
+ # y[..., :-1] -= x[..., :-1]
+ y = inplace_add(-x[..., :-1], y, self.slice[None][-1])
+ # y[..., 1:] += x[..., :-1]
+ y = inplace_add(x[..., :-1], y, self.slice[1][None])
y /= self.sampling
return y
@@ -156,10 +171,13 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray:
def _matvec_centered3(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2])
+ # y[..., 1:-1] = 0.5 * (x[..., 2:] - x[..., :-2])
+ y = inplace_set(0.5 * (x[..., 2:] - x[..., :-2]), y, self.slice[1][-1])
if self.edge:
- y[..., 0] = x[..., 1] - x[..., 0]
- y[..., -1] = x[..., -1] - x[..., -2]
+ # y[..., 0] = x[..., 1] - x[..., 0]
+ y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0])
+ # y[..., -1] = x[..., -1] - x[..., -2]
+ y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1])
y /= self.sampling
return y
@@ -167,13 +185,19 @@ def _matvec_centered3(self, x: NDArray) -> NDArray:
def _rmatvec_centered3(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] -= 0.5 * x[..., 1:-1]
- y[..., 2:] += 0.5 * x[..., 1:-1]
+ # y[..., :-2] -= 0.5 * x[..., 1:-1]
+ y = inplace_add(-0.5 * x[..., 1:-1], y, self.slice[None][-2])
+ # y[..., 2:] += 0.5 * x[..., 1:-1]
+ y = inplace_add(0.5 * x[..., 1:-1], y, self.slice[2][None])
if self.edge:
- y[..., 0] -= x[..., 0]
- y[..., 1] += x[..., 0]
- y[..., -2] -= x[..., -1]
- y[..., -1] += x[..., -1]
+ # y[..., 0] -= x[..., 0]
+ y = inplace_add(-x[..., 0], y, self.sample[0])
+ # y[..., 1] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[1])
+ # y[..., -2] -= x[..., -1]
+ y = inplace_add(-x[..., -1], y, self.sample[-2])
+ # y[..., -1] += x[..., -1]
+ y = inplace_add(x[..., -1], y, self.sample[-1])
y /= self.sampling
return y
@@ -181,17 +205,31 @@ def _rmatvec_centered3(self, x: NDArray) -> NDArray:
def _matvec_centered5(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 2:-2] = (
- x[..., :-4] / 12.0
- - 2 * x[..., 1:-3] / 3.0
- + 2 * x[..., 3:-1] / 3.0
- - x[..., 4:] / 12.0
+ # y[..., 2:-2] = (
+ # x[..., :-4] / 12.0
+ # - 2 * x[..., 1:-3] / 3.0
+ # + 2 * x[..., 3:-1] / 3.0
+ # - x[..., 4:] / 12.0
+ # )
+ y = inplace_set(
+ (
+ x[..., :-4] / 12.0
+ - 2 * x[..., 1:-3] / 3.0
+ + 2 * x[..., 3:-1] / 3.0
+ - x[..., 4:] / 12.0
+ ),
+ y,
+ self.slice[2][-2],
)
if self.edge:
- y[..., 0] = x[..., 1] - x[..., 0]
- y[..., 1] = 0.5 * (x[..., 2] - x[..., 0])
- y[..., -2] = 0.5 * (x[..., -1] - x[..., -3])
- y[..., -1] = x[..., -1] - x[..., -2]
+ # y[..., 0] = x[..., 1] - x[..., 0]
+ y = inplace_set(x[..., 1] - x[..., 0], y, self.sample[0])
+ # y[..., 1] = 0.5 * (x[..., 2] - x[..., 0])
+ y = inplace_set(0.5 * (x[..., 2] - x[..., 0]), y, self.sample[1])
+ # y[..., -2] = 0.5 * (x[..., -1] - x[..., -3])
+ y = inplace_set(0.5 * (x[..., -1] - x[..., -3]), y, self.sample[-2])
+ # y[..., -1] = x[..., -1] - x[..., -2]
+ y = inplace_set(x[..., -1] - x[..., -2], y, self.sample[-1])
y /= self.sampling
return y
@@ -199,17 +237,27 @@ def _matvec_centered5(self, x: NDArray) -> NDArray:
def _rmatvec_centered5(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-4] += x[..., 2:-2] / 12.0
- y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0
- y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0
- y[..., 4:] -= x[..., 2:-2] / 12.0
+ # y[..., :-4] += x[..., 2:-2] / 12.0
+ y = inplace_add(x[..., 2:-2] / 12.0, y, self.slice[None][-4])
+ # y[..., 1:-3] -= 2.0 * x[..., 2:-2] / 3.0
+ y = inplace_add(-2.0 * x[..., 2:-2] / 3.0, y, self.slice[1][-3])
+ # y[..., 3:-1] += 2.0 * x[..., 2:-2] / 3.0
+ y = inplace_add(2.0 * x[..., 2:-2] / 3.0, y, self.slice[3][-1])
+ # y[..., 4:] -= x[..., 2:-2] / 12.0
+ y = inplace_add(-x[..., 2:-2] / 12.0, y, self.slice[4][None])
if self.edge:
- y[..., 0] -= x[..., 0] + 0.5 * x[..., 1]
- y[..., 1] += x[..., 0]
- y[..., 2] += 0.5 * x[..., 1]
- y[..., -3] -= 0.5 * x[..., -2]
- y[..., -2] -= x[..., -1]
- y[..., -1] += 0.5 * x[..., -2] + x[..., -1]
+ # y[..., 0] -= x[..., 0] + 0.5 * x[..., 1]
+ y = inplace_add(-(x[..., 0] + 0.5 * x[..., 1]), y, self.sample[0])
+ # y[..., 1] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[1])
+ # y[..., 2] += 0.5 * x[..., 1]
+ y = inplace_add(0.5 * x[..., 1], y, self.sample[2])
+ # y[..., -3] -= 0.5 * x[..., -2]
+ y = inplace_add(-0.5 * x[..., -2], y, self.sample[-3])
+ # y[..., -2] -= x[..., -1]
+ y = inplace_add(-x[..., -1], y, self.sample[-2])
+ # y[..., -1] += 0.5 * x[..., -2] + x[..., -1]
+ y = inplace_add(0.5 * x[..., -2] + x[..., -1], y, self.sample[-1])
y /= self.sampling
return y
@@ -217,14 +265,19 @@ def _rmatvec_centered5(self, x: NDArray) -> NDArray:
def _matvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ # y[..., 1:] = (x[..., 1:] - x[..., :-1]) / self.sampling
+ y = inplace_set(
+ (x[..., 1:] - x[..., :-1]) / self.sampling, y, self.slice[1][None]
+ )
return y
@reshaped(swapaxis=True)
def _rmatvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-1] -= x[..., 1:]
- y[..., 1:] += x[..., 1:]
+ # y[..., :-1] -= x[..., 1:]
+ y = inplace_add(-x[..., 1:], y, self.slice[None][-1])
+ # y[..., 1:] += x[..., 1:]
+ y = inplace_add(x[..., 1:], y, self.slice[1][None])
y /= self.sampling
return y
diff --git a/pylops/basicoperators/hstack.py b/pylops/basicoperators/hstack.py
index 5cfbbec0..b71e8723 100644
--- a/pylops/basicoperators/hstack.py
+++ b/pylops/basicoperators/hstack.py
@@ -21,7 +21,7 @@
from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.typing import NDArray
@@ -165,14 +165,22 @@ def _matvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.nops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y += oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze()
+ y = inplace_add(
+ oper.matvec(x[self.mmops[iop] : self.mmops[iop + 1]]).squeeze(),
+ y,
+ slice(None, None),
+ )
return y
def _rmatvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.mops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.mmops[iop] : self.mmops[iop + 1]] = oper.rmatvec(x).squeeze()
+ y = inplace_set(
+ oper.rmatvec(x).squeeze(),
+ y,
+ slice(self.mmops[iop], self.mmops[iop + 1]),
+ )
return y
def _matvec_multiproc(self, x: NDArray) -> NDArray:
diff --git a/pylops/basicoperators/identity.py b/pylops/basicoperators/identity.py
index c2d05a30..50b76831 100644
--- a/pylops/basicoperators/identity.py
+++ b/pylops/basicoperators/identity.py
@@ -6,7 +6,7 @@
import numpy as np
from pylops import LinearOperator
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -181,7 +181,7 @@ def _matvec(self, x: NDArray) -> NDArray:
y = x[self.sliceN]
else:
y = ncp.zeros(self.dimsd, dtype=self.dtype)
- y[self.sliceM] = x
+ y = inplace_set(x, y, self.sliceM)
return y
@reshaped
@@ -193,7 +193,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
y = x
elif self.mode == "model":
y = ncp.zeros(self.dims, dtype=self.dtype)
- y[self.sliceN] = x
+ y = inplace_set(x, y, self.sliceN)
else:
y = x[self.sliceM]
return y
diff --git a/pylops/basicoperators/pad.py b/pylops/basicoperators/pad.py
index 45b63af8..d98a894c 100644
--- a/pylops/basicoperators/pad.py
+++ b/pylops/basicoperators/pad.py
@@ -6,6 +6,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -85,10 +86,12 @@ def __init__(
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
- return np.pad(x, self.pad, mode="constant")
+ ncp = get_array_module(x)
+ return ncp.pad(x, self.pad, mode="constant")
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
for ax, (before, _) in enumerate(self.pad):
- x = np.take(x, np.arange(before, before + self.dims[ax]), axis=ax)
+ x = ncp.take(x, ncp.arange(before, before + self.dims[ax]), axis=ax)
return x
diff --git a/pylops/basicoperators/restriction.py b/pylops/basicoperators/restriction.py
index c2e51a31..1a745b30 100644
--- a/pylops/basicoperators/restriction.py
+++ b/pylops/basicoperators/restriction.py
@@ -1,16 +1,22 @@
__all__ = ["Restriction"]
import logging
-
from typing import Sequence, Union
import numpy as np
import numpy.ma as np_ma
-from numpy.core.multiarray import normalize_axis_index
+
+# need to check numpy version since normalize_axis_index will be
+# soon moved from numpy.core.multiarray to from numpy.lib.array_utils
+np_version = np.__version__.split(".")
+if int(np_version[0]) < 2:
+ from numpy.core.multiarray import normalize_axis_index
+else:
+ from numpy.lib.array_utils import normalize_axis_index
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module, to_cupy_conditional
+from pylops.utils.backend import get_array_module, inplace_set, to_cupy_conditional
from pylops.utils.typing import DTypeLike, InputDimsLike, IntNDArray, NDArray
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
@@ -20,13 +26,13 @@ def _compute_iavamask(dims, axis, iava, ncp):
"""Compute restriction mask when using cupy arrays"""
otherdims = np.array(dims)
otherdims = np.delete(otherdims, axis)
- iavamask = ncp.zeros(int(dims[axis]), dtype=int)
+ iavamask = np.zeros(int(dims[axis]), dtype=int)
iavamask[iava] = 1
- iavamask = ncp.moveaxis(
- ncp.broadcast_to(iavamask, list(otherdims) + [dims[axis]]), -1, axis
+ iavamask = np.moveaxis(
+ np.broadcast_to(iavamask, list(otherdims) + [dims[axis]]), -1, axis
)
- iavamask = ncp.where(iavamask.ravel() == 1)[0]
- return iavamask
+ iavamask = np.where(iavamask.ravel() == 1)[0]
+ return ncp.asarray(iavamask)
class Restriction(LinearOperator):
@@ -128,8 +134,13 @@ def __init__(
)
forceflat = None
- super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd,
- forceflat=forceflat, name=name)
+ super().__init__(
+ dtype=np.dtype(dtype),
+ dims=dims,
+ dimsd=dimsd,
+ forceflat=forceflat,
+ name=name,
+ )
iavareshape = np.ones(len(self.dims), dtype=int)
iavareshape[axis] = len(iava)
@@ -168,7 +179,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
self.iava = to_cupy_conditional(x, self.iava)
self.iavamask = _compute_iavamask(self.dims, self.axis, self.iava, ncp)
y = ncp.zeros(int(self.shape[-1]), dtype=self.dtype)
- y[self.iavamask] = x.ravel()
+ y = inplace_set(x.ravel(), y, self.iavamask)
y = y.ravel()
return y
diff --git a/pylops/basicoperators/roll.py b/pylops/basicoperators/roll.py
index 8fc27e4d..29e6f613 100644
--- a/pylops/basicoperators/roll.py
+++ b/pylops/basicoperators/roll.py
@@ -6,6 +6,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -64,8 +65,10 @@ def __init__(
@reshaped(swapaxis=True)
def _matvec(self, x: NDArray) -> NDArray:
- return np.roll(x, shift=self.shift, axis=-1)
+ ncp = get_array_module(x)
+ return ncp.roll(x, shift=self.shift, axis=-1)
@reshaped(swapaxis=True)
def _rmatvec(self, x: NDArray) -> NDArray:
- return np.roll(x, shift=-self.shift, axis=-1)
+ ncp = get_array_module(x)
+ return ncp.roll(x, shift=-self.shift, axis=-1)
diff --git a/pylops/basicoperators/secondderivative.py b/pylops/basicoperators/secondderivative.py
index 744d067a..8433987d 100644
--- a/pylops/basicoperators/secondderivative.py
+++ b/pylops/basicoperators/secondderivative.py
@@ -7,7 +7,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -90,6 +90,16 @@ def __init__(
self.sampling = sampling
self.kind = kind
self.edge = edge
+ self.slice = {
+ i: {
+ j: tuple([slice(None, None)] * (len(dims) - 1) + [slice(i, j)])
+ for j in (None, -1, -2, -3, -4)
+ }
+ for i in (None, 1, 2, 3, 4)
+ }
+ self.sample = {
+ i: tuple([slice(None, None)] * (len(dims) - 1) + [i]) for i in range(-3, 4)
+ }
self._register_multiplications(self.kind)
def _register_multiplications(
@@ -123,7 +133,10 @@ def _rmatvec(self, x: NDArray) -> NDArray:
def _matvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ # y[..., :-2] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ y = inplace_set(
+ x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[None][-2]
+ )
y /= self.sampling**2
return y
@@ -131,9 +144,12 @@ def _matvec_forward(self, x: NDArray) -> NDArray:
def _rmatvec_forward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] += x[..., :-2]
- y[..., 1:-1] -= 2 * x[..., :-2]
- y[..., 2:] += x[..., :-2]
+ # y[..., :-2] += x[..., :-2]
+ y = inplace_add(x[..., :-2], y, self.slice[None][-2])
+ # y[..., 1:-1] -= 2 * x[..., :-2]
+ y = inplace_add(-2 * x[..., :-2], y, self.slice[1][-1])
+ # y[..., 2:] += x[..., :-2]
+ y = inplace_add(x[..., :-2], y, self.slice[2][None])
y /= self.sampling**2
return y
@@ -141,10 +157,17 @@ def _rmatvec_forward(self, x: NDArray) -> NDArray:
def _matvec_centered(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ # y[..., 1:-1] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ y = inplace_set(
+ x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[1][-1]
+ )
if self.edge:
- y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2]
- y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1]
+ # y[..., 0] = x[..., 0] - 2 * x[..., 1] + x[..., 2]
+ y = inplace_set(x[..., 0] - 2 * x[..., 1] + x[..., 2], y, self.sample[0])
+ # y[..., -1] = x[..., -3] - 2 * x[..., -2] + x[..., -1]
+ y = inplace_set(
+ x[..., -3] - 2 * x[..., -2] + x[..., -1], y, self.sample[-1]
+ )
y /= self.sampling**2
return y
@@ -152,16 +175,25 @@ def _matvec_centered(self, x: NDArray) -> NDArray:
def _rmatvec_centered(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] += x[..., 1:-1]
- y[..., 1:-1] -= 2 * x[..., 1:-1]
- y[..., 2:] += x[..., 1:-1]
+ # y[..., :-2] += x[..., 1:-1]
+ y = inplace_add(x[..., 1:-1], y, self.slice[None][-2])
+ # y[..., 1:-1] -= 2 * x[..., 1:-1]
+ y = inplace_add(-2 * x[..., 1:-1], y, self.slice[1][-1])
+ # y[..., 2:] += x[..., 1:-1]
+ y = inplace_add(x[..., 1:-1], y, self.slice[2][None])
if self.edge:
- y[..., 0] += x[..., 0]
- y[..., 1] -= 2 * x[..., 0]
- y[..., 2] += x[..., 0]
- y[..., -3] += x[..., -1]
- y[..., -2] -= 2 * x[..., -1]
- y[..., -1] += x[..., -1]
+ # y[..., 0] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[0])
+ # y[..., 1] -= 2 * x[..., 0]
+ y = inplace_add(-2 * x[..., 0], y, self.sample[1])
+ # y[..., 2] += x[..., 0]
+ y = inplace_add(x[..., 0], y, self.sample[2])
+ # y[..., -3] += x[..., -1]
+ y = inplace_add(x[..., -1], y, self.sample[-3])
+ # y[..., -2] -= 2 * x[..., -1]
+ y = inplace_add(-2 * x[..., -1], y, self.sample[-2])
+ # y[..., -1] += x[..., -1]
+ y = inplace_add(x[..., -1], y, self.sample[-1])
y /= self.sampling**2
return y
@@ -169,7 +201,10 @@ def _rmatvec_centered(self, x: NDArray) -> NDArray:
def _matvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ # y[..., 2:] = x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2]
+ y = inplace_set(
+ x[..., 2:] - 2 * x[..., 1:-1] + x[..., :-2], y, self.slice[2][None]
+ )
y /= self.sampling**2
return y
@@ -177,8 +212,11 @@ def _matvec_backward(self, x: NDArray) -> NDArray:
def _rmatvec_backward(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(x.shape, self.dtype)
- y[..., :-2] += x[..., 2:]
- y[..., 1:-1] -= 2 * x[..., 2:]
- y[..., 2:] += x[..., 2:]
+ # y[..., :-2] += x[..., 2:]
+ y = inplace_add(x[..., 2:], y, self.slice[None][-2])
+ # y[..., 1:-1] -= 2 * x[..., 2:]
+ y = inplace_add(-2 * x[..., 2:], y, self.slice[1][-1])
+ # y[..., 2:] += x[..., 2:]
+ y = inplace_add(x[..., 2:], y, self.slice[2][None])
y /= self.sampling**2
return y
diff --git a/pylops/basicoperators/symmetrize.py b/pylops/basicoperators/symmetrize.py
index 47814154..41ca122b 100644
--- a/pylops/basicoperators/symmetrize.py
+++ b/pylops/basicoperators/symmetrize.py
@@ -6,7 +6,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -80,6 +80,13 @@ def __init__(
self.nsym = dims[self.axis]
dimsd = list(dims)
dimsd[self.axis] = 2 * dims[self.axis] - 1
+ self.slice1 = tuple([slice(None, None)] * (len(dims) - 1) + [slice(1, None)])
+ self.slicensym_1 = tuple(
+ [slice(None, None)] * (len(dims) - 1) + [slice(self.nsym - 1, None)]
+ )
+ self.slice_nsym_1 = tuple(
+ [slice(None, None)] * (len(dims) - 1) + [slice(None, self.nsym - 1)]
+ )
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
@@ -88,12 +95,12 @@ def _matvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.dimsd, dtype=self.dtype)
y = y.swapaxes(self.axis, -1)
- y[..., self.nsym - 1 :] = x
- y[..., : self.nsym - 1] = x[..., -1:0:-1]
+ y = inplace_set(x, y, self.slicensym_1)
+ y = inplace_set(x[..., -1:0:-1], y, self.slice_nsym_1)
return y
@reshaped(swapaxis=True)
def _rmatvec(self, x: NDArray) -> NDArray:
y = x[..., self.nsym - 1 :].copy()
- y[..., 1:] += x[..., self.nsym - 2 :: -1]
+ y = inplace_add(x[..., self.nsym - 2 :: -1], y, self.slice1)
return y
diff --git a/pylops/basicoperators/vstack.py b/pylops/basicoperators/vstack.py
index 812b1a7e..0d66642e 100644
--- a/pylops/basicoperators/vstack.py
+++ b/pylops/basicoperators/vstack.py
@@ -12,16 +12,16 @@
from scipy.sparse.linalg.interface import LinearOperator as spLinearOperator
from scipy.sparse.linalg.interface import _get_dtype
else:
- from scipy.sparse.linalg._interface import _get_dtype
from scipy.sparse.linalg._interface import (
LinearOperator as spLinearOperator,
)
+ from scipy.sparse.linalg._interface import _get_dtype
from typing import Callable, Optional, Sequence
from pylops import LinearOperator
from pylops.basicoperators import MatrixMult
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.typing import DTypeLike, NDArray
@@ -165,14 +165,20 @@ def _matvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.nops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y[self.nnops[iop] : self.nnops[iop + 1]] = oper.matvec(x).squeeze()
+ y = inplace_set(
+ oper.matvec(x).squeeze(), y, slice(self.nnops[iop], self.nnops[iop + 1])
+ )
return y
def _rmatvec_serial(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
y = ncp.zeros(self.mops, dtype=self.dtype)
for iop, oper in enumerate(self.ops):
- y += oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze()
+ y = inplace_add(
+ oper.rmatvec(x[self.nnops[iop] : self.nnops[iop + 1]]).squeeze(),
+ y,
+ slice(None, None),
+ )
return y
def _matvec_multiproc(self, x: NDArray) -> NDArray:
diff --git a/pylops/jaxoperator.py b/pylops/jaxoperator.py
new file mode 100644
index 00000000..5d5c40ed
--- /dev/null
+++ b/pylops/jaxoperator.py
@@ -0,0 +1,104 @@
+__all__ = [
+ "JaxOperator",
+]
+
+from typing import Any, NewType
+
+from pylops import LinearOperator
+from pylops.utils import deps
+
+if deps.jax_enabled:
+ import jax
+
+ jaxarrayin_type = jax.typing.ArrayLike
+ jaxarrayout_type = jax.Array
+else:
+ jax_message = (
+ "JAX package not installed. In order to be able to use"
+ 'the jaxoperator module run "pip install jax" or'
+ '"conda install -c conda-forge jax".'
+ )
+ jaxarrayin_type = Any
+ jaxarrayout_type = Any
+
+JaxTypeIn = NewType("JaxTypeIn", jaxarrayin_type)
+JaxTypeOut = NewType("JaxTypeOut", jaxarrayout_type)
+
+
+class JaxOperator(LinearOperator):
+ """Enable JAX backend for PyLops operator.
+
+ This class can be used to wrap a pylops operator to enable the JAX
+ backend. Doing so, users can run all of the methods of a pylops
+ operator with JAX arrays. Moreover, the forward and adjoint
+ are internally just-in-time compiled, and other JAX functionalities
+ such as automatic differentiation and automatic vectorization
+ are enabled.
+
+ Parameters
+ ----------
+ Op : :obj:`pylops.LinearOperator`
+ PyLops operator
+
+ """
+
+ def __init__(self, Op: LinearOperator) -> None:
+ if not deps.jax_enabled:
+ raise NotImplementedError(jax_message)
+ super().__init__(
+ dtype=Op.dtype,
+ dims=Op.dims,
+ dimsd=Op.dimsd,
+ clinear=Op.clinear,
+ explicit=False,
+ forceflat=Op.forceflat,
+ name=Op.name,
+ )
+ self._matvec = jax.jit(Op._matvec)
+ self._rmatvec = jax.jit(Op._rmatvec)
+
+ def __call__(self, x, *args, **kwargs):
+ return self._matvec(x)
+
+ def _rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
+ _, f_vjp = jax.vjp(self._matvec, x)
+ xadj = jax.jit(f_vjp)(y)[0]
+ return xadj
+
+ def rmatvecad(self, x: JaxTypeIn, y: JaxTypeIn) -> JaxTypeOut:
+ """Vector-Jacobian product
+
+ JIT-compiled Vector-Jacobian product
+
+ Parameters
+ ----------
+ x : :obj:`jax.Array`
+ Input array for forward
+ y : :obj:`jax.Array`
+ Input array for adjoint
+
+ Returns
+ -------
+ xadj : :obj:`jax.typing.ArrayLike`
+ Output array
+
+ """
+ M, N = self.shape
+
+ if x.shape != (M,) and x.shape != (M, 1):
+ raise ValueError(
+ f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)."
+ )
+
+ y = self._rmatvecad(x, y)
+
+ if x.ndim == 1:
+ y = y.reshape(N)
+ elif x.ndim == 2:
+ y = y.reshape(N, 1)
+ else:
+ raise ValueError(
+ f"Invalid shape returned by user-defined rmatvecad(). "
+ f"Expected 2-d ndarray or matrix, not {x.ndim}-d ndarray"
+ )
+ return y
diff --git a/pylops/linearoperator.py b/pylops/linearoperator.py
index 0a719cf3..661178f5 100644
--- a/pylops/linearoperator.py
+++ b/pylops/linearoperator.py
@@ -442,10 +442,11 @@ def _matmat(self, X: NDArray) -> NDArray:
Modified version of scipy _matmat to avoid having trailing dimension
in col when provided to matvec
"""
+ ncp = get_array_module(X)
if sp.sparse.issparse(X):
- y = np.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.matvec(col.toarray().reshape(-1)) for col in X.T]).T
else:
- y = np.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.matvec(col.reshape(-1)) for col in X.T]).T
return y
def _rmatmat(self, X: NDArray) -> NDArray:
@@ -454,10 +455,11 @@ def _rmatmat(self, X: NDArray) -> NDArray:
Modified version of scipy _rmatmat to avoid having trailing dimension
in col when provided to rmatvec
"""
+ ncp = get_array_module(X)
if sp.sparse.issparse(X):
- y = np.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.rmatvec(col.toarray().reshape(-1)) for col in X.T]).T
else:
- y = np.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T
+ y = ncp.vstack([self.rmatvec(col.reshape(-1)) for col in X.T]).T
return y
def _adjoint(self) -> LinearOperator:
@@ -508,7 +510,9 @@ def matvec(self, x: NDArray) -> NDArray:
M, N = self.shape
if x.shape != (N,) and x.shape != (N, 1):
- raise ValueError("dimension mismatch")
+ raise ValueError(
+ f"Dimension mismatch. Got {x.shape}, but expected ({N},) or ({N}, 1)."
+ )
y = self._matvec(x)
@@ -517,7 +521,7 @@ def matvec(self, x: NDArray) -> NDArray:
elif x.ndim == 2:
y = y.reshape(M, 1)
else:
- raise ValueError("invalid shape returned by user-defined matvec()")
+ raise ValueError("Invalid shape returned by user-defined matvec()")
return y
@count(forward=False)
@@ -542,7 +546,9 @@ def rmatvec(self, x: NDArray) -> NDArray:
M, N = self.shape
if x.shape != (M,) and x.shape != (M, 1):
- raise ValueError("dimension mismatch")
+ raise ValueError(
+ f"Dimension mismatch. Got {x.shape}, but expected ({M},) or ({M}, 1)."
+ )
y = self._rmatvec(x)
@@ -551,7 +557,7 @@ def rmatvec(self, x: NDArray) -> NDArray:
elif x.ndim == 2:
y = y.reshape(N, 1)
else:
- raise ValueError("invalid shape returned by user-defined rmatvec()")
+ raise ValueError("Invalid shape returned by user-defined rmatvec()")
return y
@count(forward=True, matmat=True)
@@ -574,9 +580,9 @@ def matmat(self, X: NDArray) -> NDArray:
"""
if X.ndim != 2:
- raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim)
+ raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray")
if X.shape[0] != self.shape[1]:
- raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape))
+ raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}")
Y = self._matmat(X)
return Y
@@ -600,9 +606,9 @@ def rmatmat(self, X: NDArray) -> NDArray:
"""
if X.ndim != 2:
- raise ValueError("expected 2-d ndarray or matrix, " "not %d-d" % X.ndim)
+ raise ValueError(f"Expected 2-d ndarray or matrix, not {X.ndim}-d ndarray")
if X.shape[0] != self.shape[0]:
- raise ValueError("dimension mismatch: %r, %r" % (self.shape, X.shape))
+ raise ValueError(f"Dimension mismatch: {self.shape}, {X.shape}")
Y = self._rmatmat(X)
return Y
@@ -791,7 +797,7 @@ def todense(
Parameters
----------
backend : :obj:`str`, optional
- Backend used to densify matrix (``numpy`` or ``cupy``). Note that
+ Backend used to densify matrix (``numpy`` or ``cupy`` or ``jax``). Note that
this must be consistent with how the operator has been created.
Returns
@@ -816,7 +822,7 @@ def todense(
if Op.shape[1] == shapemin:
matrix = Op.matmat(identity)
else:
- matrix = np.conj(Op.rmatmat(identity)).T
+ matrix = ncp.conj(Op.rmatmat(identity)).T
return matrix
def tosparse(self) -> NDArray:
@@ -1242,23 +1248,14 @@ def _get_dtype(
) -> DTypeLike:
if dtypes is None:
dtypes = []
- opdtypes = []
for obj in operators:
if obj is not None and hasattr(obj, "dtype"):
- opdtypes.append(obj.dtype)
- return np.find_common_type(opdtypes, dtypes)
+ dtypes.append(obj.dtype)
+ return np.result_type(*dtypes)
class _ScaledLinearOperator(LinearOperator):
- """
- Sum Linear Operator
-
- Modified version of scipy _ScaledLinearOperator which uses a modified
- _get_dtype where the scalar and operator types are passed separately to
- np.find_common_type. Passing them together does lead to problems when using
- np.float32 operators which are cast to np.float64
-
- """
+ """Scaled Linear Operator"""
def __init__(
self,
@@ -1269,7 +1266,15 @@ def __init__(
raise ValueError("LinearOperator expected as A")
if not np.isscalar(alpha):
raise ValueError("scalar expected as alpha")
- dtype = _get_dtype([A], [type(alpha)])
+ if isinstance(alpha, complex) and not np.iscomplexobj(
+ np.ones(1, dtype=A.dtype)
+ ):
+ # if the scalar is of complex type but not the operator, find out type
+ dtype = _get_dtype([A], [type(alpha)])
+ else:
+ # if both the scalar and operator are of real or complex type, use type
+ # of the operator
+ dtype = A.dtype
super(_ScaledLinearOperator, self).__init__(dtype=dtype, shape=A.shape)
self.args = (A, alpha)
@@ -1465,7 +1470,7 @@ def __init__(self, A: LinearOperator, p: int) -> None:
if not isintlike(p) or p < 0:
raise ValueError("non-negative integer expected as p")
- super(_PowerLinearOperator, self).__init__(dtype=_get_dtype([A]), shape=A.shape)
+ super(_PowerLinearOperator, self).__init__(dtype=A.dtype, shape=A.shape)
self.args = (A, p)
def _power(self, fun: Callable, x: NDArray) -> NDArray:
diff --git a/pylops/optimization/cls_leastsquares.py b/pylops/optimization/cls_leastsquares.py
index a9e20fec..64526f81 100644
--- a/pylops/optimization/cls_leastsquares.py
+++ b/pylops/optimization/cls_leastsquares.py
@@ -219,7 +219,7 @@ def run(
and cupy `data`, respectively)
.. note::
- When user does not supply ``atol``, it is set to "legacy".
+ When user supplies ``tol`` this is set to ``atol``.
Returns
-------
@@ -238,8 +238,9 @@ def run(
if x is not None:
self.y_normal = self.y_normal - self.Op_normal.matvec(x)
if engine == "scipy" and self.ncp == np:
- if "atol" not in kwargs_solver:
- kwargs_solver["atol"] = "legacy"
+ if "tol" in kwargs_solver:
+ kwargs_solver["atol"] = kwargs_solver["tol"]
+ kwargs_solver.pop("tol")
xinv, istop = sp_cg(self.Op_normal, self.y_normal, **kwargs_solver)
elif engine == "pylops" or self.ncp != np:
if show:
@@ -593,7 +594,7 @@ def run(
xinv, istop, itn, r1norm, r2norm = cgls(
self.RegOp,
self.datatot,
- self.ncp.zeros(self.RegOp.dims, dtype=self.RegOp.dtype),
+ self.ncp.zeros(self.RegOp.shape[1], dtype=self.RegOp.dtype),
**kwargs_solver,
)[0:5]
else:
diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py
index b3fa0881..b4f8fd7b 100755
--- a/pylops/signalprocessing/__init__.py
+++ b/pylops/signalprocessing/__init__.py
@@ -23,6 +23,7 @@
Shift Fractional Shift operator.
DWT One dimensional Wavelet operator.
DWT2D Two dimensional Wavelet operator.
+ DWTND N-dimensional Wavelet operator.
DCT Discrete Cosine Transform.
DTCWT Dual-Tree Complex Wavelet Transform.
Radon2D Two dimensional Radon transform.
@@ -62,6 +63,7 @@
from .fredholm1 import *
from .dwt import *
from .dwt2d import *
+from .dwtnd import *
from .seislet import *
from .dct import *
from .dtcwt import *
@@ -95,6 +97,7 @@
"Fredholm1",
"DWT",
"DWT2D",
+ "DWTND",
"Seislet",
"DCT",
"DTCWT",
diff --git a/pylops/signalprocessing/convolve1d.py b/pylops/signalprocessing/convolve1d.py
index fd82eb91..bc154a94 100644
--- a/pylops/signalprocessing/convolve1d.py
+++ b/pylops/signalprocessing/convolve1d.py
@@ -48,10 +48,10 @@ def _choose_convfunc(
def _pad_along_axis(array: np.ndarray, pad_size: tuple, axis: int = 0) -> np.ndarray:
-
+ ncp = get_array_module(array)
npad = [(0, 0)] * array.ndim
npad[axis] = pad_size
- return np.pad(array, pad_width=npad)
+ return ncp.pad(array, pad_width=npad)
class _Convolve1Dshort(LinearOperator):
@@ -67,6 +67,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
+ ncp = get_array_module(h)
dims = _value_or_sized_to_tuple(dims)
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dims, name=name)
self.axis = axis
@@ -83,7 +84,7 @@ def __init__(
(max(self.offset, 0), -min(self.offset, 0)),
axis=-1 if h.ndim == 1 else axis,
)
- self.hstar = np.flip(self.h, axis=-1)
+ self.hstar = ncp.flip(self.h, axis=-1)
# add dimensions to filter to match dimensions of model and data
if self.h.ndim == 1:
@@ -127,6 +128,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
+ ncp = get_array_module(h)
dims = _value_or_sized_to_tuple(dims)
dimsd = h.shape
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
@@ -140,13 +142,13 @@ def __init__(
self.offset = 2 * (self.dims[self.axis] // 2 - int(offset))
if self.dims[self.axis] % 2 == 0:
self.offset -= 1
- self.hstar = np.flip(self.h, axis=-1)
+ self.hstar = ncp.flip(self.h, axis=-1)
- self.pad = np.zeros((len(dims), 2), dtype=int)
+ self.pad = ncp.zeros((len(dims), 2), dtype=int)
self.pad[self.axis, 0] = max(self.offset, 0)
self.pad[self.axis, 1] = -min(self.offset, 0)
- self.padd = np.zeros((len(dims), 2), dtype=int)
+ self.padd = ncp.zeros((len(dims), 2), dtype=int)
self.padd[self.axis, 1] = max(self.offset, 0)
self.padd[self.axis, 0] = -min(self.offset, 0)
@@ -162,12 +164,13 @@ def __init__(
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if type(self.h) is not type(x):
self.h = to_cupy_conditional(x, self.h)
self.convfunc, self.method = _choose_convfunc(
self.h, self.method, self.dims, self.axis
)
- x = np.pad(x, self.pad)
+ x = ncp.pad(x, self.pad)
y = self.convfunc(self.h, x, mode="same")
return y
@@ -179,7 +182,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
self.convfunc, self.method = _choose_convfunc(
self.hstar, self.method, self.dims, self.axis
)
- x = np.pad(x, self.padd)
+ x = ncp.pad(x, self.padd)
y = self.convfunc(self.hstar, x)
if self.dims[self.axis] % 2 == 0:
y = ncp.take(
diff --git a/pylops/signalprocessing/dwtnd.py b/pylops/signalprocessing/dwtnd.py
new file mode 100644
index 00000000..af43bb0d
--- /dev/null
+++ b/pylops/signalprocessing/dwtnd.py
@@ -0,0 +1,138 @@
+__all__ = ["DWTND"]
+
+import logging
+from math import ceil, log
+
+import numpy as np
+
+from pylops import LinearOperator
+from pylops.basicoperators import Pad
+from pylops.utils import deps
+from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
+
+from .dwt import _adjointwavelet, _checkwavelet
+
+pywt_message = deps.pywt_import("the dwtnd module")
+
+if pywt_message is None:
+ import pywt
+
+logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
+
+
+class DWTND(LinearOperator):
+ """N-dimensional Wavelet operator.
+
+ Apply ND-Wavelet transform along N ``axes`` of a
+ multi-dimensional array of size ``dims``.
+
+ Note that the Wavelet operator is an overload of the ``pywt``
+ implementation of the wavelet transform. Refer to
+ https://pywavelets.readthedocs.io for a detailed description of the
+ input parameters.
+
+ Defaults to a 3D wavelet transform along the last three dimensions
+ of the input array.
+
+ Parameters
+ ----------
+ dims : :obj:`tuple`
+ Number of samples for each dimension
+ axes : :obj:`int`, optional
+ Axis along which DWTND is applied
+ wavelet : :obj:`str`, optional
+ Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for
+ a list of available wavelets.
+ level : :obj:`int`, optional
+ Number of scaling levels (must be >=0).
+ dtype : :obj:`str`, optional
+ Type of elements in input array.
+ name : :obj:`str`, optional
+ Name of operator (to be used by :func:`pylops.utils.describe.describe`)
+
+ Attributes
+ ----------
+ shape : :obj:`tuple`
+ Operator shape
+ explicit : :obj:`bool`
+ Operator contains a matrix that can be solved explicitly
+ (``True``) or not (``False``)
+
+ Raises
+ ------
+ ModuleNotFoundError
+ If ``pywt`` is not installed
+ ValueError
+ If ``wavelet`` does not belong to ``pywt.families``
+
+ Notes
+ -----
+ The Wavelet operator applies the N-dimensional multilevel Discrete
+ Wavelet Transform (DWTN) in forward mode and the N-dimensional multilevel
+ Inverse Discrete Wavelet Transform (IDWTN) in adjoint mode.
+
+ """
+
+ def __init__(
+ self,
+ dims: InputDimsLike,
+ axes: InputDimsLike = (-3, -2, -1),
+ wavelet: str = "haar",
+ level: int = 1,
+ dtype: DTypeLike = "float64",
+ name: str = "D",
+ ) -> None:
+ if pywt_message is not None:
+ raise ModuleNotFoundError(pywt_message)
+ _checkwavelet(wavelet)
+
+ # define padding for length to be power of 2
+ ndimpow2 = [max(2 ** ceil(log(dims[ax], 2)), 2**level) for ax in axes]
+ pad = [(0, 0)] * len(dims)
+ for i, ax in enumerate(axes):
+ pad[ax] = (0, ndimpow2[i] - dims[ax])
+ self.pad = Pad(dims, pad)
+ self.axes = axes
+ dimsd = list(dims)
+ for i, ax in enumerate(axes):
+ dimsd[ax] = ndimpow2[i]
+ super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
+
+ # apply transform once again to find out slices
+ _, self.sl = pywt.coeffs_to_array(
+ pywt.wavedecn(
+ np.ones(self.dimsd),
+ wavelet=wavelet,
+ level=level,
+ mode="periodization",
+ axes=self.axes,
+ ),
+ axes=self.axes,
+ )
+ self.wavelet = wavelet
+ self.waveletadj = _adjointwavelet(wavelet)
+ self.level = level
+
+ def _matvec(self, x: NDArray) -> NDArray:
+ x = self.pad.matvec(x)
+ x = np.reshape(x, self.dimsd)
+ y = pywt.coeffs_to_array(
+ pywt.wavedecn(
+ x,
+ wavelet=self.wavelet,
+ level=self.level,
+ mode="periodization",
+ axes=self.axes,
+ ),
+ axes=(self.axes),
+ )[0]
+ return y.ravel()
+
+ def _rmatvec(self, x: NDArray) -> NDArray:
+ x = np.reshape(x, self.dimsd)
+ x = pywt.array_to_coeffs(x, self.sl, output_format="wavedecn")
+ y = pywt.waverecn(
+ x, wavelet=self.waveletadj, mode="periodization", axes=self.axes
+ )
+ y = self.pad.rmatvec(y.ravel())
+ return y
diff --git a/pylops/signalprocessing/fft.py b/pylops/signalprocessing/fft.py
index 64444bcd..6af81a30 100644
--- a/pylops/signalprocessing/fft.py
+++ b/pylops/signalprocessing/fft.py
@@ -11,6 +11,7 @@
from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFT, _FFTNorms
from pylops.utils import deps
+from pylops.utils.backend import get_array_module, inplace_divide, inplace_multiply
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -60,53 +61,61 @@ def __init__(
self._scale = self.nfft
elif self.norm is _FFTNorms.ONE_OVER_N:
self._scale = 1.0 / self.nfft
+ self.slice = tuple(
+ [slice(None, None)] * (len(self.dims) - 1)
+ + [slice(1, 1 + (self.nfft - 1) // 2)]
+ )
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.ifftshift_before:
- x = np.fft.ifftshift(x, axes=self.axis)
+ x = ncp.fft.ifftshift(x, axes=self.axis)
if not self.clinear:
- x = np.real(x)
+ x = ncp.real(x)
if self.real:
- y = np.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ y = ncp.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
- y = np.swapaxes(y, -1, self.axis)
- y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2)
- y = np.swapaxes(y, self.axis, -1)
+ y = ncp.swapaxes(y, -1, self.axis)
+ # y[..., 1 : 1 + (self.nfft - 1) // 2] *= ncp.sqrt(2)
+ y = inplace_multiply(ncp.sqrt(2), y, self.slice)
+ y = ncp.swapaxes(y, self.axis, -1)
else:
- y = np.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ y = ncp.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
if self.fftshift_after:
- y = np.fft.fftshift(y, axes=self.axis)
+ y = ncp.fft.fftshift(y, axes=self.axis)
y = y.astype(self.cdtype)
return y
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.fftshift_after:
- x = np.fft.ifftshift(x, axes=self.axis)
+ x = ncp.fft.ifftshift(x, axes=self.axis)
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
- x = np.swapaxes(x, -1, self.axis)
- x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2)
- x = np.swapaxes(x, self.axis, -1)
- y = np.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ x = ncp.swapaxes(x, -1, self.axis)
+ # x[..., 1 : 1 + (self.nfft - 1) // 2] /= ncp.sqrt(2)
+ x = inplace_divide(ncp.sqrt(2), x, self.slice)
+ x = ncp.swapaxes(x, self.axis, -1)
+ y = ncp.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
else:
- y = np.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
+ y = ncp.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
if self.nfft > self.dims[self.axis]:
- y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis)
+ y = ncp.take(y, range(0, self.dims[self.axis]), axis=self.axis)
elif self.nfft < self.dims[self.axis]:
- y = np.pad(y, self.ifftpad)
+ y = ncp.pad(y, self.ifftpad)
if not self.clinear:
- y = np.real(y)
+ y = ncp.real(y)
if self.ifftshift_before:
- y = np.fft.fftshift(y, axes=self.axis)
+ y = ncp.fft.fftshift(y, axes=self.axis)
y = y.astype(self.rdtype)
return y
@@ -453,7 +462,7 @@ def FFT(
Nyquist to the frequency bin before zero.
engine : :obj:`str`, optional
Engine used for fft computation (``numpy``, ``fftw``, or ``scipy``). Choose
- ``numpy`` when working with cupy arrays.
+ ``numpy`` when working with cupy and jax arrays.
.. note:: Since version 1.17.0, accepts "scipy".
diff --git a/pylops/signalprocessing/fft2d.py b/pylops/signalprocessing/fft2d.py
index f54e2972..2f4b5f15 100644
--- a/pylops/signalprocessing/fft2d.py
+++ b/pylops/signalprocessing/fft2d.py
@@ -9,6 +9,7 @@
from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
+from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike
@@ -67,51 +68,53 @@ def __init__(
@reshaped
def _matvec(self, x):
+ ncp = get_array_module(x)
if self.ifftshift_before.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
if not self.clinear:
- x = np.real(x)
+ x = ncp.real(x)
if self.real:
- y = np.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
- y = np.swapaxes(y, -1, self.axes[-1])
- y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
- y = np.swapaxes(y, self.axes[-1], -1)
+ y = ncp.swapaxes(y, -1, self.axes[-1])
+ y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
+ y = ncp.swapaxes(y, self.axes[-1], -1)
else:
- y = np.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
if self.fftshift_after.any():
- y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after])
return y
@reshaped
def _rmatvec(self, x):
+ ncp = get_array_module(x)
if self.fftshift_after.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
- x = np.swapaxes(x, -1, self.axes[-1])
- x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
- x = np.swapaxes(x, self.axes[-1], -1)
- y = np.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ x = ncp.swapaxes(x, -1, self.axes[-1])
+ x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
+ x = ncp.swapaxes(x, self.axes[-1], -1)
+ y = ncp.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
- y = np.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
if self.nffts[0] > self.dims[self.axes[0]]:
- y = np.take(y, range(self.dims[self.axes[0]]), axis=self.axes[0])
+ y = ncp.take(y, ncp.arange(self.dims[self.axes[0]]), axis=self.axes[0])
if self.nffts[1] > self.dims[self.axes[1]]:
- y = np.take(y, range(self.dims[self.axes[1]]), axis=self.axes[1])
+ y = ncp.take(y, ncp.arange(self.dims[self.axes[1]]), axis=self.axes[1])
if self.doifftpad:
- y = np.pad(y, self.ifftpad)
+ y = ncp.pad(y, self.ifftpad)
if not self.clinear:
- y = np.real(y)
+ y = ncp.real(y)
y = y.astype(self.rdtype)
if self.ifftshift_before.any():
- y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
return y
def __truediv__(self, y):
@@ -310,7 +313,8 @@ def FFT2D(
engine : :obj:`str`, optional
.. versionadded:: 1.17.0
- Engine used for fft computation (``numpy`` or ``scipy``).
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
dtype : :obj:`str`, optional
Type of elements in input array. Note that the ``dtype`` of the operator
is the corresponding complex type even when a real type is provided.
diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py
index a33f4918..cf2de78f 100644
--- a/pylops/signalprocessing/fftnd.py
+++ b/pylops/signalprocessing/fftnd.py
@@ -7,8 +7,9 @@
import numpy as np
import numpy.typing as npt
+from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
-from pylops.utils.backend import get_sp_fft
+from pylops.utils.backend import get_array_module, get_sp_fft
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -29,6 +30,7 @@ def __init__(
ifftshift_before: bool = False,
fftshift_after: bool = False,
dtype: DTypeLike = "complex128",
+ **kwargs_fft,
) -> None:
super().__init__(
dims=dims,
@@ -46,6 +48,7 @@ def __init__(
f"numpy backend always returns complex128 dtype. To respect the passed dtype, data will be cast to {self.cdtype}."
)
+ self._kwargs_fft = kwargs_fft
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
if self.norm is _FFTNorms.ORTHO:
self._norm_kwargs["norm"] = "ortho"
@@ -56,50 +59,52 @@ def __init__(
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.ifftshift_before.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
if not self.clinear:
- x = np.real(x)
+ x = ncp.real(x)
if self.real:
- y = np.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
- y = np.swapaxes(y, -1, self.axes[-1])
- y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
- y = np.swapaxes(y, self.axes[-1], -1)
+ y = ncp.swapaxes(y, -1, self.axes[-1])
+ y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
+ y = ncp.swapaxes(y, self.axes[-1], -1)
else:
- y = np.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
if self.fftshift_after.any():
- y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after])
return y
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
if self.fftshift_after.any():
- x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
+ x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
- x = np.swapaxes(x, -1, self.axes[-1])
- x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
- x = np.swapaxes(x, self.axes[-1], -1)
- y = np.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ x = ncp.swapaxes(x, -1, self.axes[-1])
+ x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
+ x = ncp.swapaxes(x, self.axes[-1], -1)
+ y = ncp.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
- y = np.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
+ y = ncp.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
for ax, nfft in zip(self.axes, self.nffts):
if nfft > self.dims[ax]:
- y = np.take(y, range(self.dims[ax]), axis=ax)
+ y = ncp.take(y, np.arange(self.dims[ax]), axis=ax)
if self.doifftpad:
- y = np.pad(y, self.ifftpad)
+ y = ncp.pad(y, self.ifftpad)
if not self.clinear:
- y = np.real(y)
+ y = ncp.real(y)
y = y.astype(self.rdtype)
if self.ifftshift_before.any():
- y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
+ y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
return y
def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike:
@@ -122,6 +127,7 @@ def __init__(
ifftshift_before: bool = False,
fftshift_after: bool = False,
dtype: DTypeLike = "complex128",
+ **kwargs_fft,
) -> None:
super().__init__(
dims=dims,
@@ -134,7 +140,7 @@ def __init__(
fftshift_after=fftshift_after,
dtype=dtype,
)
-
+ self._kwargs_fft = kwargs_fft
self._norm_kwargs = {"norm": None} # equivalent to "backward" in Numpy/Scipy
if self.norm is _FFTNorms.ORTHO:
self._norm_kwargs["norm"] = "ortho"
@@ -209,7 +215,8 @@ def FFTND(
engine: str = "scipy",
dtype: DTypeLike = "complex128",
name: str = "F",
-):
+ **kwargs_fft,
+) -> LinearOperator:
r"""N-dimensional Fast-Fourier Transform.
Apply N-dimensional Fast-Fourier Transform (FFT) to any n ``axes``
@@ -297,7 +304,8 @@ def FFTND(
engine : :obj:`str`, optional
.. versionadded:: 1.17.0
- Engine used for fft computation (``numpy`` or ``scipy``).
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
dtype : :obj:`str`, optional
Type of elements in input array. Note that the ``dtype`` of the operator
is the corresponding complex type even when a real type is provided.
@@ -311,6 +319,10 @@ def FFTND(
.. versionadded:: 2.0.0
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
+ **kwargs_fft
+ .. versionadded:: 2.3.0
+
+ Arbitrary keyword arguments to be passed to the selected fft method
Attributes
----------
@@ -396,6 +408,7 @@ def FFTND(
ifftshift_before=ifftshift_before,
fftshift_after=fftshift_after,
dtype=dtype,
+ **kwargs_fft,
)
elif engine == "scipy":
f = _FFTND_scipy(
@@ -408,6 +421,7 @@ def FFTND(
ifftshift_before=ifftshift_before,
fftshift_after=fftshift_after,
dtype=dtype,
+ **kwargs_fft,
)
else:
raise NotImplementedError("engine must be numpy or scipy")
diff --git a/pylops/signalprocessing/fredholm1.py b/pylops/signalprocessing/fredholm1.py
index 8a64e12b..feb6c645 100644
--- a/pylops/signalprocessing/fredholm1.py
+++ b/pylops/signalprocessing/fredholm1.py
@@ -3,7 +3,7 @@
import numpy as np
from pylops import LinearOperator
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, NDArray
@@ -61,7 +61,7 @@ class Fredholm1(LinearOperator):
d(k, x, z) = \int{G(k, x, y) m(k, y, z) \,\mathrm{d}y}
\quad \forall k=1,\ldots,n_{slice}
- on the other hand its adjoin is expressed as
+ on the other hand its adjoint is expressed as
.. math::
@@ -118,7 +118,7 @@ def _matvec(self, x: NDArray) -> NDArray:
else:
y = ncp.squeeze(ncp.zeros((self.nsl, self.nx, self.nz), dtype=self.dtype))
for isl in range(self.nsl):
- y[isl] = ncp.dot(self.G[isl], x[isl])
+ y = inplace_set(ncp.dot(self.G[isl], x[isl]), y, isl)
return y
@reshaped
@@ -131,7 +131,6 @@ def _rmatvec(self, x: NDArray) -> NDArray:
if hasattr(self, "GT"):
y = ncp.matmul(self.GT, x)
else:
- # y = ncp.matmul(self.G.transpose((0, 2, 1)).conj(), x)
y = (
ncp.matmul(x.transpose(0, 2, 1).conj(), self.G)
.transpose(0, 2, 1)
@@ -141,9 +140,10 @@ def _rmatvec(self, x: NDArray) -> NDArray:
y = ncp.squeeze(ncp.zeros((self.nsl, self.ny, self.nz), dtype=self.dtype))
if hasattr(self, "GT"):
for isl in range(self.nsl):
- y[isl] = ncp.dot(self.GT[isl], x[isl])
+ y = inplace_set(ncp.dot(self.GT[isl], x[isl]), y, isl)
else:
for isl in range(self.nsl):
- # y[isl] = ncp.dot(self.G[isl].conj().T, x[isl])
- y[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()
+ y = inplace_set(
+ ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj(), y, isl
+ )
return y.ravel()
diff --git a/pylops/signalprocessing/nonstatconvolve1d.py b/pylops/signalprocessing/nonstatconvolve1d.py
index 45daeed5..669898ef 100644
--- a/pylops/signalprocessing/nonstatconvolve1d.py
+++ b/pylops/signalprocessing/nonstatconvolve1d.py
@@ -9,7 +9,7 @@
from pylops import LinearOperator
from pylops.utils._internal import _value_or_sized_to_tuple
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
@@ -147,7 +147,8 @@ def _interpolate_h(hs, ix, oh, dh, nh):
@reshaped(swapaxis=True)
def _matvec(self, x: NDArray) -> NDArray:
- y = np.zeros_like(x)
+ ncp = get_array_module(x)
+ y = ncp.zeros_like(x)
for ix in range(self.dims[self.axis]):
h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh)
xextremes = (
@@ -158,14 +159,20 @@ def _matvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)),
)
- y[..., xextremes[0] : xextremes[1]] += (
- x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # y[..., xextremes[0] : xextremes[1]] += (
+ # x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # )
+ sl = tuple(
+ [slice(None, None)] * (len(self.dimsd) - 1)
+ + [slice(xextremes[0], xextremes[1])]
)
+ y = inplace_add(x[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl)
return y
@reshaped(swapaxis=True)
def _rmatvec(self, x: NDArray) -> NDArray:
- y = np.zeros_like(x)
+ ncp = get_array_module(x)
+ y = ncp.zeros_like(x)
for ix in range(self.dims[self.axis]):
h = self._interpolate_h(self.hs, ix, self.oh, self.dh, self.nh)
xextremes = (
@@ -176,17 +183,29 @@ def _rmatvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dims[self.axis] - ix)),
)
- y[..., ix] = np.sum(
- h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]],
- axis=-1,
+ # y[..., ix] = ncp.sum(
+ # h[hextremes[0] : hextremes[1]] * x[..., xextremes[0] : xextremes[1]],
+ # axis=-1,
+ # )
+ sl = tuple([slice(None, None)] * (len(self.dimsd) - 1) + [ix])
+ y = inplace_set(
+ ncp.sum(
+ h[hextremes[0] : hextremes[1]]
+ * x[..., xextremes[0] : xextremes[1]],
+ axis=-1,
+ ),
+ y,
+ sl,
)
+
return y
def todense(self):
+ ncp = get_array_module(self.hsinterp[0])
hs = self.hsinterp
- H = np.array(
+ H = ncp.array(
[
- np.roll(np.pad(h, (0, self.dims[self.axis])), ix)
+ ncp.roll(ncp.pad(h, (0, self.dims[self.axis])), ix)
for ix, h in enumerate(hs)
]
)
@@ -317,18 +336,27 @@ def _interpolate_hadj(htmp, hs, hextremes, ix, oh, dh, nh):
"""find closest filters and spread weighted psf"""
ih_closest = int(np.floor((ix - oh) / dh))
if ih_closest < 0:
- hs[0, hextremes[0] : hextremes[1]] += htmp
+ # hs[0, hextremes[0] : hextremes[1]] += htmp
+ sl = tuple([0] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add(htmp, hs, sl)
elif ih_closest >= nh - 1:
- hs[nh - 1, hextremes[0] : hextremes[1]] += htmp
+ # hs[nh - 1, hextremes[0] : hextremes[1]] += htmp
+ sl = tuple([nh - 1] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add(htmp, hs, sl)
else:
dh_closest = (ix - oh) / dh - ih_closest
- hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp
- hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp
+ # hs[ih_closest, hextremes[0] : hextremes[1]] += (1 - dh_closest) * htmp
+ sl = tuple([ih_closest] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add((1 - dh_closest) * htmp, hs, sl)
+ # hs[ih_closest + 1, hextremes[0] : hextremes[1]] += dh_closest * htmp
+ sl = tuple([ih_closest + 1] + [slice(hextremes[0], hextremes[1])])
+ hs = inplace_add(dh_closest * htmp, hs, sl)
return hs
@reshaped
def _matvec(self, x: NDArray) -> NDArray:
- y = np.zeros(self.dimsd, dtype=self.dtype)
+ ncp = get_array_module(x)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
for ix in range(self.dimsd[0]):
h = self._interpolate_h(x, ix, self.oh, self.dh, self.nh)
xextremes = (
@@ -339,14 +367,23 @@ def _matvec(self, x: NDArray) -> NDArray:
max(0, -ix + self.hsize // 2),
min(self.hsize, self.hsize // 2 + (self.dimsd[0] - ix)),
)
- y[..., xextremes[0] : xextremes[1]] += (
- self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # y[..., xextremes[0] : xextremes[1]] += (
+ # self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]]
+ # )
+ sl = tuple(
+ [slice(None, None)] * (len(self.dimsd) - 1)
+ + [slice(xextremes[0], xextremes[1])]
+ )
+ y = inplace_add(
+ self.inp[..., ix : ix + 1] * h[hextremes[0] : hextremes[1]], y, sl
)
+
return y
@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
- hs = np.zeros(self.dims, dtype=self.dtype)
+ ncp = get_array_module(x)
+ hs = ncp.zeros(self.dims, dtype=self.dtype)
for ix in range(self.dimsd[0]):
xextremes = (
max(0, ix - self.hsize // 2),
diff --git a/pylops/signalprocessing/patch2d.py b/pylops/signalprocessing/patch2d.py
index d95ddeb5..86e496ec 100644
--- a/pylops/signalprocessing/patch2d.py
+++ b/pylops/signalprocessing/patch2d.py
@@ -9,8 +9,14 @@
import numpy as np
from pylops import LinearOperator
-from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
+from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import (
+ get_array_module,
+ get_sliding_window_view,
+ to_cupy_conditional,
+)
+from pylops.utils.decorators import reshaped
from pylops.utils.tapers import taper2d
from pylops.utils.typing import InputDimsLike, NDArray
@@ -22,6 +28,7 @@ def patch2d_design(
nwin: Tuple[int, int],
nover: Tuple[int, int],
nop: Tuple[int, int],
+ verb: bool = True,
) -> Tuple[
Tuple[int, int],
Tuple[int, int],
@@ -45,6 +52,9 @@ def patch2d_design(
Number of samples of overlapping part of window.
nop : :obj:`tuple`
Size of model in the transformed domain.
+ verb : :obj:`bool`, optional
+ Verbosity flag. If ``verb==True``, print the data
+ and model windows start-end indices
Returns
-------
@@ -73,35 +83,26 @@ def patch2d_design(
mwins_inends = ((mwin0_ins, mwin0_ends), (mwin1_ins, mwin1_ends))
# print information about patching
- logging.warning("%d-%d windows required...", nwins0, nwins1)
- logging.warning(
- "data wins - start:%s, end:%s / start:%s, end:%s",
- dwin0_ins,
- dwin0_ends,
- dwin1_ins,
- dwin1_ends,
- )
- logging.warning(
- "model wins - start:%s, end:%s / start:%s, end:%s",
- mwin0_ins,
- mwin0_ends,
- mwin1_ins,
- mwin1_ends,
- )
+ if verb:
+ logging.warning("%d-%d windows required...", nwins0, nwins1)
+ logging.warning(
+ "data wins - start:%s, end:%s / start:%s, end:%s",
+ dwin0_ins,
+ dwin0_ends,
+ dwin1_ins,
+ dwin1_ends,
+ )
+ logging.warning(
+ "model wins - start:%s, end:%s / start:%s, end:%s",
+ mwin0_ins,
+ mwin0_ends,
+ mwin1_ins,
+ mwin1_ends,
+ )
return nwins, dims, mwins_inends, dwins_inends
-def Patch2D(
- Op: LinearOperator,
- dims: InputDimsLike,
- dimsd: InputDimsLike,
- nwin: Tuple[int, int],
- nover: Tuple[int, int],
- nop: Tuple[int, int],
- tapertype: str = "hanning",
- scalings: Optional[Sequence[float]] = None,
- name: str = "P",
-) -> LinearOperator:
+class Patch2D(LinearOperator):
"""2D Patch transform operator.
Apply a transform operator ``Op`` repeatedly to patches of the model
@@ -145,6 +146,11 @@ def Patch2D(
Size of model in the transformed domain
tapertype : :obj:`str`, optional
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
+ savetaper : :obj:`bool`, optional
+ .. versionadded:: 2.3.0
+
+ Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
+ The first option is more computationally efficient, whilst the second is more memory efficient.
scalings : :obj:`tuple` or :obj:`list`, optional
Set of scalings to apply to each patch. If ``None``, no scale will be
applied
@@ -172,104 +178,265 @@ def Patch2D(
Patch3D: 3D Patching transform operator.
"""
- # data windows
- dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
- dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
- nwins0 = len(dwin0_ins)
- nwins1 = len(dwin1_ins)
- nwins = nwins0 * nwins1
-
- # check patching
- if nwins0 * nop[0] != dims[0] or nwins1 * nop[1] != dims[1]:
- raise ValueError(
- f"Model shape (dims={dims}) is not consistent with chosen "
- f"number of windows. Run patch2d_design to identify the "
- f"correct number of windows for the current "
- "model size..."
- )
- # create tapers
- if tapertype is not None:
- tap = taper2d(nwin[1], nwin[0], nover, tapertype=tapertype).astype(Op.dtype)
- taps = {itap: tap for itap in range(nwins)}
- # topmost tapers
- taptop = tap.copy()
- taptop[: nover[0]] = tap[nwin[0] // 2]
- for itap in range(0, nwins1):
- taps[itap] = taptop
- # bottommost tapers
- tapbottom = tap.copy()
- tapbottom[-nover[0] :] = tap[nwin[0] // 2]
- for itap in range(nwins - nwins1, nwins):
- taps[itap] = tapbottom
- # leftmost tapers
- tapleft = tap.copy()
- tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
- for itap in range(0, nwins, nwins1):
- taps[itap] = tapleft
- # rightmost tapers
- tapright = tap.copy()
- tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
- for itap in range(nwins1 - 1, nwins, nwins1):
- taps[itap] = tapright
- # lefttopcorner taper
- taplefttop = tap.copy()
- taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
- taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
- taps[0] = taplefttop
- # righttopcorner taper
- taprighttop = tap.copy()
- taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
- taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
- taps[nwins1 - 1] = taprighttop
- # leftbottomcorner taper
- tapleftbottom = tap.copy()
- tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
- tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
- taps[nwins - nwins1] = tapleftbottom
- # rightbottomcorner taper
- taprightbottom = tap.copy()
- taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
- taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
- taps[nwins - 1] = taprightbottom
-
- # define scalings
- if scalings is None:
- scalings = [1.0] * nwins
-
- # transform to apply
- if tapertype is None:
- OOp = BlockDiag([scalings[itap] * Op for itap in range(nwins)])
- else:
- OOp = BlockDiag(
- [
- scalings[itap] * Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op
- for itap in range(nwins)
- ]
+ def __init__(
+ self,
+ Op: LinearOperator,
+ dims: InputDimsLike,
+ dimsd: InputDimsLike,
+ nwin: Tuple[int, int],
+ nover: Tuple[int, int],
+ nop: Tuple[int, int],
+ tapertype: str = "hanning",
+ savetaper: bool = True,
+ scalings: Optional[Sequence[float]] = None,
+ name: str = "P",
+ ) -> None:
+
+ dims: Tuple[int, ...] = _value_or_sized_to_tuple(dims)
+ dimsd: Tuple[int, ...] = _value_or_sized_to_tuple(dimsd)
+
+ # data windows
+ dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
+ dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
+ self.dwins_inends = ((dwin0_ins, dwin0_ends), (dwin1_ins, dwin1_ends))
+ nwins0 = len(dwin0_ins)
+ nwins1 = len(dwin1_ins)
+ nwins = nwins0 * nwins1
+ self.nwin = nwin
+ self.nover = nover
+
+ # check patching
+ if nwins0 * nop[0] != dims[0] or nwins1 * nop[1] != dims[1]:
+ raise ValueError(
+ f"Model shape (dims={dims}) is not consistent with chosen "
+ f"number of windows. Run patch2d_design to identify the "
+ f"correct number of windows for the current "
+ "model size..."
+ )
+
+ # create tapers
+ self.tapertype = tapertype
+ self.savetaper = savetaper
+ if self.tapertype is not None:
+ tap = taper2d(nwin[1], nwin[0], nover, tapertype=tapertype).astype(Op.dtype)
+ # topmost tapers
+ taptop = tap.copy()
+ taptop[: nover[0]] = tap[nwin[0] // 2]
+ # bottommost tapers
+ tapbottom = tap.copy()
+ tapbottom[-nover[0] :] = tap[nwin[0] // 2]
+ # leftmost tapers
+ tapleft = tap.copy()
+ tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
+ # rightmost tapers
+ tapright = tap.copy()
+ tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
+ # lefttopcorner taper
+ taplefttop = tap.copy()
+ taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
+ taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
+ # righttopcorner taper
+ taprighttop = tap.copy()
+ taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
+ taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
+ # leftbottomcorner taper
+ tapleftbottom = tap.copy()
+ tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
+ tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
+ # rightbottomcorner taper
+ taprightbottom = tap.copy()
+ taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
+ taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
+
+ if self.savetaper:
+ taps = [
+ tap,
+ ] * nwins
+ for itap in range(0, nwins1):
+ taps[itap] = taptop
+ for itap in range(nwins - nwins1, nwins):
+ taps[itap] = tapbottom
+ for itap in range(0, nwins, nwins1):
+ taps[itap] = tapleft
+ for itap in range(nwins1 - 1, nwins, nwins1):
+ taps[itap] = tapright
+ taps[0] = taplefttop
+ taps[nwins1 - 1] = taprighttop
+ taps[nwins - nwins1] = tapleftbottom
+ taps[nwins - 1] = taprightbottom
+ self.taps = np.vstack(taps).reshape(nwins0, nwins1, nwin[0], nwin[1])
+ else:
+ taps = [
+ taplefttop,
+ taptop,
+ taprighttop,
+ tapleft,
+ tap,
+ tapright,
+ tapleftbottom,
+ tapbottom,
+ taprightbottom,
+ ]
+ self.taps = np.vstack(taps).reshape(3, 3, nwin[0], nwin[1])
+
+ # define scalings
+ self.scalings = [1.0] * nwins if scalings is None else scalings
+
+ # check if operator is applied to all windows simultaneously
+ self.simOp = False
+ if Op.shape[1] == np.prod(dims):
+ self.simOp = True
+ self.Op = Op
+
+ super().__init__(
+ dtype=Op.dtype,
+ dims=(nwins0, nwins1, int(dims[0] // nwins0), int(dims[1] // nwins1)),
+ dimsd=dimsd,
+ clinear=False,
+ name=name,
)
- hstack = HStack(
- [
- Restriction(
- (nwin[0], dimsd[1]), range(win_in, win_end), axis=1, dtype=Op.dtype
- ).H
- for win_in, win_end in zip(dwin1_ins, dwin1_ends)
- ]
- )
- combining1 = BlockDiag([hstack] * nwins0)
+ self._register_multiplications(self.savetaper)
- combining0 = HStack(
- [
- Restriction(dimsd, range(win_in, win_end), axis=0, dtype=Op.dtype).H
- for win_in, win_end in zip(dwin0_ins, dwin0_ends)
+ def _apply_taper(self, ywins, iwin0, iwin1):
+ if iwin0 == 0 and iwin1 == 0:
+ ywins[0, 0] = self.taps[0, 0] * ywins[0, 0]
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
+ ywins[0, -1] = self.taps[0, -1] * ywins[0, -1]
+ elif iwin0 == 0:
+ ywins[0, iwin1] = self.taps[0, 1] * ywins[0, iwin1]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
+ ywins[-1, 0] = self.taps[-1, 0] * ywins[-1, 0]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
+ ywins[-1, -1] = self.taps[-1, -1] * ywins[-1, -1]
+ elif iwin0 == self.dims[0] - 1:
+ ywins[-1, iwin1] = self.taps[-1, 1] * ywins[-1, iwin1]
+ elif iwin1 == 0:
+ ywins[iwin0, 0] = self.taps[1, 0] * ywins[iwin0, 0]
+ elif iwin1 == self.dims[1] - 1:
+ ywins[iwin0, -1] = self.taps[1, -1] * ywins[iwin0, -1]
+ else:
+ ywins[iwin0, iwin1] = self.taps[1, 1] * ywins[iwin0, iwin1]
+ return ywins
+
+ @reshaped
+ def _matvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ if self.simOp:
+ xx = x[iwin0, iwin1].reshape(self.nwin)
+ else:
+ xx = self.Op.matvec(x[iwin0, iwin1].ravel()).reshape(self.nwin)
+ if self.tapertype is not None:
+ xxwin = self.taps[iwin0, iwin1] * xx
+ else:
+ xxwin = xx
+
+ y[
+ self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
+ self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
+ ] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin)[
+ :: self.nwin[0] - self.nover[0], :: self.nwin[1] - self.nover[1]
]
- )
- Pop = LinearOperator(combining0 * combining1 * OOp)
- Pop.dims, Pop.dimsd = (
- nwins0,
- nwins1,
- int(dims[0] // nwins0),
- int(dims[1] // nwins1),
- ), dimsd
- Pop.name = name
- return Pop
+ if self.tapertype is not None:
+ ywins = ywins * self.taps
+ if self.simOp:
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ y[iwin0, iwin1] = self.Op.rmatvec(
+ ywins[iwin0, iwin1].ravel()
+ ).reshape(self.dims[2], self.dims[3])
+ return y
+
+ @reshaped
+ def _matvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ if self.simOp:
+ xxwin = x[iwin0, iwin1].reshape(self.nwin)
+ else:
+ xxwin = self.Op.matvec(x[iwin0, iwin1].ravel()).reshape(self.nwin)
+ if self.tapertype is not None:
+ if iwin0 == 0 and iwin1 == 0:
+ xxwin = self.taps[0, 0] * xxwin
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[0, -1] * xxwin
+ elif iwin0 == 0:
+ xxwin = self.taps[0, 1] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
+ xxwin = self.taps[-1, 0] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[-1, -1] * xxwin
+ elif iwin0 == self.dims[0] - 1:
+ xxwin = self.taps[-1, 1] * xxwin
+ elif iwin1 == 0:
+ xxwin = self.taps[1, 0] * xxwin
+ elif iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[1, -1] * xxwin
+ else:
+ xxwin = self.taps[1, 1] * xxwin
+
+ y[
+ self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
+ self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
+ ] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin)[
+ :: self.nwin[0] - self.nover[0], :: self.nwin[1] - self.nover[1]
+ ].copy()
+ if self.simOp:
+ if self.tapertype is not None:
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ ywins = self._apply_taper(ywins, iwin0, iwin1)
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ if self.tapertype is not None:
+ ywins = self._apply_taper(ywins, iwin0, iwin1)
+ y[iwin0, iwin1] = self.Op.rmatvec(
+ ywins[iwin0, iwin1].ravel()
+ ).reshape(self.dims[2], self.dims[3])
+ return y
+
+ def _register_multiplications(self, savetaper: bool) -> None:
+ if savetaper:
+ self._matvec = self._matvec_savetaper
+ self._rmatvec = self._rmatvec_savetaper
+ else:
+ self._matvec = self._matvec_nosavetaper
+ self._rmatvec = self._rmatvec_nosavetaper
diff --git a/pylops/signalprocessing/patch3d.py b/pylops/signalprocessing/patch3d.py
index 011963cc..ce3889d4 100644
--- a/pylops/signalprocessing/patch3d.py
+++ b/pylops/signalprocessing/patch3d.py
@@ -9,8 +9,14 @@
import numpy as np
from pylops import LinearOperator
-from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
+from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import (
+ get_array_module,
+ get_sliding_window_view,
+ to_cupy_conditional,
+)
+from pylops.utils.decorators import reshaped
from pylops.utils.tapers import tapernd
from pylops.utils.typing import InputDimsLike, NDArray
@@ -22,6 +28,7 @@ def patch3d_design(
nwin: Tuple[int, int, int],
nover: Tuple[int, int, int],
nop: Tuple[int, int, int],
+ verb: bool = True,
) -> Tuple[
Tuple[int, int, int],
Tuple[int, int, int],
@@ -45,6 +52,9 @@ def patch3d_design(
Number of samples of overlapping part of window.
nop : :obj:`tuple`
Size of model in the transformed domain.
+ verb : :obj:`bool`, optional
+ Verbosity flag. If ``verb==True``, print the data
+ and model windows start-end indices
Returns
-------
@@ -84,39 +94,30 @@ def patch3d_design(
)
# print information about patching
- logging.warning("%d-%d-%d windows required...", nwins0, nwins1, nwins2)
- logging.warning(
- "data wins - start:%s, end:%s / start:%s, end:%s / start:%s, end:%s",
- dwin0_ins,
- dwin0_ends,
- dwin1_ins,
- dwin1_ends,
- dwin2_ins,
- dwin2_ends,
- )
- logging.warning(
- "model wins - start:%s, end:%s / start:%s, end:%s / start:%s, end:%s",
- mwin0_ins,
- mwin0_ends,
- mwin1_ins,
- mwin1_ends,
- mwin2_ins,
- mwin2_ends,
- )
+ if verb:
+ logging.warning("%d-%d-%d windows required...", nwins0, nwins1, nwins2)
+ logging.warning(
+ "data wins - start:%s, end:%s / start:%s, end:%s / start:%s, end:%s",
+ dwin0_ins,
+ dwin0_ends,
+ dwin1_ins,
+ dwin1_ends,
+ dwin2_ins,
+ dwin2_ends,
+ )
+ logging.warning(
+ "model wins - start:%s, end:%s / start:%s, end:%s / start:%s, end:%s",
+ mwin0_ins,
+ mwin0_ends,
+ mwin1_ins,
+ mwin1_ends,
+ mwin2_ins,
+ mwin2_ends,
+ )
return nwins, dims, mwins_inends, dwins_inends
-def Patch3D(
- Op,
- dims: InputDimsLike,
- dimsd: InputDimsLike,
- nwin: Tuple[int, int, int],
- nover: Tuple[int, int, int],
- nop: Tuple[int, int, int],
- tapertype: str = "hanning",
- scalings: Optional[Sequence[float]] = None,
- name: str = "P",
-) -> LinearOperator:
+class Patch3D(LinearOperator):
"""3D Patch transform operator.
Apply a transform operator ``Op`` repeatedly to patches of the model
@@ -160,6 +161,11 @@ def Patch3D(
Size of model in the transformed domain
tapertype : :obj:`str`, optional
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
+ savetaper : :obj:`bool`, optional
+ .. versionadded:: 2.3.0
+
+ Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
+ The first option is more computationally efficient, whilst the second is more memory efficient.
scalings : :obj:`tuple` or :obj:`list`, optional
Set of scalings to apply to each patch. If ``None``, no scale will be
applied
@@ -185,272 +191,578 @@ def Patch3D(
Patch2D: 2D Patching transform operator.
"""
- # data windows
- dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
- dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
- dwin2_ins, dwin2_ends = _slidingsteps(dimsd[2], nwin[2], nover[2])
- nwins0 = len(dwin0_ins)
- nwins1 = len(dwin1_ins)
- nwins2 = len(dwin2_ins)
- nwins = nwins0 * nwins1 * nwins2
-
- # check patching
- if (
- nwins0 * nop[0] != dims[0]
- or nwins1 * nop[1] != dims[1]
- or nwins2 * nop[2] != dims[2]
- ):
- raise ValueError(
- f"Model shape (dims={dims}) is not consistent with chosen "
- f"number of windows. Run patch3d_design to identify the "
- f"correct number of windows for the current "
- "model size..."
+
+ def __init__(
+ self,
+ Op: LinearOperator,
+ dims: InputDimsLike,
+ dimsd: InputDimsLike,
+ nwin: Tuple[int, int, int],
+ nover: Tuple[int, int, int],
+ nop: Tuple[int, int, int],
+ tapertype: str = "hanning",
+ savetaper: bool = True,
+ scalings: Optional[Sequence[float]] = None,
+ name: str = "P",
+ ) -> None:
+
+ dims: Tuple[int, ...] = _value_or_sized_to_tuple(dims)
+ dimsd: Tuple[int, ...] = _value_or_sized_to_tuple(dimsd)
+
+ # data windows
+ dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
+ dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
+ dwin2_ins, dwin2_ends = _slidingsteps(dimsd[2], nwin[2], nover[2])
+ self.dwins_inends = (
+ (dwin0_ins, dwin0_ends),
+ (dwin1_ins, dwin1_ends),
+ (dwin2_ins, dwin2_ends),
)
+ nwins0 = len(dwin0_ins)
+ nwins1 = len(dwin1_ins)
+ nwins2 = len(dwin2_ins)
+ nwins = nwins0 * nwins1 * nwins2
+ self.nwin = nwin
+ self.nover = nover
- # create tapers
- if tapertype is not None:
- tap = tapernd(nwin, nover, tapertype=tapertype).astype(Op.dtype)
- taps = {itap: tap for itap in range(nwins)}
- # 1, sides
- # topmost tapers
- taptop = tap.copy()
- taptop[: nover[0]] = tap[nwin[0] // 2]
- for itap in range(0, nwins1 * nwins2):
- taps[itap] = taptop
- # bottommost tapers
- tapbottom = tap.copy()
- tapbottom[-nover[0] :] = tap[nwin[0] // 2]
- for itap in range(nwins - nwins1 * nwins2, nwins):
- taps[itap] = tapbottom
- # frontmost tapers
- tapfront = tap.copy()
- tapfront[:, :, : nover[2]] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
- for itap in range(0, nwins, nwins2):
- taps[itap] = tapfront
- # backmost tapers
- tapback = tap.copy()
- tapback[:, :, -nover[2] :] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
- for itap in range(nwins2 - 1, nwins, nwins2):
- taps[itap] = tapback
- # leftmost tapers
- tapleft = tap.copy()
- tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- for itap in range(0, nwins, nwins1 * nwins2):
- for i in range(nwins2):
- taps[itap + i] = tapleft
- # rightmost tapers
- tapright = tap.copy()
- tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- for itap in range(nwins2 * (nwins1 - 1), nwins, nwins2 * nwins1):
- for i in range(nwins2):
- taps[itap + i] = tapright
- # 2. pillars
- # topleftmost tapers
- taplefttop = tap.copy()
- taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
- for itap in range(nwins2):
- taps[itap] = taplefttop
- # toprightmost tapers
- taprighttop = tap.copy()
- taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
- for itap in range(nwins2):
- taps[nwins2 * (nwins1 - 1) + itap] = taprighttop
- # topfrontmost tapers
- tapfronttop = tap.copy()
- tapfronttop[:, :, : nover[2]] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
- tapfronttop[: nover[0]] = tapfronttop[nwin[0] // 2]
- for itap in range(0, nwins1 * nwins2, nwins2):
- taps[itap] = tapfronttop
- # topbackmost tapers
- tapbacktop = tap.copy()
- tapbacktop[:, :, -nover[2] :] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
- tapbacktop[: nover[0]] = tapbacktop[nwin[0] // 2]
- for itap in range(nwins2 - 1, nwins1 * nwins2, nwins2):
- taps[itap] = tapbacktop
- # bottomleftmost tapers
- tapleftbottom = tap.copy()
- tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
- for itap in range(nwins2):
- taps[(nwins0 - 1) * nwins1 * nwins2 + itap] = tapleftbottom
- # bottomrightmost tapers
- taprightbottom = tap.copy()
- taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
- for itap in range(nwins2):
- taps[
- (nwins0 - 1) * nwins1 * nwins2 + (nwins1 - 1) * nwins2 + itap
- ] = taprightbottom
- # bottomfrontmost tapers
- tapfrontbottom = tap.copy()
- tapfrontbottom[:, :, : nover[2]] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
- tapfrontbottom[-nover[0] :] = tapfrontbottom[nwin[0] // 2]
- for itap in range(0, nwins1 * nwins2, nwins2):
- taps[(nwins0 - 1) * nwins1 * nwins2 + itap] = tapfrontbottom
- # bottombackmost tapers
- tapbackbottom = tap.copy()
- tapbackbottom[:, :, -nover[2] :] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
- tapbackbottom[-nover[0] :] = tapbackbottom[nwin[0] // 2]
- for itap in range(0, nwins1 * nwins2, nwins2):
- taps[(nwins0 - 1) * nwins1 * nwins2 + nwins2 + itap - 1] = tapbackbottom
- # leftfrontmost tapers
- tapleftfront = tap.copy()
- tapleftfront[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- tapleftfront[:, :, : nover[2]] = tapleftfront[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- for itap in range(0, nwins, nwins1 * nwins2):
- taps[itap] = tapleftfront
- # rightfrontmost tapers
- taprightfront = tap.copy()
- taprightfront[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- taprightfront[:, :, : nover[2]] = taprightfront[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- for itap in range(0, nwins, nwins1 * nwins2):
- taps[(nwins1 - 1) * nwins2 + itap] = taprightfront
- # leftbackmost tapers
- tapleftback = tap.copy()
- tapleftback[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- tapleftback[:, :, -nover[2] :] = tapleftback[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- for itap in range(0, nwins, nwins1 * nwins2):
- taps[nwins2 + itap - 1] = tapleftback
- # rightbackmost tapers
- taprightback = tap.copy()
- taprightback[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
- taprightback[:, :, -nover[2] :] = taprightback[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- for itap in range(0, nwins, nwins1 * nwins2):
- taps[(nwins1 - 1) * nwins2 + nwins2 + itap - 1] = taprightback
- # 3. corners
- # lefttopfrontcorner taper
- taplefttop = tap.copy()
- taplefttop[: nover[0]] = tap[nwin[0] // 2]
- taplefttop[:, : nover[1]] = taplefttop[:, nwin[1] // 2][:, np.newaxis, :]
- taplefttop[:, :, : nover[2]] = taplefttop[:, :, nwin[2] // 2][:, :, np.newaxis]
- taps[0] = taplefttop
- # lefttopbackcorner taper
- taplefttop = tap.copy()
- taplefttop[: nover[0]] = tap[nwin[0] // 2]
- taplefttop[:, : nover[1]] = taplefttop[:, nwin[1] // 2][:, np.newaxis, :]
- taplefttop[:, :, -nover[2] :] = taplefttop[:, :, nwin[2] // 2][:, :, np.newaxis]
- taps[nwins2 - 1] = taplefttop
- # righttopfrontcorner taper
- taprighttop = tap.copy()
- taprighttop[: nover[0]] = tap[nwin[0] // 2]
- taprighttop[:, -nover[1] :] = taprighttop[:, nwin[1] // 2][:, np.newaxis, :]
- taprighttop[:, :, : nover[2]] = taprighttop[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- taps[(nwins1 - 1) * nwins2] = taprighttop
- # righttopbackcorner taper
- taprighttop = tap.copy()
- taprighttop[: nover[0]] = tap[nwin[0] // 2]
- taprighttop[:, -nover[1] :] = taprighttop[:, nwin[1] // 2][:, np.newaxis, :]
- taprighttop[:, :, -nover[2] :] = taprighttop[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- taps[(nwins1 - 1) * nwins2 + nwins2 - 1] = taprighttop
- # leftbottomfrontcorner taper
- tapleftbottom = tap.copy()
- tapleftbottom[-nover[0] :] = tap[nwin[0] // 2]
- tapleftbottom[:, : nover[1]] = tapleftbottom[:, nwin[1] // 2][:, np.newaxis, :]
- tapleftbottom[:, :, : nover[2]] = tapleftbottom[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- taps[(nwins0 - 1) * nwins1 * nwins2] = tapleftbottom
- # leftbottombackcorner taper
- tapleftbottom = tap.copy()
- tapleftbottom[-nover[0] :] = tap[nwin[0] // 2]
- tapleftbottom[:, : nover[1]] = tapleftbottom[:, nwin[1] // 2][:, np.newaxis, :]
- tapleftbottom[:, :, -nover[2] :] = tapleftbottom[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- taps[(nwins0 - 1) * nwins1 * nwins2 + nwins2 - 1] = tapleftbottom
- # rightbottomfrontcorner taper
- taprightbottom = tap.copy()
- taprightbottom[-nover[0] :] = tap[nwin[0] // 2]
- taprightbottom[:, -nover[1] :] = taprightbottom[:, nwin[1] // 2][
- :, np.newaxis, :
- ]
- taprightbottom[:, :, : nover[2]] = taprightbottom[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- taps[(nwins0 - 1) * nwins1 * nwins2 + (nwins1 - 1) * nwins2] = taprightbottom
- # rightbottombackcorner taper
- taprightbottom = tap.copy()
- taprightbottom[-nover[0] :] = tap[nwin[0] // 2]
- taprightbottom[:, -nover[1] :] = taprightbottom[:, nwin[1] // 2][
- :, np.newaxis, :
- ]
- taprightbottom[:, :, -nover[2] :] = taprightbottom[:, :, nwin[2] // 2][
- :, :, np.newaxis
- ]
- taps[
- (nwins0 - 1) * nwins1 * nwins2 + (nwins1 - 1) * nwins2 + nwins2 - 1
- ] = taprightbottom
-
- # define scalings
- if scalings is None:
- scalings = [1.0] * nwins
-
- # transform to apply
- if tapertype is None:
- OOp = BlockDiag([scalings[itap] * Op for itap in range(nwins)])
- else:
- OOp = BlockDiag(
- [
- scalings[itap] * Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op
- for itap in range(nwins)
+ # check patching
+ if (
+ nwins0 * nop[0] != dims[0]
+ or nwins1 * nop[1] != dims[1]
+ or nwins2 * nop[2] != dims[2]
+ ):
+ raise ValueError(
+ f"Model shape (dims={dims}) is not consistent with chosen "
+ f"number of windows. Run patch3d_design to identify the "
+ f"correct number of windows for the current "
+ "model size..."
+ )
+
+ # create tapers
+ self.tapertype = tapertype
+ self.savetaper = savetaper
+ if tapertype is not None:
+ tap = tapernd(nwin, nover, tapertype=tapertype).astype(Op.dtype)
+ # 1, sides
+ # topmost tapers
+ taptop = tap.copy()
+ taptop[: nover[0]] = tap[nwin[0] // 2]
+ # bottommost tapers
+ tapbottom = tap.copy()
+ tapbottom[-nover[0] :] = tap[nwin[0] // 2]
+ # frontmost tapers
+ tapfront = tap.copy()
+ tapfront[:, :, : nover[2]] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
+ # backmost tapers
+ tapback = tap.copy()
+ tapback[:, :, -nover[2] :] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
+ # leftmost tapers
+ tapleft = tap.copy()
+ tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ # rightmost tapers
+ tapright = tap.copy()
+ tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+
+ # 2. pillars
+ # topleftmost tapers
+ taplefttop = tap.copy()
+ taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
+ # toprightmost tapers
+ taprighttop = tap.copy()
+ taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
+ # topfrontmost tapers
+ tapfronttop = tap.copy()
+ tapfronttop[:, :, : nover[2]] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
+ tapfronttop[: nover[0]] = tapfronttop[nwin[0] // 2]
+ # topbackmost tapers
+ tapbacktop = tap.copy()
+ tapbacktop[:, :, -nover[2] :] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
+ tapbacktop[: nover[0]] = tapbacktop[nwin[0] // 2]
+ # bottomleftmost tapers
+ tapleftbottom = tap.copy()
+ tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
+ # bottomrightmost tapers
+ taprightbottom = tap.copy()
+ taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
+ # bottomfrontmost tapers
+ tapfrontbottom = tap.copy()
+ tapfrontbottom[:, :, : nover[2]] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
+ tapfrontbottom[-nover[0] :] = tapfrontbottom[nwin[0] // 2]
+ # bottombackmost tapers
+ tapbackbottom = tap.copy()
+ tapbackbottom[:, :, -nover[2] :] = tap[:, :, nwin[2] // 2][:, :, np.newaxis]
+ tapbackbottom[-nover[0] :] = tapbackbottom[nwin[0] // 2]
+ # leftfrontmost tapers
+ tapleftfront = tap.copy()
+ tapleftfront[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ tapleftfront[:, :, : nover[2]] = tapleftfront[:, :, nwin[2] // 2][
+ :, :, np.newaxis
+ ]
+ # rightfrontmost tapers
+ taprightfront = tap.copy()
+ taprightfront[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ taprightfront[:, :, : nover[2]] = taprightfront[:, :, nwin[2] // 2][
+ :, :, np.newaxis
]
+ # leftbackmost tapers
+ tapleftback = tap.copy()
+ tapleftback[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ tapleftback[:, :, -nover[2] :] = tapleftback[:, :, nwin[2] // 2][
+ :, :, np.newaxis
+ ]
+ # rightbackmost tapers
+ taprightback = tap.copy()
+ taprightback[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis, :]
+ taprightback[:, :, -nover[2] :] = taprightback[:, :, nwin[2] // 2][
+ :, :, np.newaxis
+ ]
+
+ # 3. corners
+ # lefttopfrontcorner taper
+ taplefttopfront = tap.copy()
+ taplefttopfront[: nover[0]] = tap[nwin[0] // 2]
+ taplefttopfront[:, : nover[1]] = taplefttopfront[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ taplefttopfront[:, :, : nover[2]] = taplefttopfront[:, :, nwin[2] // 2][
+ :, :, np.newaxis
+ ]
+ # lefttopbackcorner taper
+ taplefttopback = tap.copy()
+ taplefttopback[: nover[0]] = tap[nwin[0] // 2]
+ taplefttopback[:, : nover[1]] = taplefttopback[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ taplefttopback[:, :, -nover[2] :] = taplefttopback[:, :, nwin[2] // 2][
+ :, :, np.newaxis
+ ]
+ # righttopfrontcorner taper
+ taprighttopfront = tap.copy()
+ taprighttopfront[: nover[0]] = tap[nwin[0] // 2]
+ taprighttopfront[:, -nover[1] :] = taprighttopfront[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ taprighttopfront[:, :, : nover[2]] = taprighttopfront[:, :, nwin[2] // 2][
+ :, :, np.newaxis
+ ]
+ # righttopbackcorner taper
+ taprighttopback = tap.copy()
+ taprighttopback[: nover[0]] = tap[nwin[0] // 2]
+ taprighttopback[:, -nover[1] :] = taprighttopback[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ taprighttopback[:, :, -nover[2] :] = taprighttopback[:, :, nwin[2] // 2][
+ :, :, np.newaxis
+ ]
+ # leftbottomfrontcorner taper
+ tapleftbottomfront = tap.copy()
+ tapleftbottomfront[-nover[0] :] = tap[nwin[0] // 2]
+ tapleftbottomfront[:, : nover[1]] = tapleftbottomfront[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ tapleftbottomfront[:, :, : nover[2]] = tapleftbottomfront[
+ :, :, nwin[2] // 2
+ ][:, :, np.newaxis]
+ # leftbottombackcorner taper
+ tapleftbottomback = tap.copy()
+ tapleftbottomback[-nover[0] :] = tap[nwin[0] // 2]
+ tapleftbottomback[:, : nover[1]] = tapleftbottomback[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ tapleftbottomback[:, :, -nover[2] :] = tapleftbottomback[
+ :, :, nwin[2] // 2
+ ][:, :, np.newaxis]
+ # rightbottomfrontcorner taper
+ taprightbottomfront = tap.copy()
+ taprightbottomfront[-nover[0] :] = tap[nwin[0] // 2]
+ taprightbottomfront[:, -nover[1] :] = taprightbottomfront[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ taprightbottomfront[:, :, : nover[2]] = taprightbottomfront[
+ :, :, nwin[2] // 2
+ ][:, :, np.newaxis]
+ # rightbottombackcorner taper
+ taprightbottomback = tap.copy()
+ taprightbottomback[-nover[0] :] = tap[nwin[0] // 2]
+ taprightbottomback[:, -nover[1] :] = taprightbottomback[:, nwin[1] // 2][
+ :, np.newaxis, :
+ ]
+ taprightbottomback[:, :, -nover[2] :] = taprightbottomback[
+ :, :, nwin[2] // 2
+ ][:, :, np.newaxis]
+ if self.savetaper:
+ taps = [
+ tap,
+ ] * nwins
+ for itap in range(0, nwins1 * nwins2):
+ taps[itap] = taptop
+ for itap in range(nwins - nwins1 * nwins2, nwins):
+ taps[itap] = tapbottom
+ for itap in range(0, nwins, nwins2):
+ taps[itap] = tapfront
+ for itap in range(nwins2 - 1, nwins, nwins2):
+ taps[itap] = tapback
+ for itap in range(0, nwins, nwins1 * nwins2):
+ for i in range(nwins2):
+ taps[itap + i] = tapleft
+ for itap in range(nwins2 * (nwins1 - 1), nwins, nwins2 * nwins1):
+ for i in range(nwins2):
+ taps[itap + i] = tapright
+ for itap in range(nwins2):
+ taps[itap] = taplefttop
+ for itap in range(nwins2):
+ taps[nwins2 * (nwins1 - 1) + itap] = taprighttop
+ for itap in range(0, nwins1 * nwins2, nwins2):
+ taps[itap] = tapfronttop
+ for itap in range(nwins2 - 1, nwins1 * nwins2, nwins2):
+ taps[itap] = tapbacktop
+ for itap in range(nwins2):
+ taps[(nwins0 - 1) * nwins1 * nwins2 + itap] = tapleftbottom
+ for itap in range(nwins2):
+ taps[
+ (nwins0 - 1) * nwins1 * nwins2 + (nwins1 - 1) * nwins2 + itap
+ ] = taprightbottom
+ for itap in range(0, nwins1 * nwins2, nwins2):
+ taps[(nwins0 - 1) * nwins1 * nwins2 + itap] = tapfrontbottom
+ for itap in range(0, nwins1 * nwins2, nwins2):
+ taps[
+ (nwins0 - 1) * nwins1 * nwins2 + nwins2 + itap - 1
+ ] = tapbackbottom
+ for itap in range(0, nwins, nwins1 * nwins2):
+ taps[itap] = tapleftfront
+ for itap in range(0, nwins, nwins1 * nwins2):
+ taps[(nwins1 - 1) * nwins2 + itap] = taprightfront
+ for itap in range(0, nwins, nwins1 * nwins2):
+ taps[nwins2 + itap - 1] = tapleftback
+ for itap in range(0, nwins, nwins1 * nwins2):
+ taps[(nwins1 - 1) * nwins2 + nwins2 + itap - 1] = taprightback
+ taps[0] = taplefttopfront
+ taps[nwins2 - 1] = taplefttopback
+ taps[(nwins1 - 1) * nwins2] = taprighttopfront
+ taps[(nwins1 - 1) * nwins2 + nwins2 - 1] = taprighttopback
+ taps[(nwins0 - 1) * nwins1 * nwins2] = tapleftbottomfront
+ taps[(nwins0 - 1) * nwins1 * nwins2 + nwins2 - 1] = tapleftbottomback
+ taps[
+ (nwins0 - 1) * nwins1 * nwins2 + (nwins1 - 1) * nwins2
+ ] = taprightbottomfront
+ taps[
+ (nwins0 - 1) * nwins1 * nwins2 + (nwins1 - 1) * nwins2 + nwins2 - 1
+ ] = taprightbottomback
+ self.taps = np.vstack(taps).reshape(
+ nwins0, nwins1, nwins2, nwin[0], nwin[1], nwin[2]
+ )
+ else:
+ self.taps = np.zeros(
+ (3, 3, 3, nwin[0], nwin[1], nwin[2]), dtype=Op.dtype
+ )
+ self.taps[0, 0, 0] = taplefttopfront
+ self.taps[0, 0, 1] = taplefttop
+ self.taps[0, 0, -1] = taplefttopback
+ self.taps[0, 1, 0] = tapfronttop
+ self.taps[0, 1, 1] = taptop
+ self.taps[0, 1, -1] = tapbacktop
+ self.taps[0, -1, 0] = taprighttopfront
+ self.taps[0, -1, 1] = taprighttop
+ self.taps[0, -1, -1] = taprighttopback
+
+ self.taps[1, 0, 0] = tapleftfront
+ self.taps[1, 0, 1] = tapleft
+ self.taps[1, 0, -1] = tapleftback
+ self.taps[1, 1, 0] = tapfront
+ self.taps[1, 1, 1] = tap
+ self.taps[1, 1, -1] = tapback
+ self.taps[1, -1, 0] = taprightfront
+ self.taps[1, -1, 1] = tapright
+ self.taps[1, -1, -1] = taprightback
+
+ self.taps[-1, 0, 0] = tapleftbottomfront
+ self.taps[-1, 0, 1] = tapleftbottom
+ self.taps[-1, 0, -1] = tapleftbottomback
+ self.taps[-1, 1, 0] = tapfrontbottom
+ self.taps[-1, 1, 1] = tapbottom
+ self.taps[-1, 1, -1] = tapbackbottom
+ self.taps[-1, -1, 0] = taprightbottomfront
+ self.taps[-1, -1, 1] = taprightbottom
+ self.taps[-1, -1, -1] = taprightbottomback
+
+ # define scalings
+ self.scalings = [1.0] * nwins if scalings is None else scalings
+
+ # check if operator is applied to all windows simultaneously
+ self.simOp = False
+ if Op.shape[1] == np.prod(dims):
+ self.simOp = True
+ self.Op = Op
+
+ super().__init__(
+ dtype=Op.dtype,
+ dims=(
+ nwins0,
+ nwins1,
+ nwins2,
+ int(dims[0] // nwins0),
+ int(dims[1] // nwins1),
+ int(dims[2] // nwins2),
+ ),
+ dimsd=dimsd,
+ clinear=False,
+ name=name,
)
- hstack2 = HStack(
- [
- Restriction(
- (nwin[0], nwin[1], dimsd[2]),
- range(win_in, win_end),
- axis=2,
- dtype=Op.dtype,
- ).H
- for win_in, win_end in zip(dwin2_ins, dwin2_ends)
- ]
- )
- combining2 = BlockDiag([hstack2] * (nwins1 * nwins0))
-
- hstack1 = HStack(
- [
- Restriction(
- (nwin[0], dimsd[1], dimsd[2]),
- range(win_in, win_end),
- axis=1,
- dtype=Op.dtype,
- ).H
- for win_in, win_end in zip(dwin1_ins, dwin1_ends)
- ]
- )
- combining1 = BlockDiag([hstack1] * nwins0)
+ self._register_multiplications(self.savetaper)
+
+ def _apply_taper(self, ywins, iwin0, iwin1, iwin2):
+ if iwin0 == 0 and iwin1 == 0 and iwin2 == 0:
+ ywins[0, 0, 0] = self.taps[0, 0, 0] * ywins[0, 0, 0]
+ elif iwin0 == 0 and iwin1 == 0 and iwin2 == self.dims[2] - 1:
+ ywins[0, 0, -1] = self.taps[0, 0, -1] * ywins[0, 0, -1]
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1 and iwin2 == self.dims[2] - 1:
+ ywins[0, -1, -1] = self.taps[0, -1, -1] * ywins[0, -1, -1]
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1 and iwin2 == 0:
+ ywins[0, -1, 0] = self.taps[0, -1, 0] * ywins[0, -1, 0]
+ elif iwin0 == 0 and iwin1 == 0:
+ ywins[0, 0, iwin2] = self.taps[0, 0, 1] * ywins[0, 0, iwin2]
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
+ ywins[0, -1, iwin2] = self.taps[0, -1, 1] * ywins[0, -1, iwin2]
+ elif iwin0 == 0 and iwin2 == 0:
+ ywins[0, iwin1, 0] = self.taps[0, 1, 0] * ywins[0, iwin1, 0]
+ elif iwin0 == 0 and iwin2 == self.dims[2] - 1:
+ ywins[0, iwin1, -1] = self.taps[0, 1, -1] * ywins[0, iwin1, -1]
+ elif iwin0 == 0:
+ ywins[0, iwin1, iwin2] = self.taps[0, 1, 1] * ywins[0, iwin1, iwin2]
+
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0 and iwin2 == 0:
+ ywins[-1, 0, 0] = self.taps[-1, 0, 0] * ywins[-1, 0, 0]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0 and iwin2 == self.dims[2] - 1:
+ ywins[-1, 0, -1] = self.taps[-1, 0, -1] * ywins[-1, 0, -1]
+ elif (
+ iwin0 == self.dims[0] - 1
+ and iwin1 == self.dims[1] - 1
+ and iwin2 == self.dims[2] - 1
+ ):
+ ywins[-1, -1, -1] = self.taps[-1, -1, -1] * ywins[-1, -1, -1]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1 and iwin2 == 0:
+ ywins[-1, -1, 0] = self.taps[-1, -1, 0] * ywins[-1, -1, 0]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
+ ywins[-1, 0, iwin2] = self.taps[-1, 0, 1] * ywins[-1, 0, iwin2]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
+ ywins[-1, -1, iwin2] = self.taps[-1, -1, 1] * ywins[-1, -1, iwin2]
+ elif iwin0 == self.dims[0] - 1 and iwin2 == 0:
+ ywins[-1, iwin1, 0] = self.taps[-1, 1, 0] * ywins[-1, iwin1, 0]
+ elif iwin0 == self.dims[0] - 1 and iwin2 == self.dims[2] - 1:
+ ywins[-1, iwin1, -1] = self.taps[-1, 1, -1] * ywins[-1, iwin1, -1]
+ elif iwin0 == self.dims[0] - 1:
+ ywins[-1, iwin1, iwin2] = self.taps[-1, 1, 1] * ywins[-1, iwin1, iwin2]
- combining0 = HStack(
- [
- Restriction(dimsd, range(win_in, win_end), axis=0, dtype=Op.dtype).H
- for win_in, win_end in zip(dwin0_ins, dwin0_ends)
+ elif iwin1 == 0 and iwin2 == 0:
+ ywins[iwin0, 0, 0] = self.taps[1, 0, 0] * ywins[iwin0, 0, 0]
+ elif iwin1 == 0 and iwin2 == self.dims[2] - 1:
+ ywins[iwin0, 0, -1] = self.taps[1, 0, -1] * ywins[iwin0, 0, -1]
+ elif iwin1 == self.dims[1] - 1 and iwin2 == self.dims[2] - 1:
+ ywins[iwin0, -1, -1] = self.taps[1, -1, -1] * ywins[iwin0, -1, -1]
+ elif iwin1 == self.dims[1] - 1 and iwin2 == 0:
+ ywins[iwin0, -1, 0] = self.taps[1, -1, 0] * ywins[iwin0, -1, 0]
+ elif iwin1 == 0:
+ ywins[iwin0, 0, iwin2] = self.taps[1, 0, 1] * ywins[iwin0, 0, iwin2]
+ elif iwin1 == self.dims[1] - 1:
+ ywins[iwin0, -1, iwin2] = self.taps[1, -1, 1] * ywins[iwin0, -1, iwin2]
+ elif iwin2 == 0:
+ ywins[iwin0, iwin1, 0] = self.taps[1, 1, 0] * ywins[iwin0, iwin1, 0]
+ elif iwin2 == self.dims[2] - 1:
+ ywins[iwin0, iwin1, -1] = self.taps[1, 1, -1] * ywins[iwin0, iwin1, -1]
+ else:
+ ywins[iwin0, iwin1, iwin2] = self.taps[1, 1, 1] * ywins[iwin0, iwin1, iwin2]
+ return ywins
+
+ @reshaped
+ def _matvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ for iwin2 in range(self.dims[2]):
+ if self.simOp:
+ xx = x[iwin0, iwin1, iwin2].reshape(self.nwin)
+ else:
+ xx = self.Op.matvec(x[iwin0, iwin1, iwin2].ravel()).reshape(
+ self.nwin
+ )
+ if self.tapertype is not None:
+ xxwin = self.taps[iwin0, iwin1, iwin2] * xx
+ else:
+ xxwin = xx
+
+ y[
+ self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
+ self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
+ self.dwins_inends[2][0][iwin2] : self.dwins_inends[2][1][iwin2],
+ ] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin)[
+ :: self.nwin[0] - self.nover[0],
+ :: self.nwin[1] - self.nover[1],
+ :: self.nwin[2] - self.nover[2],
]
- )
+ if self.tapertype is not None:
+ ywins = ywins * self.taps
+ if self.simOp:
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ for iwin2 in range(self.dims[2]):
+ y[iwin0, iwin1, iwin2] = self.Op.rmatvec(
+ ywins[iwin0, iwin1, iwin2].ravel()
+ ).reshape(self.dims[3], self.dims[4], self.dims[5])
+ return y
+
+ @reshaped
+ def _matvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ for iwin2 in range(self.dims[2]):
+ if self.simOp:
+ xxwin = x[iwin0, iwin1, iwin2].reshape(self.nwin)
+ else:
+ xxwin = self.Op.matvec(x[iwin0, iwin1, iwin2].ravel()).reshape(
+ self.nwin
+ )
+ if self.tapertype is not None:
+ if iwin0 == 0 and iwin1 == 0 and iwin2 == 0:
+ xxwin = self.taps[0, 0, 0] * xxwin
+ elif iwin0 == 0 and iwin1 == 0 and iwin2 == self.dims[2] - 1:
+ xxwin = self.taps[0, 0, -1] * xxwin
+ elif (
+ iwin0 == 0
+ and iwin1 == self.dims[1] - 1
+ and iwin2 == self.dims[2] - 1
+ ):
+ xxwin = self.taps[0, -1, -1] * xxwin
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1 and iwin2 == 0:
+ xxwin = self.taps[0, -1, 0] * xxwin
+ elif iwin0 == 0 and iwin1 == 0:
+ xxwin = self.taps[0, 0, 1] * xxwin
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[0, -1, 1] * xxwin
+ elif iwin0 == 0 and iwin2 == 0:
+ xxwin = self.taps[0, 1, 0] * xxwin
+ elif iwin0 == 0 and iwin2 == self.dims[2] - 1:
+ xxwin = self.taps[0, 1, -1] * xxwin
+ elif iwin0 == 0:
+ xxwin = self.taps[0, 1, 1] * xxwin
+
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0 and iwin2 == 0:
+ xxwin = self.taps[-1, 0, 0] * xxwin
+ elif (
+ iwin0 == self.dims[0] - 1
+ and iwin1 == 0
+ and iwin2 == self.dims[2] - 1
+ ):
+ xxwin = self.taps[-1, 0, -1] * xxwin
+ elif (
+ iwin0 == self.dims[0] - 1
+ and iwin1 == self.dims[1] - 1
+ and iwin2 == self.dims[2] - 1
+ ):
+ xxwin = self.taps[-1, -1, -1] * xxwin
+ elif (
+ iwin0 == self.dims[0] - 1
+ and iwin1 == self.dims[1] - 1
+ and iwin2 == 0
+ ):
+ xxwin = self.taps[-1, -1, 0] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
+ xxwin = self.taps[-1, 0, 1] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[-1, -1, 1] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin2 == 0:
+ xxwin = self.taps[-1, 1, 0] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin2 == self.dims[2] - 1:
+ xxwin = self.taps[-1, 1, -1] * xxwin
+ elif iwin0 == self.dims[0] - 1:
+ xxwin = self.taps[-1, 1, 1] * xxwin
+
+ elif iwin1 == 0 and iwin2 == 0:
+ xxwin = self.taps[1, 0, 0] * xxwin
+ elif iwin1 == 0 and iwin2 == self.dims[2] - 1:
+ xxwin = self.taps[1, 0, -1] * xxwin
+ elif iwin1 == self.dims[1] - 1 and iwin2 == self.dims[2] - 1:
+ xxwin = self.taps[1, -1, -1] * xxwin
+ elif iwin1 == self.dims[1] - 1 and iwin2 == 0:
+ xxwin = self.taps[1, -1, 0] * xxwin
+ elif iwin1 == 0:
+ xxwin = self.taps[1, 0, 1] * xxwin
+ elif iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[1, -1, 1] * xxwin
+ elif iwin2 == 0:
+ xxwin = self.taps[1, 1, 0] * xxwin
+ elif iwin2 == self.dims[2] - 1:
+ xxwin = self.taps[1, 1, -1] * xxwin
+ else:
+ xxwin = self.taps[1, 1, 1] * xxwin
+ y[
+ self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
+ self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
+ self.dwins_inends[2][0][iwin2] : self.dwins_inends[2][1][iwin2],
+ ] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin)[
+ :: self.nwin[0] - self.nover[0],
+ :: self.nwin[1] - self.nover[1],
+ :: self.nwin[2] - self.nover[2],
+ ].copy()
+ if self.simOp:
+ if self.tapertype is not None:
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ for iwin2 in range(self.dims[2]):
+ ywins = self._apply_taper(ywins, iwin0, iwin1, iwin2)
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ for iwin2 in range(self.dims[2]):
+ if self.tapertype is not None:
+ ywins = self._apply_taper(ywins, iwin0, iwin1, iwin2)
+ y[iwin0, iwin1, iwin2] = self.Op.rmatvec(
+ ywins[iwin0, iwin1, iwin2].ravel()
+ ).reshape(self.dims[3], self.dims[4], self.dims[5])
+ return y
- Pop = LinearOperator(combining0 * combining1 * combining2 * OOp)
- Pop.dims, Pop.dimsd = (
- nwins0,
- nwins1,
- nwins2,
- int(dims[0] // nwins0),
- int(dims[1] // nwins1),
- int(dims[2] // nwins2),
- ), dimsd
- Pop.name = name
- return Pop
+ def _register_multiplications(self, savetaper: bool) -> None:
+ if savetaper:
+ self._matvec = self._matvec_savetaper
+ self._rmatvec = self._rmatvec_savetaper
+ else:
+ self._matvec = self._matvec_nosavetaper
+ self._rmatvec = self._rmatvec_nosavetaper
diff --git a/pylops/signalprocessing/sliding1d.py b/pylops/signalprocessing/sliding1d.py
index 5cb8c213..1726615a 100644
--- a/pylops/signalprocessing/sliding1d.py
+++ b/pylops/signalprocessing/sliding1d.py
@@ -6,10 +6,17 @@
import logging
from typing import Tuple, Union
+import numpy as np
+
from pylops import LinearOperator
-from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import (
+ get_array_module,
+ get_sliding_window_view,
+ to_cupy_conditional,
+)
+from pylops.utils.decorators import reshaped
from pylops.utils.tapers import taper
from pylops.utils.typing import InputDimsLike, NDArray
@@ -21,6 +28,7 @@ def sliding1d_design(
nwin: int,
nover: int,
nop: int,
+ verb: bool = True,
) -> Tuple[int, int, Tuple[NDArray, NDArray], Tuple[NDArray, NDArray]]:
"""Design Sliding1D operator
@@ -39,6 +47,9 @@ def sliding1d_design(
Number of samples of overlapping part of window.
nop : :obj:`tuple`
Size of model in the transformed domain.
+ verb : :obj:`bool`, optional
+ Verbosity flag. If ``verb==True``, print the data
+ and model windows start-end indices
Returns
-------
@@ -63,29 +74,22 @@ def sliding1d_design(
mwins_inends = (mwin_ins, mwin_ends)
# print information about patching
- logging.warning("%d windows required...", nwins)
- logging.warning(
- "data wins - start:%s, end:%s",
- dwin_ins,
- dwin_ends,
- )
- logging.warning(
- "model wins - start:%s, end:%s",
- mwin_ins,
- mwin_ends,
- )
+ if verb:
+ logging.warning("%d windows required...", nwins)
+ logging.warning(
+ "data wins - start:%s, end:%s",
+ dwin_ins,
+ dwin_ends,
+ )
+ logging.warning(
+ "model wins - start:%s, end:%s",
+ mwin_ins,
+ mwin_ends,
+ )
return nwins, dim, mwins_inends, dwins_inends
-def Sliding1D(
- Op: LinearOperator,
- dim: Union[int, InputDimsLike],
- dimd: Union[int, InputDimsLike],
- nwin: int,
- nover: int,
- tapertype: str = "hanning",
- name: str = "S",
-) -> LinearOperator:
+class Sliding1D(LinearOperator):
r"""1D Sliding transform operator.
Apply a transform operator ``Op`` repeatedly to slices of the model
@@ -103,6 +107,12 @@ def Sliding1D(
``nover``, it is recommended to first run ``sliding1d_design`` to obtain
the corresponding ``dims`` and number of windows.
+ .. note:: Two kind of operators ``Op`` can be provided: the first
+ applies a single transformation to each window separately; the second
+ applies the transformation to all of the windows at the same time. This
+ is directly inferred during initialization when the following condition
+ holds ``Op.shape[1] == dim[0]``.
+
.. warning:: Depending on the choice of `nwin` and `nover` as well as the
size of the data, sliding windows may not cover the entire data.
The start and end indices of each window will be displayed and returned
@@ -122,16 +132,16 @@ def Sliding1D(
Number of samples of overlapping part of window
tapertype : :obj:`str`, optional
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
+ savetaper : :obj:`bool`, optional
+ .. versionadded:: 2.3.0
+
+ Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
+ The first option is more computationally efficient, whilst the second is more memory efficient.
name : :obj:`str`, optional
.. versionadded:: 2.0.0
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
- Returns
- -------
- Sop : :obj:`pylops.LinearOperator`
- Sliding operator
-
Raises
------
ValueError
@@ -139,50 +149,167 @@ def Sliding1D(
shape (``dims``).
"""
- dim: Tuple[int, ...] = _value_or_sized_to_tuple(dim)
- dimd: Tuple[int, ...] = _value_or_sized_to_tuple(dimd)
- # data windows
- dwin_ins, dwin_ends = _slidingsteps(dimd[0], nwin, nover)
- nwins = len(dwin_ins)
+ def __init__(
+ self,
+ Op: LinearOperator,
+ dim: Union[int, InputDimsLike],
+ dimd: Union[int, InputDimsLike],
+ nwin: int,
+ nover: int,
+ tapertype: str = "hanning",
+ savetaper: bool = True,
+ name: str = "S",
+ ) -> None:
- # check windows
- if nwins * Op.shape[1] != dim[0]:
- raise ValueError(
- f"Model shape (dim={dim}) is not consistent with chosen "
- f"number of windows. Run sliding1d_design to identify the "
- f"correct number of windows for the current "
- "model size..."
- )
+ dim: Tuple[int, ...] = _value_or_sized_to_tuple(dim)
+ dimd: Tuple[int, ...] = _value_or_sized_to_tuple(dimd)
+
+ # data windows
+ dwin_ins, dwin_ends = _slidingsteps(dimd[0], nwin, nover)
+ self.dwin_inends = (dwin_ins, dwin_ends)
+ nwins = len(dwin_ins)
+ self.nwin = nwin
+ self.nover = nover
+
+ # check windows
+ if nwins * Op.shape[1] != dim[0] and Op.shape[1] != dim[0]:
+ raise ValueError(
+ f"Model shape (dim={dim}) is not consistent with chosen "
+ f"number of windows. Run sliding1d_design to identify the "
+ f"correct number of windows for the current "
+ "model size..."
+ )
- # create tapers
- if tapertype is not None:
- tap = taper(nwin, nover, tapertype=tapertype).astype(Op.dtype)
- tapin = tap.copy()
- tapin[:nover] = 1
- tapend = tap.copy()
- tapend[-nover:] = 1
- taps = {}
- taps[0] = tapin
- for i in range(1, nwins - 1):
- taps[i] = tap
- taps[nwins - 1] = tapend
-
- # transform to apply
- if tapertype is None:
- OOp = BlockDiag([Op for _ in range(nwins)])
- else:
- OOp = BlockDiag(
- [Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op for itap in range(nwins)]
+ # create tapers
+ self.tapertype = tapertype
+ self.savetaper = savetaper
+ if self.tapertype is not None:
+ tap = taper(nwin, nover, tapertype=self.tapertype)
+ tapin = tap.copy()
+ tapin[:nover] = 1
+ tapend = tap.copy()
+ tapend[-nover:] = 1
+ if self.savetaper:
+ self.taps = [
+ tapin,
+ ]
+ for _ in range(1, nwins - 1):
+ self.taps.append(tap)
+ self.taps.append(tapend)
+ self.taps = np.vstack(self.taps)
+ else:
+ self.taps = np.vstack([tapin, tap, tapend])
+
+ # check if operator is applied to all windows simultaneously
+ self.simOp = False
+ if Op.shape[1] == dim[0]:
+ self.simOp = True
+ self.Op = Op
+
+ super().__init__(
+ dtype=Op.dtype,
+ dims=(nwins, int(dim[0] // nwins)),
+ dimsd=dimd,
+ clinear=False,
+ name=name,
)
- combining = HStack(
- [
- Restriction(dimd, range(win_in, win_end), dtype=Op.dtype).H
- for win_in, win_end in zip(dwin_ins, dwin_ends)
- ]
- )
- Sop = LinearOperator(combining * OOp)
- Sop.dims, Sop.dimsd = (nwins, int(dim[0] // nwins)), dimd
- Sop.name = name
- return Sop
+ self._register_multiplications(self.savetaper)
+
+ def _apply_taper(self, ywins, iwin0):
+ if iwin0 == 0:
+ ywins[0] = ywins[0] * self.taps[0]
+ elif iwin0 == self.dims[0] - 1:
+ ywins[-1] = ywins[-1] * self.taps[-1]
+ else:
+ ywins[iwin0] = ywins[iwin0] * self.taps[1]
+ return ywins
+
+ @reshaped
+ def _matvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ if self.tapertype is not None:
+ x = self.taps * x
+ for iwin0 in range(self.dims[0]):
+ if self.simOp:
+ xxwin = x[iwin0]
+ else:
+ xxwin = self.Op.matvec(x[iwin0])
+ if self.tapertype is not None:
+ xxwin = self.taps[iwin0] * xxwin
+ y[self.dwin_inends[0][iwin0] : self.dwin_inends[1][iwin0]] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin)[:: self.nwin - self.nover]
+ if self.tapertype is not None:
+ ywins = ywins * self.taps
+ if self.simOp:
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ y[iwin0] = self.Op.rmatvec(ywins[iwin0])
+ return y
+
+ @reshaped
+ def _matvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ if self.simOp:
+ xxwin = x[iwin0]
+ else:
+ xxwin = self.Op.matvec(x[iwin0])
+ if self.tapertype is not None:
+ if iwin0 == 0:
+ xxwin = self.taps[0] * xxwin
+ elif iwin0 == self.dims[0] - 1:
+ xxwin = self.taps[-1] * xxwin
+ else:
+ xxwin = self.taps[1] * xxwin
+ y[self.dwin_inends[0][iwin0] : self.dwin_inends[1][iwin0]] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin)[:: self.nwin - self.nover].copy()
+ if self.simOp:
+ if self.tapertype is not None:
+ for iwin0 in range(self.dims[0]):
+ ywins = self._apply_taper(ywins, iwin0)
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ if self.tapertype is not None:
+ ywins = self._apply_taper(ywins, iwin0)
+ y[iwin0] = self.Op.rmatvec(ywins[iwin0])
+ return y
+
+ def _register_multiplications(self, savetaper: bool) -> None:
+ if savetaper:
+ self._matvec = self._matvec_savetaper
+ self._rmatvec = self._rmatvec_savetaper
+ else:
+ self._matvec = self._matvec_nosavetaper
+ self._rmatvec = self._rmatvec_nosavetaper
diff --git a/pylops/signalprocessing/sliding2d.py b/pylops/signalprocessing/sliding2d.py
index c97802c6..f6cbce9c 100644
--- a/pylops/signalprocessing/sliding2d.py
+++ b/pylops/signalprocessing/sliding2d.py
@@ -9,7 +9,13 @@
import numpy as np
from pylops import LinearOperator
-from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
+from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import (
+ get_array_module,
+ get_sliding_window_view,
+ to_cupy_conditional,
+)
+from pylops.utils.decorators import reshaped
from pylops.utils.tapers import taper2d
from pylops.utils.typing import InputDimsLike, NDArray
@@ -54,6 +60,7 @@ def sliding2d_design(
nwin: int,
nover: int,
nop: Tuple[int, int],
+ verb: bool = True,
) -> Tuple[int, Tuple[int, int], Tuple[NDArray, NDArray], Tuple[NDArray, NDArray]]:
"""Design Sliding2D operator
@@ -72,6 +79,9 @@ def sliding2d_design(
Number of samples of overlapping part of window.
nop : :obj:`tuple`
Size of model in the transformed domain.
+ verb : :obj:`bool`, optional
+ Verbosity flag. If ``verb==True``, print the data
+ and model windows start-end indices
Returns
-------
@@ -96,29 +106,22 @@ def sliding2d_design(
mwins_inends = (mwin_ins, mwin_ends)
# print information about patching
- logging.warning("%d windows required...", nwins)
- logging.warning(
- "data wins - start:%s, end:%s",
- dwin_ins,
- dwin_ends,
- )
- logging.warning(
- "model wins - start:%s, end:%s",
- mwin_ins,
- mwin_ends,
- )
+ if verb:
+ logging.warning("%d windows required...", nwins)
+ logging.warning(
+ "data wins - start:%s, end:%s",
+ dwin_ins,
+ dwin_ends,
+ )
+ logging.warning(
+ "model wins - start:%s, end:%s",
+ mwin_ins,
+ mwin_ends,
+ )
return nwins, dims, mwins_inends, dwins_inends
-def Sliding2D(
- Op: LinearOperator,
- dims: InputDimsLike,
- dimsd: InputDimsLike,
- nwin: int,
- nover: int,
- tapertype: str = "hanning",
- name: str = "S",
-) -> LinearOperator:
+class Sliding2D(LinearOperator):
"""2D Sliding transform operator.
Apply a transform operator ``Op`` repeatedly to slices of the model
@@ -139,6 +142,12 @@ def Sliding2D(
``nover``, it is recommended to first run ``sliding2d_design`` to obtain
the corresponding ``dims`` and number of windows.
+ .. note:: Two kind of operators ``Op`` can be provided: the first
+ applies a single transformation to each window separately; the second
+ applies the transformation to all of the windows at the same time. This
+ is directly inferred during initialization when the following condition
+ holds ``Op.shape[1] == np.prod(dims)``.
+
.. warning:: Depending on the choice of `nwin` and `nover` as well as the
size of the data, sliding windows may not cover the entire data.
The start and end indices of each window will be displayed and returned
@@ -159,6 +168,11 @@ def Sliding2D(
Number of samples of overlapping part of window
tapertype : :obj:`str`, optional
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
+ savetaper : :obj:`bool`, optional
+ .. versionadded:: 2.3.0
+
+ Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
+ The first option is more computationally efficient, whilst the second is more memory efficient.
name : :obj:`str`, optional
.. versionadded:: 2.0.0
@@ -176,47 +190,181 @@ def Sliding2D(
shape (``dims``).
"""
- # data windows
- dwin_ins, dwin_ends = _slidingsteps(dimsd[0], nwin, nover)
- nwins = len(dwin_ins)
- # check patching
- if nwins * Op.shape[1] // dims[1] != dims[0]:
- raise ValueError(
- f"Model shape (dims={dims}) is not consistent with chosen "
- f"number of windows. Run sliding2d_design to identify the "
- f"correct number of windows for the current "
- "model size..."
+ def __init__(
+ self,
+ Op: LinearOperator,
+ dims: InputDimsLike,
+ dimsd: InputDimsLike,
+ nwin: int,
+ nover: int,
+ tapertype: str = "hanning",
+ savetaper: bool = True,
+ name: str = "S",
+ ) -> None:
+
+ dims: Tuple[int, ...] = _value_or_sized_to_tuple(dims)
+ dimsd: Tuple[int, ...] = _value_or_sized_to_tuple(dimsd)
+
+ # data windows
+ dwin_ins, dwin_ends = _slidingsteps(dimsd[0], nwin, nover)
+ self.dwin_inends = (dwin_ins, dwin_ends)
+ nwins = len(dwin_ins)
+ self.nwin = nwin
+ self.nover = nover
+
+ # check patching
+ if nwins * Op.shape[1] // dims[1] != dims[0] and Op.shape[1] != np.prod(dims):
+ raise ValueError(
+ f"Model shape (dims={dims}) is not consistent with chosen "
+ f"number of windows. Run sliding2d_design to identify the "
+ f"correct number of windows for the current "
+ "model size..."
+ )
+
+ # create tapers
+ self.tapertype = tapertype
+ self.savetaper = savetaper
+ if self.tapertype is not None:
+ tap = taper2d(dimsd[1], nwin, nover, tapertype=self.tapertype)
+ tapin = tap.copy()
+ tapin[:nover] = 1
+ tapend = tap.copy()
+ tapend[-nover:] = 1
+ if self.savetaper:
+ self.taps = [
+ tapin[np.newaxis, :],
+ ]
+ for _ in range(1, nwins - 1):
+ self.taps.append(tap[np.newaxis, :])
+ self.taps.append(tapend[np.newaxis, :])
+ self.taps = np.concatenate(self.taps, axis=0)
+ else:
+ self.taps = np.vstack(
+ [tapin[np.newaxis, :], tap[np.newaxis, :], tapend[np.newaxis, :]]
+ )
+
+ # check if operator is applied to all windows simultaneously
+ self.simOp = False
+ if Op.shape[1] == np.prod(dims):
+ self.simOp = True
+ self.Op = Op
+
+ super().__init__(
+ dtype=Op.dtype,
+ dims=(nwins, int(dims[0] // nwins), dims[1]),
+ dimsd=dimsd,
+ clinear=False,
+ name=name,
)
- # create tapers
- if tapertype is not None:
- tap = taper2d(dimsd[1], nwin, nover, tapertype=tapertype).astype(Op.dtype)
- tapin = tap.copy()
- tapin[:nover] = 1
- tapend = tap.copy()
- tapend[-nover:] = 1
- taps = {}
- taps[0] = tapin
- for i in range(1, nwins - 1):
- taps[i] = tap
- taps[nwins - 1] = tapend
-
- # transform to apply
- if tapertype is None:
- OOp = BlockDiag([Op for _ in range(nwins)])
- else:
- OOp = BlockDiag(
- [Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op for itap in range(nwins)]
+ self._register_multiplications(self.savetaper)
+
+ def _apply_taper(self, ywins, iwin0):
+ if iwin0 == 0:
+ ywins[0] = ywins[0] * self.taps[0]
+ elif iwin0 == self.dims[0] - 1:
+ ywins[-1] = ywins[-1] * self.taps[-1]
+ else:
+ ywins[iwin0] = ywins[iwin0] * self.taps[1]
+ return ywins
+
+ @reshaped
+ def _matvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ if self.simOp:
+ xx = x[iwin0].reshape(self.nwin, self.dimsd[-1])
+ else:
+ xx = self.Op.matvec(x[iwin0].ravel()).reshape(self.nwin, self.dimsd[-1])
+ if self.tapertype is not None:
+ xxwin = self.taps[iwin0] * xx
+ else:
+ xxwin = xx
+ y[self.dwin_inends[0][iwin0] : self.dwin_inends[1][iwin0]] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin, axis=0)[
+ :: self.nwin - self.nover
+ ].transpose(0, 2, 1)
+ if self.tapertype is not None:
+ ywins = ywins * self.taps
+ if self.simOp:
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ y[iwin0] = self.Op.rmatvec(ywins[iwin0].ravel()).reshape(
+ self.dims[1], self.dims[2]
+ )
+ return y
+
+ @reshaped
+ def _matvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ if self.simOp:
+ xxwin = x[iwin0].reshape(self.nwin, self.dimsd[-1])
+ else:
+ xxwin = self.Op.matvec(x[iwin0].ravel()).reshape(
+ self.nwin, self.dimsd[-1]
+ )
+ if self.tapertype is not None:
+ if iwin0 == 0:
+ xxwin = self.taps[0] * xxwin
+ elif iwin0 == self.dims[0] - 1:
+ xxwin = self.taps[-1] * xxwin
+ else:
+ xxwin = self.taps[1] * xxwin
+ y[self.dwin_inends[0][iwin0] : self.dwin_inends[1][iwin0]] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = (
+ ncp_sliding_window_view(x, self.nwin, axis=0)[:: self.nwin - self.nover]
+ .transpose(0, 2, 1)
+ .copy()
)
-
- combining = HStack(
- [
- Restriction(dimsd, range(win_in, win_end), axis=0, dtype=Op.dtype).H
- for win_in, win_end in zip(dwin_ins, dwin_ends)
- ]
- )
- Sop = LinearOperator(combining * OOp)
- Sop.dims, Sop.dimsd = (nwins, int(dims[0] // nwins), dims[1]), dimsd
- Sop.name = name
- return Sop
+ if self.simOp:
+ if self.tapertype is not None:
+ for iwin0 in range(self.dims[0]):
+ ywins = self._apply_taper(ywins, iwin0)
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ if self.tapertype is not None:
+ ywins = self._apply_taper(ywins, iwin0)
+ y[iwin0] = self.Op.rmatvec(ywins[iwin0].ravel()).reshape(
+ self.dims[1], self.dims[2]
+ )
+ return y
+
+ def _register_multiplications(self, savetaper: bool) -> None:
+ if savetaper:
+ self._matvec = self._matvec_savetaper
+ self._rmatvec = self._rmatvec_savetaper
+ else:
+ self._matvec = self._matvec_nosavetaper
+ self._rmatvec = self._rmatvec_nosavetaper
diff --git a/pylops/signalprocessing/sliding3d.py b/pylops/signalprocessing/sliding3d.py
index 7d21018c..bf6b773d 100644
--- a/pylops/signalprocessing/sliding3d.py
+++ b/pylops/signalprocessing/sliding3d.py
@@ -6,9 +6,17 @@
import logging
from typing import Tuple
+import numpy as np
+
from pylops import LinearOperator
-from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
from pylops.signalprocessing.sliding2d import _slidingsteps
+from pylops.utils._internal import _value_or_sized_to_tuple
+from pylops.utils.backend import (
+ get_array_module,
+ get_sliding_window_view,
+ to_cupy_conditional,
+)
+from pylops.utils.decorators import reshaped
from pylops.utils.tapers import taper3d
from pylops.utils.typing import InputDimsLike, NDArray
@@ -20,6 +28,7 @@ def sliding3d_design(
nwin: Tuple[int, int],
nover: Tuple[int, int],
nop: Tuple[int, int, int],
+ verb: bool = True,
) -> Tuple[
Tuple[int, int],
Tuple[int, int, int],
@@ -43,6 +52,9 @@ def sliding3d_design(
Number of samples of overlapping part of window.
nop : :obj:`tuple`
Size of model in the transformed domain.
+ verb : :obj:`bool`, optional
+ Verbosity flag. If ``verb==True``, print the data
+ and model windows start-end indices
Returns
-------
@@ -71,35 +83,26 @@ def sliding3d_design(
mwins_inends = ((mwin0_ins, mwin0_ends), (mwin1_ins, mwin1_ends))
# print information about patching
- logging.warning("%d-%d windows required...", nwins0, nwins1)
- logging.warning(
- "data wins - start:%s, end:%s / start:%s, end:%s",
- dwin0_ins,
- dwin0_ends,
- dwin1_ins,
- dwin1_ends,
- )
- logging.warning(
- "model wins - start:%s, end:%s / start:%s, end:%s",
- mwin0_ins,
- mwin0_ends,
- mwin1_ins,
- mwin1_ends,
- )
+ if verb:
+ logging.warning("%d-%d windows required...", nwins0, nwins1)
+ logging.warning(
+ "data wins - start:%s, end:%s / start:%s, end:%s",
+ dwin0_ins,
+ dwin0_ends,
+ dwin1_ins,
+ dwin1_ends,
+ )
+ logging.warning(
+ "model wins - start:%s, end:%s / start:%s, end:%s",
+ mwin0_ins,
+ mwin0_ends,
+ mwin1_ins,
+ mwin1_ends,
+ )
return nwins, dims, mwins_inends, dwins_inends
-def Sliding3D(
- Op: LinearOperator,
- dims: InputDimsLike,
- dimsd: InputDimsLike,
- nwin: Tuple[int, int],
- nover: Tuple[int, int],
- nop: Tuple[int, int, int],
- tapertype: str = "hanning",
- nproc: int = 1,
- name: str = "P",
-) -> LinearOperator:
+class Sliding3D(LinearOperator):
"""3D Sliding transform operator.w
Apply a transform operator ``Op`` repeatedly to patches of the model
@@ -121,6 +124,12 @@ def Sliding3D(
``nover``, it is recommended to first run ``sliding3d_design`` to obtain
the corresponding ``dims`` and number of windows.
+ .. note:: Two kind of operators ``Op`` can be provided: the first
+ applies a single transformation to each window separately; the second
+ applies the transformation to all of the windows at the same time. This
+ is directly inferred during initialization when the following condition
+ holds ``Op.shape[1] == np.prod(dims)``.
+
.. warning:: Depending on the choice of `nwin` and `nover` as well as the
size of the data, sliding windows may not cover the entire data.
The start and end indices of each window will be displayed and returned
@@ -145,9 +154,14 @@ def Sliding3D(
to spatial axes in the data
tapertype : :obj:`str`, optional
Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
+ savetaper : :obj:`bool`, optional
+ .. versionadded:: 2.3.0
+
+ Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
+ The first option is more computationally efficient, whilst the second is more memory efficient.
nproc : :obj:`int`, optional
- Number of processes used to evaluate the N operators in parallel
- using ``multiprocessing``. If ``nproc=1``, work in serial mode.
+ *Deprecated*, will be removed in v3.0.0. Simply kept for
+ back-compatibility with previous implementation
name : :obj:`str`, optional
.. versionadded:: 2.0.0
@@ -165,66 +179,287 @@ def Sliding3D(
shape (``dims``).
"""
- # data windows
- dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
- dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
- nwins0 = len(dwin0_ins)
- nwins1 = len(dwin1_ins)
- nwins = nwins0 * nwins1
-
- # check windows
- if nwins * Op.shape[1] // dims[2] != dims[0] * dims[1]:
- raise ValueError(
- f"Model shape (dims={dims}) is not consistent with chosen "
- f"number of windows. Run sliding3d_design to identify the "
- f"correct number of windows for the current "
- "model size..."
+
+ def __init__(
+ self,
+ Op: LinearOperator,
+ dims: InputDimsLike,
+ dimsd: InputDimsLike,
+ nwin: Tuple[int, int],
+ nover: Tuple[int, int],
+ nop: Tuple[int, int, int],
+ tapertype: str = "hanning",
+ savetaper: bool = True,
+ nproc: int = 1,
+ name: str = "P",
+ ) -> None:
+
+ dims: Tuple[int, ...] = _value_or_sized_to_tuple(dims)
+ dimsd: Tuple[int, ...] = _value_or_sized_to_tuple(dimsd)
+
+ # data windows
+ dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
+ dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
+ self.dwins_inends = ((dwin0_ins, dwin0_ends), (dwin1_ins, dwin1_ends))
+ nwins0 = len(dwin0_ins)
+ nwins1 = len(dwin1_ins)
+ nwins = nwins0 * nwins1
+ self.nwin = nwin
+ self.nover = nover
+
+ # model windows
+ mwin0_ins, mwin0_ends = _slidingsteps(dims[0], nop[0], 0)
+ mwin1_ins, mwin1_ends = _slidingsteps(dims[1], nop[1], 0)
+ self.mwins_inends = ((mwin0_ins, mwin0_ends), (mwin1_ins, mwin1_ends))
+
+ # check windows
+ if nwins * Op.shape[1] // dims[2] != dims[0] * dims[1] and Op.shape[
+ 1
+ ] != np.prod(dims):
+ raise ValueError(
+ f"Model shape (dims={dims}) is not consistent with chosen "
+ f"number of windows. Run sliding3d_design to identify the "
+ f"correct number of windows for the current "
+ "model size..."
+ )
+
+ # create tapers
+ self.tapertype = tapertype
+ self.savetaper = savetaper
+ if self.tapertype is not None:
+ tap = taper3d(dimsd[2], nwin, nover, tapertype=tapertype).astype(Op.dtype)
+ # topmost tapers
+ taptop = tap.copy()
+ taptop[: nover[0]] = tap[nwin[0] // 2]
+ # bottommost tapers
+ tapbottom = tap.copy()
+ tapbottom[-nover[0] :] = tap[nwin[0] // 2]
+ # leftmost tapers
+ tapleft = tap.copy()
+ tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
+ # rightmost tapers
+ tapright = tap.copy()
+ tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
+ # lefttopcorner taper
+ taplefttop = tap.copy()
+ taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
+ taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
+ # righttopcorner taper
+ taprighttop = tap.copy()
+ taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
+ taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
+ # leftbottomcorner taper
+ tapleftbottom = tap.copy()
+ tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
+ tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
+ # rightbottomcorner taper
+ taprightbottom = tap.copy()
+ taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
+ taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
+
+ if self.savetaper:
+ taps = [
+ tap,
+ ] * nwins
+
+ for itap in range(0, nwins1):
+ taps[itap] = taptop
+ for itap in range(nwins - nwins1, nwins):
+ taps[itap] = tapbottom
+ for itap in range(0, nwins, nwins1):
+ taps[itap] = tapleft
+ for itap in range(nwins1 - 1, nwins, nwins1):
+ taps[itap] = tapright
+ taps[0] = taplefttop
+ taps[nwins1 - 1] = taprighttop
+ taps[nwins - nwins1] = tapleftbottom
+ taps[nwins - 1] = taprightbottom
+ self.taps = np.vstack(taps).reshape(
+ nwins0, nwins1, nwin[0], nwin[1], dimsd[2]
+ )
+ else:
+ taps = [
+ taplefttop,
+ taptop,
+ taprighttop,
+ tapleft,
+ tap,
+ tapright,
+ tapleftbottom,
+ tapbottom,
+ taprightbottom,
+ ]
+ self.taps = np.vstack(taps).reshape(3, 3, nwin[0], nwin[1], dimsd[2])
+ # check if operator is applied to all windows simultaneously
+ self.simOp = False
+ if Op.shape[1] == np.prod(dims):
+ self.simOp = True
+ self.Op = Op
+
+ super().__init__(
+ dtype=Op.dtype,
+ dims=(
+ nwins0,
+ nwins1,
+ int(dims[0] // nwins0),
+ int(dims[1] // nwins1),
+ dims[2],
+ ),
+ dimsd=dimsd,
+ clinear=False,
+ name=name,
)
- # create tapers
- if tapertype is not None:
- tap = taper3d(dimsd[2], nwin, nover, tapertype=tapertype).astype(Op.dtype)
-
- # transform to apply
- if tapertype is None:
- OOp = BlockDiag([Op for _ in range(nwins)], nproc=nproc)
- else:
- OOp = BlockDiag(
- [Diagonal(tap.ravel(), dtype=Op.dtype) * Op for _ in range(nwins)],
- nproc=nproc,
+ self._register_multiplications(self.savetaper)
+
+ def _apply_taper(self, ywins, iwin0, iwin1):
+ if iwin0 == 0 and iwin1 == 0:
+ ywins[0, 0] = self.taps[0, 0] * ywins[0, 0]
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
+ ywins[0, -1] = self.taps[0, -1] * ywins[0, -1]
+ elif iwin0 == 0:
+ ywins[0, iwin1] = self.taps[0, 1] * ywins[0, iwin1]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
+ ywins[-1, 0] = self.taps[-1, 0] * ywins[-1, 0]
+ elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
+ ywins[-1, -1] = self.taps[-1, -1] * ywins[-1, -1]
+ elif iwin0 == self.dims[0] - 1:
+ ywins[-1, iwin1] = self.taps[-1, 1] * ywins[-1, iwin1]
+ elif iwin1 == 0:
+ ywins[iwin0, 0] = self.taps[1, 0] * ywins[iwin0, 0]
+ elif iwin1 == self.dims[1] - 1:
+ ywins[iwin0, -1] = self.taps[1, -1] * ywins[iwin0, -1]
+ else:
+ ywins[iwin0, iwin1] = self.taps[1, 1] * ywins[iwin0, iwin1]
+ return ywins
+
+ @reshaped
+ def _matvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ if self.simOp:
+ xx = x[iwin0, iwin1].reshape(
+ self.nwin[0], self.nwin[1], self.dimsd[-1]
+ )
+ else:
+ xx = self.Op.matvec(x[iwin0, iwin1].ravel()).reshape(
+ self.nwin[0], self.nwin[1], self.dimsd[-1]
+ )
+ if self.tapertype is not None:
+ xxwin = self.taps[iwin0, iwin1] * xx
+ else:
+ xxwin = xx
+ y[
+ self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
+ self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
+ ] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_savetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = ncp_sliding_window_view(x, self.nwin, axis=(0, 1))[
+ :: self.nwin[0] - self.nover[0], :: self.nwin[1] - self.nover[1]
+ ].transpose(0, 1, 3, 4, 2)
+ if self.tapertype is not None:
+ ywins = ywins * self.taps
+ if self.simOp:
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ y[iwin0, iwin1] = self.Op.rmatvec(
+ ywins[iwin0, iwin1].ravel()
+ ).reshape(self.dims[2], self.dims[3], self.dims[4])
+ return y
+
+ @reshaped
+ def _matvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ y = ncp.zeros(self.dimsd, dtype=self.dtype)
+ if self.simOp:
+ x = self.Op @ x
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ if self.simOp:
+ xxwin = x[iwin0, iwin1].reshape(
+ self.nwin[0], self.nwin[1], self.dimsd[-1]
+ )
+ else:
+ xxwin = self.Op.matvec(x[iwin0, iwin1].ravel()).reshape(
+ self.nwin[0], self.nwin[1], self.dimsd[-1]
+ )
+ if self.tapertype is not None:
+ if iwin0 == 0 and iwin1 == 0:
+ xxwin = self.taps[0, 0] * xxwin
+ elif iwin0 == 0 and iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[0, -1] * xxwin
+ elif iwin0 == 0:
+ xxwin = self.taps[0, 1] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin1 == 0:
+ xxwin = self.taps[-1, 0] * xxwin
+ elif iwin0 == self.dims[0] - 1 and iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[-1, -1] * xxwin
+ elif iwin0 == self.dims[0] - 1:
+ xxwin = self.taps[-1, 1] * xxwin
+ elif iwin1 == 0:
+ xxwin = self.taps[1, 0] * xxwin
+ elif iwin1 == self.dims[1] - 1:
+ xxwin = self.taps[1, -1] * xxwin
+ else:
+ xxwin = self.taps[1, 1] * xxwin
+ y[
+ self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
+ self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
+ ] += xxwin
+ return y
+
+ @reshaped
+ def _rmatvec_nosavetaper(self, x: NDArray) -> NDArray:
+ ncp = get_array_module(x)
+ ncp_sliding_window_view = get_sliding_window_view(x)
+ if self.tapertype is not None:
+ self.taps = to_cupy_conditional(x, self.taps)
+ ywins = (
+ ncp_sliding_window_view(x, self.nwin, axis=(0, 1))[
+ :: self.nwin[0] - self.nover[0], :: self.nwin[1] - self.nover[1]
+ ]
+ .transpose(0, 1, 3, 4, 2)
+ .copy()
)
+ if self.simOp:
+ if self.tapertype is not None:
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ ywins = self._apply_taper(ywins, iwin0, iwin1)
+ y = self.Op.H @ ywins
+ else:
+ y = ncp.zeros(self.dims, dtype=self.dtype)
+ for iwin0 in range(self.dims[0]):
+ for iwin1 in range(self.dims[1]):
+ if self.tapertype is not None:
+ ywins = self._apply_taper(ywins, iwin0, iwin1)
+ y[iwin0, iwin1] = self.Op.rmatvec(
+ ywins[iwin0, iwin1].ravel()
+ ).reshape(self.dims[2], self.dims[3], self.dims[4])
+ return y
- hstack = HStack(
- [
- Restriction(
- (nwin[0], dimsd[1], dimsd[2]),
- range(win_in, win_end),
- axis=1,
- dtype=Op.dtype,
- ).H
- for win_in, win_end in zip(dwin1_ins, dwin1_ends)
- ]
- )
-
- combining1 = BlockDiag([hstack] * nwins0)
- combining0 = HStack(
- [
- Restriction(
- dimsd,
- range(win_in, win_end),
- axis=0,
- dtype=Op.dtype,
- ).H
- for win_in, win_end in zip(dwin0_ins, dwin0_ends)
- ]
- )
- Sop = LinearOperator(combining0 * combining1 * OOp)
- Sop.dims, Sop.dimsd = (
- nwins0,
- nwins1,
- int(dims[0] // nwins0),
- int(dims[1] // nwins1),
- dims[2],
- ), dimsd
- Sop.name = name
- return Sop
+ def _register_multiplications(self, savetaper: bool) -> None:
+ if savetaper:
+ self._matvec = self._matvec_savetaper
+ self._rmatvec = self._rmatvec_savetaper
+ else:
+ self._matvec = self._matvec_nosavetaper
+ self._rmatvec = self._rmatvec_nosavetaper
diff --git a/pylops/torchoperator.py b/pylops/torchoperator.py
index 5e41f67f..1c4dc2da 100644
--- a/pylops/torchoperator.py
+++ b/pylops/torchoperator.py
@@ -14,7 +14,7 @@
else:
torch_message = (
"Torch package not installed. In order to be able to use"
- 'the twoway module run "pip install torch" or'
+ 'the torchoperator module run "pip install torch" or'
'"conda install -c pytorch torch".'
)
from pylops.utils.typing import TensorTypeLike
diff --git a/pylops/utils/backend.py b/pylops/utils/backend.py
index 1cef8bdb..50da7faa 100644
--- a/pylops/utils/backend.py
+++ b/pylops/utils/backend.py
@@ -7,15 +7,22 @@
"get_oaconvolve",
"get_correlate",
"get_add_at",
+ "get_sliding_window_view",
"get_block_diag",
"get_toeplitz",
"get_csc_matrix",
"get_sparse_eye",
"get_lstsq",
+ "get_sp_fft",
"get_complex_dtype",
"get_real_dtype",
"to_numpy",
"to_cupy_conditional",
+ "inplace_set",
+ "inplace_add",
+ "inplace_multiply",
+ "inplace_divide",
+ "randn",
]
from types import ModuleType
@@ -44,6 +51,14 @@
from cupyx.scipy.sparse import csc_matrix as cp_csc_matrix
from cupyx.scipy.sparse import eye as cp_eye
+if deps.jax_enabled:
+ import jax
+ import jax.numpy as jnp
+ from jax.scipy.linalg import block_diag as jnp_block_diag
+ from jax.scipy.linalg import toeplitz as jnp_toeplitz
+ from jax.scipy.signal import convolve as j_convolve
+ from jax.scipy.signal import fftconvolve as j_fftconvolve
+
def get_module(backend: str = "numpy") -> ModuleType:
"""Returns correct numerical module based on backend string
@@ -51,21 +66,23 @@ def get_module(backend: str = "numpy") -> ModuleType:
Parameters
----------
backend : :obj:`str`, optional
- Backend used for dot test computations (``numpy`` or ``cupy``). This
+ Backend used for dot test computations (``numpy`` or ``cupy`` or ``jax``). This
parameter will be used to choose how to create the random vectors.
Returns
-------
mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ Module to be used to process array (:mod:`numpy` or :mod:`cupy` or :mod:`jax`)
"""
if backend == "numpy":
ncp = np
elif backend == "cupy":
ncp = cp
+ elif backend == "jax":
+ ncp = jnp
else:
- raise ValueError("backend must be numpy or cupy")
+ raise ValueError("backend must be numpy, cupy, or jax")
return ncp
@@ -75,12 +92,12 @@ def get_module_name(mod: ModuleType) -> str:
Parameters
----------
mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ Module to be used to process array (:mod:`numpy` or :mod:`cupy` or :mod:`jax`)
Returns
-------
backend : :obj:`str`, optional
- Backend used for dot test computations (``numpy`` or ``cupy``). This
+ Backend used for dot test computations (``numpy`` or ``cupy`` or ``jax``). This
parameter will be used to choose how to create the random vectors.
"""
@@ -88,8 +105,10 @@ def get_module_name(mod: ModuleType) -> str:
backend = "numpy"
elif mod == cp:
backend = "cupy"
+ elif mod == jnp:
+ backend = "jax"
else:
- raise ValueError("module must be numpy or cupy")
+ raise ValueError("module must be numpy, cupy, or jax")
return backend
@@ -98,17 +117,23 @@ def get_array_module(x: npt.ArrayLike) -> ModuleType:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ Module to be used to process array
+ (:mod:`numpy`, :mod:`cupy`, or , :mod:`jax`)
"""
- if deps.cupy_enabled:
- return cp.get_array_module(x)
+ if deps.cupy_enabled or deps.jax_enabled:
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ return jnp
+ elif deps.cupy_enabled:
+ return cp.get_array_module(x)
+ else:
+ return np
else:
return np
@@ -118,22 +143,24 @@ def get_convolve(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return convolve
-
- if cp.get_array_module(x) == np:
- return convolve
+ if deps.cupy_enabled or deps.jax_enabled:
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ return j_convolve
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_convolve
+ else:
+ return convolve
else:
- return cp_convolve
+ return convolve
def get_fftconvolve(x: npt.ArrayLike) -> Callable:
@@ -141,22 +168,24 @@ def get_fftconvolve(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return fftconvolve
-
- if cp.get_array_module(x) == np:
- return fftconvolve
+ if deps.cupy_enabled or deps.jax_enabled:
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ return j_fftconvolve
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_fftconvolve
+ else:
+ return fftconvolve
else:
- return cp_fftconvolve
+ return fftconvolve
def get_oaconvolve(x: npt.ArrayLike) -> Callable:
@@ -164,22 +193,28 @@ def get_oaconvolve(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return oaconvolve
-
- if cp.get_array_module(x) == np:
- return oaconvolve
+ if deps.cupy_enabled or deps.jax_enabled:
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ raise NotImplementedError(
+ "oaconvolve not implemented in "
+ "jax. Consider using a different"
+ "option..."
+ )
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_oaconvolve
+ else:
+ return oaconvolve
else:
- return cp_oaconvolve
+ return oaconvolve
def get_correlate(x: npt.ArrayLike) -> Callable:
@@ -187,22 +222,24 @@ def get_correlate(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return correlate
-
- if cp.get_array_module(x) == np:
- return correlate
+ if deps.cupy_enabled or deps.jax_enabled:
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ return jax.scipy.signal.correlate
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_correlate
+ else:
+ return correlate
else:
- return cp_correlate
+ return correlate
def get_add_at(x: npt.ArrayLike) -> Callable:
@@ -210,13 +247,13 @@ def get_add_at(x: npt.ArrayLike) -> Callable:
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -228,27 +265,52 @@ def get_add_at(x: npt.ArrayLike) -> Callable:
return cupyx.scatter_add
-def get_block_diag(x: npt.ArrayLike) -> Callable:
- """Returns correct block_diag module based on input
+def get_sliding_window_view(x: npt.ArrayLike) -> Callable:
+ """Returns correct sliding_window_view module based on input
Parameters
----------
- x : :obj:`numpy.ndarray`
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
Array
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
- return block_diag
+ return np.lib.stride_tricks.sliding_window_view
if cp.get_array_module(x) == np:
- return block_diag
+ return np.lib.stride_tricks.sliding_window_view
else:
- return cp_block_diag
+ return cp.lib.stride_tricks.sliding_window_view
+
+
+def get_block_diag(x: npt.ArrayLike) -> Callable:
+ """Returns correct block_diag module based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`cupy.ndarray` or :obj:`jax.Array`
+ Array
+
+ Returns
+ -------
+ f : :obj:`func`
+ Function to be used to process array
+
+ """
+ if deps.cupy_enabled or deps.jax_enabled:
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ return jnp_block_diag
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_block_diag
+ else:
+ return block_diag
+ else:
+ return block_diag
def get_toeplitz(x: npt.ArrayLike) -> Callable:
@@ -261,17 +323,19 @@ def get_toeplitz(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
- if not deps.cupy_enabled:
- return toeplitz
-
- if cp.get_array_module(x) == np:
- return toeplitz
+ if deps.cupy_enabled or deps.jax_enabled:
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ return jnp_toeplitz
+ elif deps.cupy_enabled and cp.get_array_module(x) == cp:
+ return cp_toeplitz
+ else:
+ return toeplitz
else:
- return cp_toeplitz
+ return toeplitz
def get_csc_matrix(x: npt.ArrayLike) -> Callable:
@@ -284,8 +348,8 @@ def get_csc_matrix(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -307,8 +371,8 @@ def get_sparse_eye(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -330,8 +394,8 @@ def get_lstsq(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -353,8 +417,8 @@ def get_sp_fft(x: npt.ArrayLike) -> Callable:
Returns
-------
- mod : :obj:`func`
- Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
+ f : :obj:`func`
+ Function to be used to process array
"""
if not deps.cupy_enabled:
@@ -409,7 +473,7 @@ def to_numpy(x: NDArray) -> NDArray:
Returns
-------
- x : :obj:`cupy.ndarray`
+ x : :obj:`numpy.ndarray`
Converted array
"""
@@ -437,5 +501,138 @@ def to_cupy_conditional(x: npt.ArrayLike, y: npt.ArrayLike) -> NDArray:
"""
if deps.cupy_enabled:
if cp.get_array_module(x) == cp and cp.get_array_module(y) == np:
- y = cp.asarray(y)
+ with cp.cuda.Device(x.device):
+ y = cp.asarray(y)
return y
+
+
+def inplace_set(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace set based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to sum at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].set(x)
+ return y
+ else:
+ y[idx] = x
+ return y
+
+
+def inplace_add(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace add based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to sum at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].add(x)
+ return y
+ else:
+ y[idx] += x
+ return y
+
+
+def inplace_multiply(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace multiplication based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to multiply at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].multiply(x)
+ return y
+ else:
+ y[idx] *= x
+ return y
+
+
+def inplace_divide(x: npt.ArrayLike, y: npt.ArrayLike, idx: list) -> NDArray:
+ """Perform inplace division based on input
+
+ Parameters
+ ----------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Array to sum
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+ idx : :obj:`list`
+ Indices to divide at
+
+ Returns
+ -------
+ y : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Output array
+
+ """
+ if deps.jax_enabled and isinstance(x, jnp.ndarray):
+ y = y.at[idx].divide(x)
+ return y
+ else:
+ y[idx] /= x
+ return y
+
+
+def randn(*n: int, backend: str = "numpy") -> NDArray:
+ """Returns randomly generated number
+
+ Parameters
+ ----------
+ *n : :obj:`int`
+ Number of samples to generate in each dimension
+ backend : :obj:`str`, optional
+ Backend used for dot test computations (``numpy`` or ``cupy``). This
+ parameter will be used to choose how to create the random vectors.
+
+ Returns
+ -------
+ x : :obj:`numpy.ndarray` or :obj:`jax.Array`
+ Generated array
+
+ """
+ if backend == "numpy":
+ x = np.random.randn(*n)
+ elif backend == "cupy":
+ x = cp.random.randn(*n)
+ elif backend == "jax":
+ x = jnp.array(np.random.randn(*n))
+ else:
+ raise ValueError("backend must be numpy, cupy, or jax")
+ return x
diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py
index 3497ce86..d0c53409 100644
--- a/pylops/utils/deps.py
+++ b/pylops/utils/deps.py
@@ -1,5 +1,6 @@
__all__ = [
"cupy_enabled",
+ "jax_enabled",
"devito_enabled",
"dtcwt_enabled",
"ucurv_enabled",
@@ -52,6 +53,34 @@ def cupy_import(message: Optional[str] = None) -> str:
return cupy_message
+def jax_import(message: Optional[str] = None) -> str:
+ jax_test = (
+ util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1
+ )
+ if jax_test:
+ try:
+ import_module("jax") # noqa: F401
+
+ jax_message = None
+ except (ImportError, ModuleNotFoundError) as e:
+ jax_message = (
+ f"Failed to import jax, Falling back to numpy (error: {e}). "
+ "Please ensure your environment is set up correctly "
+ "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'"
+ )
+ print(UserWarning(jax_message))
+ else:
+ jax_message = (
+ "Jax package not installed or os.getenv('JAX_PYLOPS') == 0. "
+ f"In order to be able to use {message} "
+ "ensure 'os.getenv('JAX_PYLOPS') == 1' and run "
+ "'pip install jax'; "
+ "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'"
+ )
+
+ return jax_message
+
+
def devito_import(message: Optional[str] = None) -> str:
if devito_enabled:
try:
@@ -211,15 +240,18 @@ def sympy_import(message: Optional[str] = None) -> str:
# Set package availability booleans
-# cupy: the package is imported to check everything is working correctly,
-# if not the package is disabled. We do this here as this library is used as drop-in
-# replacement for many numpy and scipy routines when cupy arrays are provided.
+# cupy and jax: the package is imported to check everything is working correctly,
+# if not the package is disabled. We do this here as these libraries are used as drop-in
+# replacement for many numpy and scipy routines when cupy/jax arrays are provided.
# all other libraries: we simply check if the package is available and postpone its import
# to check everything is working correctly when a user tries to create an operator that requires
# such a package
cupy_enabled: bool = (
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
)
+jax_enabled: bool = (
+ True if (jax_import() is None and int(os.getenv("JAX_PYLOPS", 1)) == 1) else False
+)
devito_enabled = util.find_spec("devito") is not None
dtcwt_enabled = util.find_spec("dtcwt") is not None
ucurv_enabled = util.find_spec("ucurv") is not None
diff --git a/pylops/utils/dottest.py b/pylops/utils/dottest.py
index ed77b995..c8a198ca 100644
--- a/pylops/utils/dottest.py
+++ b/pylops/utils/dottest.py
@@ -4,7 +4,7 @@
import numpy as np
-from pylops.utils.backend import get_module, to_numpy
+from pylops.utils.backend import get_module, randn, to_numpy
def dottest(
@@ -93,13 +93,13 @@ def dottest(
# make u and v vectors
rdtype = np.ones(1, Op.dtype).real.dtype
- u = ncp.random.randn(nc).astype(rdtype)
+ u = randn(nc, backend=backend).astype(rdtype)
if complexflag not in (0, 2):
- u = u + 1j * ncp.random.randn(nc).astype(rdtype)
+ u = u + 1j * randn(nc, backend=backend).astype(rdtype)
- v = ncp.random.randn(nr).astype(rdtype)
+ v = randn(nr, backend=backend).astype(rdtype)
if complexflag not in (0, 1):
- v = v + 1j * ncp.random.randn(nr).astype(rdtype)
+ v = v + 1j * randn(nr, backend=backend).astype(rdtype)
y = Op.matvec(u) # Op * u
x = Op.rmatvec(v) # Op'* v
diff --git a/pylops/utils/signalprocessing.py b/pylops/utils/signalprocessing.py
index 0d1fb573..09bb23b0 100644
--- a/pylops/utils/signalprocessing.py
+++ b/pylops/utils/signalprocessing.py
@@ -298,8 +298,8 @@ def dip_estimate(
Notes
-----
- Thin wrapper around ``pylops.utils.signalprocessing.dip_estimate`` with ``slopes==True``.
- See the Notes of ``pylops.utils.signalprocessing.dip_estimate`` for details.
+ Thin wrapper around ``pylops.utils.signalprocessing.slope_estimate`` with ``dips=True``.
+ See the Notes of ``pylops.utils.signalprocessing.slope_estimate`` for details.
.. [1] Van Vliet, L. J., Verbeek, P. W., "Estimators for orientation and
anisotropy in digitized images", Journal ASCI Imaging Workshop. 1995.
diff --git a/pylops/utils/tapers.py b/pylops/utils/tapers.py
index 52c95e8e..a15a4f32 100644
--- a/pylops/utils/tapers.py
+++ b/pylops/utils/tapers.py
@@ -75,6 +75,8 @@ def cosinetaper(
square : :obj:`bool`, optional
Cosine square taper (``True``) or Cosine taper (``False``)
exponent : :obj:`float`, optional
+ .. versionadded:: 2.3.0
+
Exponent to apply to Cosine taper. If provided, takes precedence over ``square``
Returns
diff --git a/pylops/waveeqprocessing/_twoway.py b/pylops/waveeqprocessing/_twoway.py
new file mode 100644
index 00000000..81a6498a
--- /dev/null
+++ b/pylops/waveeqprocessing/_twoway.py
@@ -0,0 +1,43 @@
+from examples.seismic.utils import PointSource
+
+
+class _CustomSource(PointSource):
+ """Custom source
+
+ This class creates a Devito symbolic object that encapsulates a set of
+ sources with a user defined source signal wavelet ``wav``
+
+ Parameters
+ ----------
+ name : :obj:`str`
+ Name for the resulting symbol.
+ grid : :obj:`devito.types.grid.Grid`
+ The computational domain.
+ time_range : :obj:`examples.seismic.source.TimeAxis`
+ TimeAxis(start, step, num) object.
+ wav : :obj:`numpy.ndarray`
+ Wavelet of size
+
+ """
+
+ __rkwargs__ = PointSource.__rkwargs__ + ["wav"]
+
+ @classmethod
+ def __args_setup__(cls, *args, **kwargs):
+ kwargs.setdefault("npoint", 1)
+
+ return super().__args_setup__(*args, **kwargs)
+
+ def __init_finalize__(self, *args, **kwargs):
+ super().__init_finalize__(*args, **kwargs)
+
+ self.wav = kwargs.get("wav")
+
+ if not self.alias:
+ for p in range(kwargs["npoint"]):
+ self.data[:, p] = self.wavelet
+
+ @property
+ def wavelet(self):
+ """Return user-provided wavelet"""
+ return self.wav
diff --git a/pylops/waveeqprocessing/blending.py b/pylops/waveeqprocessing/blending.py
index adca6d93..2bc31c65 100644
--- a/pylops/waveeqprocessing/blending.py
+++ b/pylops/waveeqprocessing/blending.py
@@ -9,7 +9,7 @@
from pylops import LinearOperator
from pylops.basicoperators import BlockDiag, HStack, Pad
from pylops.signalprocessing import Shift
-from pylops.utils.backend import get_array_module
+from pylops.utils.backend import get_array_module, inplace_add, inplace_set
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, NDArray
@@ -76,12 +76,7 @@ def __init__(
self.dt = dt
self.times = times
self.shiftall = shiftall
- if np.max(self.times) // dt == np.max(self.times) / dt:
- # do not add extra sample as no shift will be applied
- self.nttot = int(np.max(self.times) / self.dt + self.nt)
- else:
- # add 1 extra sample at the end
- self.nttot = int(np.max(self.times) / self.dt + self.nt + 1)
+ self.nttot = int(np.max(self.times) / self.dt + self.nt + 1)
if not self.shiftall:
# original implementation, where each source is shifted indipendently
self.PadOp = Pad((self.nr, self.nt), ((0, 0), (0, 1)), dtype=self.dtype)
@@ -143,7 +138,11 @@ def _matvec_smallrecs(self, x: NDArray) -> NDArray:
self.ns, self.nr, self.nt + 1
)
for i, shift_int in enumerate(self.shifts):
- blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data[i]
+ blended_data = inplace_add(
+ shifted_data[i],
+ blended_data,
+ (slice(None, None), slice(shift_int, shift_int + self.nt + 1)),
+ )
return blended_data
@reshaped
@@ -151,7 +150,11 @@ def _rmatvec_smallrecs(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
shifted_data = ncp.zeros((self.ns, self.nr, self.nt + 1), dtype=self.dtype)
for i, shift_int in enumerate(self.shifts):
- shifted_data[i, :, :] = x[:, shift_int : shift_int + self.nt + 1]
+ shifted_data = inplace_set(
+ x[:, shift_int : shift_int + self.nt + 1],
+ shifted_data,
+ (i, slice(None, None), slice(None, None)),
+ )
deblended_data = self.PadOp._rmatvec(
self.ShiftOp._rmatvec(shifted_data.ravel())
).reshape(self.dims)
@@ -170,7 +173,11 @@ def _matvec_largerecs(self, x: NDArray) -> NDArray:
.matvec(self.PadOp.matvec(x[i, :, :].ravel()))
.reshape(self.ShiftOps[i].dimsd)
)
- blended_data[:, shift_int : shift_int + self.nt + 1] += shifted_data
+ blended_data = inplace_add(
+ shifted_data,
+ blended_data,
+ (slice(None, None), slice(shift_int, shift_int + self.nt + 1)),
+ )
return blended_data
@reshaped
@@ -186,7 +193,11 @@ def _rmatvec_largerecs(self, x: NDArray) -> NDArray:
x[:, shift_int : shift_int + self.nt + 1].ravel()
)
).reshape(self.PadOp.dims)
- deblended_data[i, :, :] = shifted_data
+ deblended_data = inplace_set(
+ shifted_data,
+ deblended_data,
+ (i, slice(None, None), slice(None, None)),
+ )
return deblended_data
def _register_multiplications(self) -> None:
diff --git a/pylops/waveeqprocessing/kirchhoff.py b/pylops/waveeqprocessing/kirchhoff.py
index 06c29251..bdd5be87 100644
--- a/pylops/waveeqprocessing/kirchhoff.py
+++ b/pylops/waveeqprocessing/kirchhoff.py
@@ -288,8 +288,11 @@ def __init__(
)
self.rix = np.tile((recs[0] - x[0]) // dx, (ns, 1)).astype(int).ravel()
elif self.ndims == 3:
- # TODO: 3D normalized distances
- raise NotImplementedError("dynamic=True currently not available in 3D")
+ # TODO: compute 3D indices for aperture filter
+ # currently no aperture filter in 3D... just make indices 0
+ # so check if always passed
+ self.six = np.zeros(nr * ns)
+ self.rix = np.zeros(nr * ns)
# compute traveltime and distances
self.travsrcrec = True # use separate tables for src and rec traveltimes
@@ -362,8 +365,26 @@ def __init__(
trav_recs_grad[0], trav_recs_grad[1]
).reshape(np.prod(dims), nr)
else:
- # TODO: 3D
- raise NotImplementedError("dynamic=True currently not available in 3D")
+ trav_srcs_grad = np.concatenate(
+ [trav_srcs_grad[i][np.newaxis] for i in range(3)]
+ )
+ trav_recs_grad = np.concatenate(
+ [trav_recs_grad[i][np.newaxis] for i in range(3)]
+ )
+ self.angle_srcs = (
+ np.sign(trav_srcs_grad[1])
+ * np.arccos(
+ trav_srcs_grad[-1]
+ / np.sqrt(np.sum(trav_srcs_grad**2, axis=0))
+ )
+ ).reshape(np.prod(dims), ns)
+ self.angle_recs = (
+ np.sign(trav_srcs_grad[1])
+ * np.arccos(
+ trav_recs_grad[-1]
+ / np.sqrt(np.sum(trav_recs_grad**2, axis=0))
+ )
+ ).reshape(np.prod(dims), nr)
# pre-compute traveltime indices if total traveltime is used
if not self.travsrcrec:
@@ -386,6 +407,12 @@ def __init__(
# define aperture
# if aperture=None, we want to ensure the check is always matched (no aperture limits...)
+ # if aperture!=None in 3d, force to None as aperture checks are not yet implemented
+ if aperture is not None and self.ndims == 3:
+ aperture = None
+ warnings.warn(
+ "Aperture is forced to None as currently not implemented in 3D"
+ )
if aperture is not None:
warnings.warn(
"Aperture is currently defined as ratio of offset over depth, "
@@ -608,10 +635,10 @@ def _traveltime_table(
# compute traveltime gradients at image points
trav_srcs_grad = np.gradient(
- trav_srcs.reshape(*dims, ns), axis=np.arange(ndims)
+ trav_srcs.reshape(*dims, ns), *dsamp, axis=np.arange(ndims)
)
trav_recs_grad = np.gradient(
- trav_recs.reshape(*dims, nr), axis=np.arange(ndims)
+ trav_recs.reshape(*dims, nr), *dsamp, axis=np.arange(ndims)
)
return (
diff --git a/pylops/waveeqprocessing/mdd.py b/pylops/waveeqprocessing/mdd.py
index 0202b624..90e2cba3 100644
--- a/pylops/waveeqprocessing/mdd.py
+++ b/pylops/waveeqprocessing/mdd.py
@@ -242,6 +242,7 @@ def MDC(
conj=conj,
prescaled=prescaled,
args_FFT={"engine": fftengine},
+ args_FFT1={"engine": fftengine},
args_Fredholm1={"usematmul": usematmul},
)
MOp.name = name
diff --git a/pylops/waveeqprocessing/oneway.py b/pylops/waveeqprocessing/oneway.py
index b41779db..14e5f7f1 100644
--- a/pylops/waveeqprocessing/oneway.py
+++ b/pylops/waveeqprocessing/oneway.py
@@ -235,6 +235,8 @@ def Deghosting(
zrec : :obj:`float`
Depth of receiver array
kind : :obj:`str`, optional
+ .. versionadded:: 2.3.0
+
Type of data (``p`` or ``vz``)
pd : :obj:`np.ndarray`, optional
Direct arrival to be subtracted from ``p``
diff --git a/pylops/waveeqprocessing/twoway.py b/pylops/waveeqprocessing/twoway.py
index 9b74d726..f74de122 100644
--- a/pylops/waveeqprocessing/twoway.py
+++ b/pylops/waveeqprocessing/twoway.py
@@ -1,6 +1,6 @@
__all__ = ["AcousticWave2D"]
-from typing import Tuple
+from typing import Any, NewType, Tuple
import numpy as np
@@ -15,6 +15,12 @@
from examples.seismic import AcquisitionGeometry, Model
from examples.seismic.acoustic import AcousticWaveSolver
+ from ._twoway import _CustomSource
+else:
+ AcousticWaveSolver = Any
+
+AcousticWaveSolverType = NewType("AcousticWaveSolver", AcousticWaveSolver)
+
class AcousticWave2D(LinearOperator):
"""Devito Acoustic propagator.
@@ -38,9 +44,9 @@ class AcousticWave2D(LinearOperator):
rec_z : :obj:`numpy.ndarray` or :obj:`float`
Receiver z-coordinates in m
t0 : :obj:`float`
- Initial time
+ Initial time in ms
tn : :obj:`int`
- Number of time samples
+ Final time in ms
src_type : :obj:`str`
Source type
space_order : :obj:`int`, optional
@@ -79,7 +85,7 @@ def __init__(
rec_x: NDArray,
rec_z: NDArray,
t0: float,
- tn: int,
+ tn: float,
src_type: str = "Ricker",
space_order: int = 6,
nbl: int = 20,
@@ -155,7 +161,7 @@ def _create_geometry(
rec_x: NDArray,
rec_z: NDArray,
t0: float,
- tn: int,
+ tn: float,
src_type: str,
f0: float = 20.0,
) -> None:
@@ -174,7 +180,7 @@ def _create_geometry(
t0 : :obj:`float`
Initial time
tn : :obj:`int`
- Number of time samples
+ Final time in ms
src_type : :obj:`str`
Source type
f0 : :obj:`float`, optional
@@ -201,6 +207,28 @@ def _create_geometry(
f0=None if f0 is None else f0 * 1e-3,
)
+ def updatesrc(self, wav):
+ """Update source wavelet
+
+ This routines is used to allow users to pass a custom source
+ wavelet to replace the source wavelet generated when the
+ object is initialized
+
+ Parameters
+ ----------
+ wav : :obj:`numpy.ndarray`
+ Wavelet
+
+ """
+ wav_padded = np.pad(wav, (0, self.geometry.nt - len(wav)))
+
+ self.wav = _CustomSource(
+ name="src",
+ grid=self.model.grid,
+ wav=wav_padded,
+ time_range=self.geometry.time_axis,
+ )
+
def _srcillumination_oneshot(self, isrc: int) -> Tuple[NDArray, NDArray]:
"""Source wavefield and illumination for one shot
@@ -229,8 +257,15 @@ def _srcillumination_oneshot(self, isrc: int) -> Tuple[NDArray, NDArray]:
)
solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
+ # assign source location to source object with custom wavelet
+ if hasattr(self, "wav"):
+ self.wav.coordinates.data[0, :] = self.geometry.src_positions[isrc, :]
+
# source wavefield
- u0 = solver.forward(save=True)[1]
+ u0 = solver.forward(
+ save=True, src=None if not hasattr(self, "wav") else self.wav
+ )[1]
+
# source illumination
src_ill = self._crop_model((u0.data**2).sum(axis=0), self.model.nbl)
return u0, src_ill
@@ -255,13 +290,13 @@ def srcillumination_allshots(self, savewav: bool = False) -> None:
self.src_wavefield.append(src_wav)
self.src_illumination += src_ill
- def _born_oneshot(self, isrc: int, dm: NDArray) -> NDArray:
+ def _born_oneshot(self, solver: AcousticWaveSolverType, dm: NDArray) -> NDArray:
"""Born modelling for one shot
Parameters
----------
- isrc : :obj:`int`
- Index of source to model
+ solver : :obj:`AcousticWaveSolver`
+ Devito's solver object.
dm : :obj:`np.ndarray`
Model perturbation
@@ -271,25 +306,19 @@ def _born_oneshot(self, isrc: int, dm: NDArray) -> NDArray:
Data
"""
- # create geometry for single source
- geometry = AcquisitionGeometry(
- self.model,
- self.geometry.rec_positions,
- self.geometry.src_positions[isrc, :],
- self.geometry.t0,
- self.geometry.tn,
- f0=self.geometry.f0,
- src_type=self.geometry.src_type,
- )
# set perturbation
dmext = np.zeros(self.model.grid.shape, dtype=np.float32)
dmext[self.model.nbl : -self.model.nbl, self.model.nbl : -self.model.nbl] = dm
- # solve
- solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
- d = solver.jacobian(dmext)[0]
- d = d.resample(geometry.dt).data[:][: geometry.nt].T
+ # assign source location to source object with custom wavelet
+ if hasattr(self, "wav"):
+ self.wav.coordinates.data[0, :] = solver.geometry.src_positions[:]
+
+ d = solver.jacobian(dmext, src=None if not hasattr(self, "wav") else self.wav)[
+ 0
+ ]
+ d = d.resample(solver.geometry.dt).data[:][: solver.geometry.nt].T
return d
def _born_allshots(self, dm: NDArray) -> NDArray:
@@ -306,11 +335,26 @@ def _born_allshots(self, dm: NDArray) -> NDArray:
Data for all shots
"""
+ # create geometry for single source
+ geometry = AcquisitionGeometry(
+ self.model,
+ self.geometry.rec_positions,
+ self.geometry.src_positions[0, :],
+ self.geometry.t0,
+ self.geometry.tn,
+ f0=self.geometry.f0,
+ src_type=self.geometry.src_type,
+ )
+
+ # solve
+ solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
+
nsrc = self.geometry.src_positions.shape[0]
dtot = []
for isrc in range(nsrc):
- d = self._born_oneshot(isrc, dm)
+ solver.geometry.src_positions = self.geometry.src_positions[isrc, :]
+ d = self._born_oneshot(solver, dm)
dtot.append(d)
dtot = np.array(dtot).reshape(nsrc, d.shape[0], d.shape[1])
return dtot
@@ -347,11 +391,18 @@ def _bornadj_oneshot(self, isrc, dobs):
solver = AcousticWaveSolver(self.model, geometry, space_order=self.space_order)
+ # assign source location to source object with custom wavelet
+ if hasattr(self, "wav"):
+ self.wav.coordinates.data[0, :] = self.geometry.src_positions[isrc, :]
+
# source wavefield
if hasattr(self, "src_wavefield"):
u0 = self.src_wavefield[isrc]
else:
- u0 = solver.forward(save=True)[1]
+ u0 = solver.forward(
+ save=True, src=None if not hasattr(self, "wav") else self.wav
+ )[1]
+
# adjoint modelling (reverse wavefield plus imaging condition)
model = solver.jacobian_adjoint(
rec=recs, u=u0, checkpointing=self.checkpointing
diff --git a/pylops/waveeqprocessing/wavedecomposition.py b/pylops/waveeqprocessing/wavedecomposition.py
index 7d926d36..715fb2c1 100644
--- a/pylops/waveeqprocessing/wavedecomposition.py
+++ b/pylops/waveeqprocessing/wavedecomposition.py
@@ -156,6 +156,7 @@ def _obliquity3D(
critical: float = 100.0,
ntaper: int = 10,
composition: bool = True,
+ fftengine: str = "scipy",
backend: str = "numpy",
dtype: DTypeLike = "complex128",
) -> Tuple[LinearOperator, LinearOperator]:
@@ -187,6 +188,9 @@ def _obliquity3D(
composition : :obj:`bool`, optional
Create obliquity factor for composition (``True``) or
decomposition (``False``)
+ fftengine : :obj:`str`, optional
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
backend : :obj:`str`, optional
Backend used for creation of obliquity factor operator
(``numpy`` or ``cupy``)
@@ -203,7 +207,11 @@ def _obliquity3D(
"""
# create Fourier operator
FFTop = FFTND(
- dims=[nr[0], nr[1], nt], nffts=nffts, sampling=[dr[0], dr[1], dt], dtype=dtype
+ dims=[nr[0], nr[1], nt],
+ nffts=nffts,
+ sampling=[dr[0], dr[1], dt],
+ engine=fftengine,
+ dtype=dtype,
)
# create obliquity operator
@@ -547,6 +555,7 @@ def UpDownComposition3D(
critical: float = 100.0,
ntaper: int = 10,
scaling: float = 1.0,
+ fftengine: str = "scipy",
backend: str = "numpy",
dtype: DTypeLike = "complex128",
name: str = "U",
@@ -588,6 +597,11 @@ def UpDownComposition3D(
angle
scaling : :obj:`float`, optional
Scaling to apply to the operator (see Notes for more details)
+ fftengine : :obj:`str`, optional
+ .. versionadded:: 2.3.0
+
+ Engine used for fft computation (``numpy`` or ``scipy``). Choose
+ ``numpy`` when working with cupy and jax arrays.
backend : :obj:`str`, optional
Backend used for creation of obliquity factor operator
(``numpy`` or ``cupy``)
@@ -638,6 +652,7 @@ def UpDownComposition3D(
critical=critical,
ntaper=ntaper,
composition=True,
+ fftengine=fftengine,
backend=backend,
dtype=dtype,
)
diff --git a/pyproject.toml b/pyproject.toml
index 5e435cc7..01b44d01 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,7 +31,7 @@ classifiers = [
]
dependencies = [
"numpy >= 1.21.0",
- "scipy >= 1.4.0",
+ "scipy >= 1.11.0",
]
dynamic = ["version"]
diff --git a/pytests/test_dtcwt.py b/pytests/test_dtcwt.py
index b0cf2b61..979a7f76 100644
--- a/pytests/test_dtcwt.py
+++ b/pytests/test_dtcwt.py
@@ -3,6 +3,9 @@
from pylops.signalprocessing import DTCWT
+# currently test only if numpy<2.0.0 is installed...
+np_version = np.__version__.split(".")
+
par1 = {"ny": 10, "nx": 10, "dtype": "float64"}
par2 = {"ny": 50, "nx": 50, "dtype": "float64"}
@@ -17,6 +20,8 @@ def sequential_array(shape):
@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input1D(par):
"""Test for DTCWT with 1D input"""
+ if int(np_version[0]) >= 2:
+ return
t = sequential_array((par["ny"],))
@@ -31,6 +36,8 @@ def test_dtcwt1D_input1D(par):
@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input2D(par):
"""Test for DTCWT with 2D input (forward-inverse pair)"""
+ if int(np_version[0]) >= 2:
+ return
t = sequential_array(
(
@@ -50,6 +57,8 @@ def test_dtcwt1D_input2D(par):
@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_input3D(par):
"""Test for DTCWT with 3D input (forward-inverse pair)"""
+ if int(np_version[0]) >= 2:
+ return
t = sequential_array((par["ny"], par["ny"], par["ny"]))
@@ -64,6 +73,9 @@ def test_dtcwt1D_input3D(par):
@pytest.mark.parametrize("par", [(par1), (par2)])
def test_dtcwt1D_birot(par):
"""Test for DTCWT birot (forward-inverse pair)"""
+ if int(np_version[0]) >= 2:
+ return
+
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]
t = sequential_array(
diff --git a/pytests/test_dwts.py b/pytests/test_dwts.py
index 0fca4526..09f567dc 100755
--- a/pytests/test_dwts.py
+++ b/pytests/test_dwts.py
@@ -3,11 +3,20 @@
from numpy.testing import assert_array_almost_equal
from scipy.sparse.linalg import lsqr
-from pylops.signalprocessing import DWT, DWT2D
+from pylops.signalprocessing import DWT, DWT2D, DWTND
from pylops.utils import dottest
par1 = {"ny": 7, "nx": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real
par2 = {"ny": 7, "nx": 9, "nt": 10, "imag": 1j, "dtype": "complex64"} # complex
+par3 = {"ny": 7, "nx": 9, "nz": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real 4D
+par4 = {
+ "ny": 7,
+ "nx": 9,
+ "nz": 9,
+ "nt": 10,
+ "imag": 1j,
+ "dtype": "complex64",
+} # complex 4D
np.random.seed(10)
@@ -133,3 +142,56 @@ def test_DWT2D_3dsignal(par):
assert_array_almost_equal(x.ravel(), xadj, decimal=8)
assert_array_almost_equal(x.ravel(), xinv, decimal=8)
+
+
+@pytest.mark.parametrize("par", [(par3), (par4)])
+def test_DWTND_3dsignal(par):
+ """Dot-test and inversion for DWTND operator for 3d signal"""
+ DWTop = DWTND(
+ dims=(par["nt"], par["nx"], par["ny"]), axes=(0, 1, 2), wavelet="haar", level=3
+ )
+ x = np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) + par[
+ "imag"
+ ] * np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"]))
+
+ assert dottest(
+ DWTop, DWTop.shape[0], DWTop.shape[1], complexflag=0 if par["imag"] == 0 else 3
+ )
+
+ y = DWTop * x.ravel()
+ xadj = DWTop.H * y # adjoint is same as inverse for dwt
+ xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]
+
+ assert_array_almost_equal(x.ravel(), xadj, decimal=8)
+ assert_array_almost_equal(x.ravel(), xinv, decimal=8)
+
+
+@pytest.mark.parametrize("par", [(par3), (par4)])
+def test_DWTND_4dsignal(par):
+ """Dot-test and inversion for DWTND operator for 4d signal"""
+ for axes in [(0, 1, 2), (0, 2, 3), (1, 2, 3), (0, 1, 3), (0, 1, 2, 3)]:
+ DWTop = DWTND(
+ dims=(par["nt"], par["nx"], par["ny"], par["nz"]),
+ axes=axes,
+ wavelet="haar",
+ level=3,
+ )
+ x = np.random.normal(
+ 0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
+ ) + par["imag"] * np.random.normal(
+ 0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
+ )
+
+ assert dottest(
+ DWTop,
+ DWTop.shape[0],
+ DWTop.shape[1],
+ complexflag=0 if par["imag"] == 0 else 3,
+ )
+
+ y = DWTop * x.ravel()
+ xadj = DWTop.H * y # adjoint is same as inverse for dwt
+ xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]
+
+ assert_array_almost_equal(x.ravel(), xadj, decimal=8)
+ assert_array_almost_equal(x.ravel(), xinv, decimal=8)
diff --git a/pytests/test_jaxoperator.py b/pytests/test_jaxoperator.py
new file mode 100755
index 00000000..86de4e8d
--- /dev/null
+++ b/pytests/test_jaxoperator.py
@@ -0,0 +1,53 @@
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+from numpy.testing import assert_array_almost_equal, assert_array_equal
+
+from pylops import JaxOperator, MatrixMult
+
+par1 = {"ny": 11, "nx": 11, "dtype": np.float32} # square
+par2 = {"ny": 21, "nx": 11, "dtype": np.float32} # overdetermined
+
+np.random.seed(0)
+
+
+@pytest.mark.parametrize("par", [(par1)])
+def test_JaxOperator(par):
+ """Apply forward and adjoint and compare with native pylops."""
+ M = np.random.normal(0.0, 1.0, (par["ny"], par["nx"])).astype(par["dtype"])
+ Mop = MatrixMult(jnp.array(M), dtype=par["dtype"])
+ Jop = JaxOperator(Mop)
+
+ x = np.random.normal(0.0, 1.0, par["nx"]).astype(par["dtype"])
+ xjnp = jnp.array(x)
+
+ # pylops operator
+ y = Mop * x
+ xadj = Mop.H * y
+
+ # jax operator
+ yjnp = Jop * xjnp
+ xadjnp = Jop.rmatvecad(xjnp, yjnp)
+
+ assert_array_equal(y, np.array(yjnp))
+ assert_array_equal(xadj, np.array(xadjnp))
+
+
+@pytest.mark.parametrize("par", [(par1)])
+def test_TorchOperator_batch(par):
+ """Apply forward for input with multiple samples
+ (= batch) and flattened arrays"""
+
+ M = np.random.normal(0.0, 1.0, (par["ny"], par["nx"])).astype(par["dtype"])
+ Mop = MatrixMult(jnp.array(M), dtype=par["dtype"])
+ Jop = JaxOperator(Mop)
+ auto_batch_matvec = jax.vmap(Jop._matvec)
+
+ x = np.random.normal(0.0, 1.0, (4, par["nx"])).astype(par["dtype"])
+ xjnp = jnp.array(x)
+
+ y = Mop.matmat(x.T).T
+ yjnp = auto_batch_matvec(xjnp)
+
+ assert_array_almost_equal(y, np.array(yjnp), decimal=5)
diff --git a/pytests/test_leastsquares.py b/pytests/test_leastsquares.py
index c0ee9944..84013a27 100755
--- a/pytests/test_leastsquares.py
+++ b/pytests/test_leastsquares.py
@@ -93,12 +93,12 @@ def test_NormalEquationsInversion(par):
# normal equations with regularization
xinv = normal_equations_inversion(
- Gop, y, [Reg], epsI=1e-5, epsRs=[1e-8], x0=x0, **dict(maxiter=200, tol=1e-10)
+ Gop, y, [Reg], epsI=1e-5, epsRs=[1e-8], x0=x0, **dict(maxiter=200, atol=1e-10)
)[0]
assert_array_almost_equal(x, xinv, decimal=3)
# normal equations with weight
xinv = normal_equations_inversion(
- Gop, y, None, Weight=Weigth, epsI=1e-5, x0=x0, **dict(maxiter=200, tol=1e-10)
+ Gop, y, None, Weight=Weigth, epsI=1e-5, x0=x0, **dict(maxiter=200, atol=1e-10)
)[0]
assert_array_almost_equal(x, xinv, decimal=3)
# normal equations with weight and small regularization
@@ -110,7 +110,7 @@ def test_NormalEquationsInversion(par):
epsI=1e-5,
epsRs=[1e-8],
x0=x0,
- **dict(maxiter=200, tol=1e-10)
+ **dict(maxiter=200, atol=1e-10)
)[0]
assert_array_almost_equal(x, xinv, decimal=3)
# normal equations with weight and small normal regularization
@@ -123,7 +123,7 @@ def test_NormalEquationsInversion(par):
epsI=1e-5,
epsNRs=[1e-8],
x0=x0,
- **dict(maxiter=200, tol=1e-10)
+ **dict(maxiter=200, atol=1e-10)
)[0]
assert_array_almost_equal(x, xinv, decimal=3)
@@ -192,7 +192,7 @@ def test_WeightedInversion(par):
y = Gop * x
xne = normal_equations_inversion(
- Gop, y, None, Weight=Weigth, **dict(maxiter=5, tol=1e-10)
+ Gop, y, None, Weight=Weigth, **dict(maxiter=5, atol=1e-10)
)[0]
xreg = regularized_inversion(
Gop, y, None, Weight=Weigth1, **dict(damp=0, iter_lim=5, show=0)
diff --git a/pytests/test_linearoperator.py b/pytests/test_linearoperator.py
index e6b4fd58..6673b2b7 100755
--- a/pytests/test_linearoperator.py
+++ b/pytests/test_linearoperator.py
@@ -122,7 +122,7 @@ def test_sparse(par):
D = np.diag(diag)
Dop = Diagonal(diag, dtype=par["dtype"])
S = Dop.tosparse()
- assert_array_equal(S.A, D)
+ assert_array_equal(S.toarray(), D)
@pytest.mark.parametrize("par", [(par1), (par2), (par1j)])
diff --git a/pytests/test_patching.py b/pytests/test_patching.py
index 25656724..de61243e 100755
--- a/pytests/test_patching.py
+++ b/pytests/test_patching.py
@@ -25,6 +25,7 @@
"novert": 0,
# "winst": 2,
"tapertype": None,
+ "savetaper": True,
} # no overlap, no taper
par2 = {
"ny": 6,
@@ -43,6 +44,7 @@
"novert": 0,
# "winst": 2,
"tapertype": "hanning",
+ "savetaper": True,
} # no overlap, with taper
par3 = {
"ny": 6,
@@ -61,8 +63,28 @@
"novert": 2,
# "winst": 4,
"tapertype": None,
+ "savetaper": True,
} # overlap, no taper
par4 = {
+ "ny": 6,
+ "nx": 7,
+ "nt": 10,
+ "npy": 15,
+ "nwiny": 7,
+ "novery": 3,
+ # "winsy": 3,
+ "npx": 13,
+ "nwinx": 5,
+ "noverx": 2,
+ # "winsx": 3,
+ "npt": 10,
+ "nwint": 4,
+ "novert": 2,
+ # "winst": 4,
+ "tapertype": None,
+ "savetaper": False,
+} # overlap, no taper (non saved
+par5 = {
"ny": 6,
"nx": 7,
"nt": 10,
@@ -79,10 +101,30 @@
"novert": 2,
# "winst": 4,
"tapertype": "hanning",
+ "savetaper": True,
} # overlap, with taper
+par6 = {
+ "ny": 6,
+ "nx": 7,
+ "nt": 10,
+ "npy": 15,
+ "nwiny": 7,
+ "novery": 3,
+ # "winsy": 3,
+ "npx": 13,
+ "nwinx": 5,
+ "noverx": 2,
+ # "winsx": 3,
+ "npt": 10,
+ "nwint": 4,
+ "novert": 2,
+ # "winst": 4,
+ "tapertype": "hanning",
+ "savetaper": False,
+} # overlap, with taper (non saved)
-@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
+@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
def test_Patch2D(par):
"""Dot-test and inverse for Patch2D operator"""
Op = MatrixMult(np.ones((par["nwiny"] * par["nwint"], par["ny"] * par["nt"])))
@@ -101,6 +143,7 @@ def test_Patch2D(par):
nover=(par["novery"], par["novert"]),
nop=(par["ny"], par["nt"]),
tapertype=par["tapertype"],
+ savetaper=par["savetaper"],
)
assert dottest(
Pop,
@@ -134,6 +177,7 @@ def test_Patch2D_scalings(par):
nover=(par["novery"], par["novert"]),
nop=(par["ny"], par["nt"]),
tapertype=par["tapertype"],
+ savetaper=par["savetaper"],
scalings=scalings,
)
assert dottest(
@@ -148,7 +192,7 @@ def test_Patch2D_scalings(par):
assert_array_almost_equal(x.ravel(), xinv)
-@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
+@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
def test_Patch3D(par):
"""Dot-test and inverse for Patch3D operator"""
Op = MatrixMult(
@@ -179,6 +223,7 @@ def test_Patch3D(par):
nover=(par["novery"], par["noverx"], par["novert"]),
nop=(par["ny"], par["nx"], par["nt"]),
tapertype=par["tapertype"],
+ savetaper=par["savetaper"],
)
assert dottest(
Pop,
diff --git a/pytests/test_sliding.py b/pytests/test_sliding.py
index 6750bbdf..55a62199 100755
--- a/pytests/test_sliding.py
+++ b/pytests/test_sliding.py
@@ -22,6 +22,7 @@
"noverx": 0,
# "winsx": 2,
"tapertype": None,
+ "savetaper": True,
} # no overlap, no taper
par2 = {
"ny": 6,
@@ -36,6 +37,7 @@
"noverx": 0,
# "winsx": 2,
"tapertype": "hanning",
+ "savetaper": True,
} # no overlap, with taper
par3 = {
"ny": 6,
@@ -50,8 +52,24 @@
"noverx": 2,
# "winsx": 4,
"tapertype": None,
+ "savetaper": True,
} # overlap, no taper
par4 = {
+ "ny": 6,
+ "nx": 7,
+ "nt": 10,
+ "npy": 15,
+ "nwiny": 7,
+ "novery": 3,
+ # "winsy": 3,
+ "npx": 10,
+ "nwinx": 4,
+ "noverx": 2,
+ # "winsx": 4,
+ "tapertype": None,
+ "savetaper": False,
+} # overlap, no taper (non saved)
+par5 = {
"ny": 6,
"nx": 7,
"nt": 10,
@@ -64,10 +82,26 @@
"noverx": 2,
# "winsx": 4,
"tapertype": "hanning",
+ "savetaper": True,
} # overlap, with taper
+par6 = {
+ "ny": 6,
+ "nx": 7,
+ "nt": 10,
+ "npy": 15,
+ "nwiny": 7,
+ "novery": 3,
+ # "winsy": 3,
+ "npx": 10,
+ "nwinx": 4,
+ "noverx": 2,
+ # "winsx": 4,
+ "tapertype": "hanning",
+ "savetaper": False,
+} # overlap, with taper (non saved)
-@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
+@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
def test_Sliding1D(par):
"""Dot-test and inverse for Sliding1D operator"""
Op = MatrixMult(np.ones((par["nwiny"], par["ny"])))
@@ -83,6 +117,7 @@ def test_Sliding1D(par):
nwin=par["nwiny"],
nover=par["novery"],
tapertype=par["tapertype"],
+ savetaper=par["savetaper"],
)
assert dottest(Slid, par["npy"], par["ny"] * nwins)
x = np.ones(par["ny"] * nwins)
@@ -92,7 +127,7 @@ def test_Sliding1D(par):
assert_array_almost_equal(x.ravel(), xinv)
-@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
+@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
def test_Sliding2D(par):
"""Dot-test and inverse for Sliding2D operator"""
Op = MatrixMult(np.ones((par["nwiny"] * par["nt"], par["ny"] * par["nt"])))
@@ -107,6 +142,7 @@ def test_Sliding2D(par):
nwin=par["nwiny"],
nover=par["novery"],
tapertype=par["tapertype"],
+ savetaper=par["savetaper"],
)
assert dottest(Slid, par["npy"] * par["nt"], par["ny"] * par["nt"] * nwins)
x = np.ones((par["ny"] * nwins, par["nt"]))
@@ -116,7 +152,7 @@ def test_Sliding2D(par):
assert_array_almost_equal(x.ravel(), xinv)
-@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4)])
+@pytest.mark.parametrize("par", [(par1), (par2), (par3), (par4), (par5), (par6)])
def test_Sliding3D(par):
"""Dot-test and inverse for Sliding3D operator"""
Op = MatrixMult(
@@ -140,6 +176,7 @@ def test_Sliding3D(par):
nover=(par["novery"], par["noverx"]),
nop=(par["ny"], par["nx"]),
tapertype=par["tapertype"],
+ savetaper=par["savetaper"],
)
assert dottest(
Slid,
diff --git a/pytests/test_sparsity.py b/pytests/test_sparsity.py
index b4ef5a30..ef4c0d6e 100644
--- a/pytests/test_sparsity.py
+++ b/pytests/test_sparsity.py
@@ -5,6 +5,9 @@
from pylops.basicoperators import FirstDerivative, Identity, MatrixMult
from pylops.optimization.sparsity import fista, irls, ista, omp, spgl1, splitbregman
+# currently test spgl1 only if numpy<2.0.0 is installed...
+np_version = np.__version__.split(".")
+
par1 = {
"ny": 11,
"nx": 11,
@@ -412,6 +415,6 @@ def test_SplitBregman(par):
x0=x0 if par["x0"] else None,
restart=False,
show=False,
- **dict(iter_lim=5, damp=1e-3)
+ **dict(iter_lim=5, damp=1e-3),
)
assert (np.linalg.norm(x - xinv) / np.linalg.norm(x)) < 1e-1
diff --git a/pytests/test_torchoperator.py b/pytests/test_torchoperator.py
index 38246a20..43f33e3f 100755
--- a/pytests/test_torchoperator.py
+++ b/pytests/test_torchoperator.py
@@ -1,3 +1,5 @@
+import platform
+
import numpy as np
import pytest
import torch
@@ -17,6 +19,11 @@ def test_TorchOperator(par):
must equal the adjoint of operator applied to the same vector, the two
results are also checked to be the same.
"""
+ # temporarily, skip tests on mac as torch seems not to recognized
+ # numpy when v2 is installed
+ if platform.system() == "Darwin":
+ return
+
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
Top = TorchOperator(Dop, batch=False)
@@ -40,6 +47,11 @@ def test_TorchOperator(par):
@pytest.mark.parametrize("par", [(par1)])
def test_TorchOperator_batch(par):
"""Apply forward for input with multiple samples (= batch) and flattened arrays"""
+ # temporarily, skip tests on mac as torch seems not to recognized
+ # numpy when v2 is installed
+ if platform.system() == "Darwin":
+ return
+
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
Top = TorchOperator(Dop, batch=True)
@@ -56,6 +68,11 @@ def test_TorchOperator_batch(par):
@pytest.mark.parametrize("par", [(par1)])
def test_TorchOperator_batch_nd(par):
"""Apply forward for input with multiple samples (= batch) and nd-arrays"""
+ # temporarily, skip tests on mac as torch seems not to recognized
+ # numpy when v2 is installed
+ if platform.system() == "Darwin":
+ return
+
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])), otherdims=(2,))
Top = TorchOperator(Dop, batch=True, flatten=False)
diff --git a/requirements-dev.txt b/requirements-dev.txt
index cd4b5911..42477b9a 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,6 +1,8 @@
numpy>=1.21.0
-scipy>=1.4.0
+scipy>=1.11.0
+--extra-index-url https://download.pytorch.org/whl/cpu
torch>=1.2.0
+jax
numba
pyfftw
PyWavelets
@@ -18,6 +20,7 @@ docutils<0.18
Sphinx
pydata-sphinx-theme
sphinx-gallery
+sphinxemoji
numpydoc
nbsphinx
image
diff --git a/requirements-doc.txt b/requirements-doc.txt
new file mode 100644
index 00000000..74fea77d
--- /dev/null
+++ b/requirements-doc.txt
@@ -0,0 +1,34 @@
+# Currently we force rdt to use numpy<2.0.0 to build the documentation
+# since the dtcwt and spgl1 are not yet compatible with numpy=2.0.0
+numpy>=1.21.0,<2.0.0
+scipy>=1.11.0
+jax
+--extra-index-url https://download.pytorch.org/whl/cpu
+torch>=1.2.0
+numba
+pyfftw
+PyWavelets
+spgl1
+scikit-fmm
+sympy
+devito
+dtcwt
+matplotlib
+ipython
+pytest
+pytest-runner
+setuptools_scm
+docutils<0.18
+Sphinx
+pydata-sphinx-theme
+sphinx-gallery
+sphinxemoji
+numpydoc
+nbsphinx
+image
+pre-commit
+autopep8
+isort
+black
+flake8
+mypy
diff --git a/tutorials/bayesian.py b/tutorials/bayesian.py
index 42e16a6f..653d6b3d 100755
--- a/tutorials/bayesian.py
+++ b/tutorials/bayesian.py
@@ -21,19 +21,26 @@
Based on the above definition, we construct some prior models in the frequency
domain, convert each of them to the time domain and use such an ensemble
-to estimate the prior mean :math:`\mu_\mathbf{x}` and model
-covariance :math:`\mathbf{C_x}`.
+to estimate the prior mean :math:`\mathbf{x}_0` and model
+covariance :math:`\mathbf{C}_{x_0}`.
-We then create our data by sampling the true signal at certain locations
+We then create our data by sampling the true signal at certain locations
and solve the resconstruction problem within a Bayesian framework. Since we are
assuming gaussianity in our priors, the equation to obtain the posterion mean
-can be derived analytically:
+and covariance can be derived analytically:
.. math::
- \mathbf{x} = \mathbf{x_0} + \mathbf{C}_x \mathbf{R}^T
- (\mathbf{R} \mathbf{C}_x \mathbf{R}^T + \mathbf{C}_y)^{-1} (\mathbf{y} -
+ \mathbf{x} = \mathbf{x_0} + \mathbf{C}_{x_0} \mathbf{R}^T
+ (\mathbf{R} \mathbf{C}_{x_0} \mathbf{R}^T + \mathbf{C}_y)^{-1} (\mathbf{y} -
\mathbf{R} \mathbf{x_0})
+and
+
+.. math::
+ \mathbf{C}_x = \mathbf{C}_{x_0} - \mathbf{C}_{x_0} \mathbf{R}^T
+ (\mathbf{R} \mathbf{C}_x \mathbf{R}^T + \mathbf{C}_y)^{-1}
+ \mathbf{R} \mathbf{C}_{x_0}
+
"""
import matplotlib.pyplot as plt
@@ -80,14 +87,15 @@ def prior_realization(f0, a0, phi0, sigmaf, sigmaa, sigmaphi, dt, nt, nfft):
sigmaa = [0.1, 0.5, 0.6]
phi0 = [-90.0, 0.0, 0.0]
sigmaphi = [0.1, 0.2, 0.4]
-sigmad = 1e-2
+sigmad = 1
+scaling = 100 # Scale by a factor to allow noise std=1
# Prior models
nt = 200
nfft = 2**11
dt = 0.004
t = np.arange(nt) * dt
-xs = np.array(
+xs = scaling * np.array(
[
prior_realization(f0, a0, phi0, sigmaf, sigmaa, sigmaphi, dt, nt, nfft)
for _ in range(nreals)
@@ -95,7 +103,10 @@ def prior_realization(f0, a0, phi0, sigmaf, sigmaa, sigmaphi, dt, nt, nfft):
)
# True model (taken as one possible realization)
-x = prior_realization(f0, a0, phi0, [0, 0, 0], [0, 0, 0], [0, 0, 0], dt, nt, nfft)
+x = scaling * prior_realization(
+ f0, a0, phi0, [0, 0, 0], [0, 0, 0], [0, 0, 0], dt, nt, nfft
+)
+
###############################################################################
# We have now a set of prior models in time domain. We can easily use sample
@@ -110,7 +121,7 @@ def prior_realization(f0, a0, phi0, sigmaf, sigmaa, sigmaphi, dt, nt, nfft):
N = 30 # lenght of decorrelation
diags = np.array([Cm[i, i - N : i + N + 1] for i in range(N, nt - N)])
diag_ave = np.average(diags, axis=0)
-# add a taper at the end to avoid edge effects
+# add a taper at the start and end to avoid edge effects
diag_ave *= np.hamming(2 * N + 1)
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
@@ -157,65 +168,107 @@ def prior_realization(f0, a0, phi0, sigmaf, sigmaa, sigmaphi, dt, nt, nfft):
ynmask = Rop.mask(x + n)
###############################################################################
-# First we apply the Bayesian inversion equation
-xbayes = x0 + Cm_op * Rop.H * (
+# First, since the problem is rather small, we construct the dense version of
+# all our matrices and we compute the analytical posterior mean and covariance
+
+Cm = Cm_op.todense()
+Cd = Cd_op.todense()
+R = Rop.todense()
+
+# Bayesian analytical solution
+xpost_ana = x0 + Cm @ R.T @ (np.linalg.solve(R @ Cm @ R.T + Cd, yn - R @ x0))
+Cmpost_ana = Cm - Cm @ R.T @ (np.linalg.solve(R @ Cm @ R.T + Cd, R @ Cm))
+
+###############################################################################
+# Next we solve the same Bayesian inversion equation iteratively. We will see
+# that provided we use enough iterations we can retrieve the same values of
+# the analytical posterior mean
+xpost_iter = x0 + Cm_op * Rop.H * (
lsqr(Rop * Cm_op * Rop.H + Cd_op, yn - Rop * x0, iter_lim=400)[0]
)
-# Visualize
+###############################################################################
+# But what is the problem did not allow creating dense matrices for both the
+# operator and the input covariance matrices. In this case, we can resort to the
+# Randomize-Then-Optimize algorithm of Bardsley et al., 2014, which simply solves
+# the same problem that we solved to find the MAP solution repeatedly by adding
+# random noise to the data. It can be shown that the sample mean and covariance
+# of the solutions of the different perturbed problems provide a good
+# approximation for the true posterior mean and covariance.
+
+# RTO number of solutions
+nreals = 1000
+
+xrto = []
+for ireal in range(nreals):
+ yreal = yn + Rop * np.random.normal(0, sigmad, nt)
+ xrto.append(
+ x0
+ + Cm_op
+ * Rop.H
+ * (lsqr(Rop * Cm_op * Rop.H + Cd_op, yreal - Rop * x0, iter_lim=400))[0]
+ )
+
+xrto = np.array(xrto)
+xpost_rto = np.average(xrto, axis=0)
+Cmpost_rto = ((xrto - xpost_rto).T @ (xrto - xpost_rto)) / nreals
+
+###############################################################################
+# Finally we visualize the different results
+
+# Means
fig, ax = plt.subplots(1, 1, figsize=(12, 5))
ax.plot(t, x, "k", lw=6, label="true")
+ax.plot(t, xpost_ana, "r", lw=7, label="bayesian inverse (ana)")
+ax.plot(t, xpost_iter, "g", lw=5, label="bayesian inverse (iter)")
+ax.plot(t, xpost_rto, "b", lw=3, label="bayesian inverse (rto)")
ax.plot(t, ymask, ".k", ms=25, label="available samples")
ax.plot(t, ynmask, ".r", ms=25, label="available noisy samples")
-ax.plot(t, xbayes, "r", lw=3, label="bayesian inverse")
ax.legend()
-ax.set_title("Signal")
+ax.set_title("Mean reconstruction")
ax.set_xlim(0, 0.8)
-plt.tight_layout()
-###############################################################################
-# So far we have been able to estimate our posterion mean. What about its
-# uncertainties (i.e., posterion covariance)?
-#
-# In real-life applications it is very difficult (if not impossible)
-# to directly compute the posterior covariance matrix. It is much more
-# useful to create a set of models that sample the posterion probability.
-# We can do that by solving our problem several times using different prior
-# realizations as starting guesses:
-
-xpost = [
- x0
- + Cm_op
- * Rop.H
- * (lsqr(Rop * Cm_op * Rop.H + Cd_op, yn - Rop * x0, iter_lim=400)[0])
- for x0 in xs[:30]
-]
-xpost = np.array(xpost)
-
-x0post = np.average(xpost, axis=0)
-Cm_post = ((xpost - x0post).T @ (xpost - x0post)) / nreals
-
-# Visualize
+# RTO realizations
fig, ax = plt.subplots(1, 1, figsize=(12, 5))
ax.plot(t, x, "k", lw=6, label="true")
-ax.plot(t, xpost.T, "--r", lw=1)
-ax.plot(t, x0post, "r", lw=3, label="bayesian inverse")
+ax.plot(t, xrto[::10].T, "--b", lw=0.5)
+ax.plot(t, xpost_rto, "b", lw=3, label="bayesian inverse (rto)")
ax.plot(t, ymask, ".k", ms=25, label="available samples")
ax.plot(t, ynmask, ".r", ms=25, label="available noisy samples")
ax.legend()
-ax.set_title("Signal")
+ax.set_title("RTO realizations")
ax.set_xlim(0, 0.8)
-fig, ax = plt.subplots(1, 1, figsize=(5, 4))
-im = ax.imshow(
- Cm_post, interpolation="nearest", cmap="seismic", extent=(t[0], t[-1], t[-1], t[0])
+# Covariances
+fig, axs = plt.subplots(1, 2, figsize=(12, 4))
+axs[0].imshow(
+ Cmpost_ana,
+ interpolation="nearest",
+ cmap="seismic",
+ vmin=-5e-1,
+ vmax=2,
+ extent=(t[0], t[-1], t[-1], t[0]),
+)
+axs[0].set_title(r"$\mathbf{C}_m^{post,ANA}$")
+axs[0].axis("tight")
+
+axs[1].imshow(
+ Cmpost_rto,
+ interpolation="nearest",
+ cmap="seismic",
+ vmin=-5e-1,
+ vmax=2,
+ extent=(t[0], t[-1], t[-1], t[0]),
)
-ax.set_title(r"$\mathbf{C}_m^{posterior}$")
-ax.axis("tight")
+axs[1].set_title(r"$\mathbf{C}_m^{post,RTO}$")
+axs[1].axis("tight")
plt.tight_layout()
###############################################################################
# Note that here we have been able to compute a sample posterior covariance
-# from its estimated samples. By displaying it we can see how both the overall
+# from its estimated samples. By displaying it we can see how both the overall
# variances and the correlation between different parameters have become
-# narrower compared to their prior counterparts.
+# narrower compared to their prior counterparts. Moreover, whilst the RTO
+# covariance seems to be slightly under-estimated, this represents an appealing
+# alternative to the closed-form solution for large-scale problems under
+# Gaussian assumptions.
diff --git a/tutorials/ilsm.py b/tutorials/ilsm.py
index b4b016f3..1f394bfc 100755
--- a/tutorials/ilsm.py
+++ b/tutorials/ilsm.py
@@ -1,5 +1,5 @@
r"""
-20. Image Domain Least-squares migration
+19. Image Domain Least-squares migration
========================================
Seismic migration is the process by which seismic data are manipulated to create
an image of the subsurface reflectivity.
diff --git a/tutorials/jaxop.py b/tutorials/jaxop.py
new file mode 100755
index 00000000..c7a30d40
--- /dev/null
+++ b/tutorials/jaxop.py
@@ -0,0 +1,103 @@
+r"""
+21. JAX Operator
+================
+This tutorial is aimed at introducing the :class:`pylops.JaxOperator` operator. This
+represents the entry-point to the JAX backend of PyLops.
+
+More specifically, by wrapping any of PyLops' operators into a
+:class:`pylops.JaxOperator` one can:
+
+- apply forward, adjoint and use any of PyLops solver with JAX arrays;
+- enable automatic differentiation;
+- enable automatic vectorization.
+
+Moreover, both the forward and adjoint are internally just-in-time compiled
+to enable any further optimization provided by JAX.
+
+In this example we will consider a :class:`pylops.MatrixMult` operator and
+showcase how to use it in conjunction with :class:`pylops.JaxOperator`
+to enable the different JAX functionalities mentioned above.
+
+"""
+import jax
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+import numpy as np
+
+import pylops
+
+plt.close("all")
+np.random.seed(10)
+
+###############################################################################
+# Let's start by creating a :class:`pylops.MatrixMult` operator. We will then
+# perform the dot-test as well as apply the forward and adjoint operations to
+# JAX arrays.
+
+n = 4
+G = np.random.normal(0, 1, (n, n)).astype("float32")
+Gopjax = pylops.JaxOperator(pylops.MatrixMult(jnp.array(G), dtype="float32"))
+
+# dottest
+pylops.utils.dottest(Gopjax, n, n, backend="jax", verb=True, atol=1e-3)
+
+# forward
+xjnp = jnp.ones(n, dtype="float32")
+yjnp = Gopjax @ xjnp
+
+# adjoint
+xadjjnp = Gopjax.H @ yjnp
+
+###############################################################################
+# We can now use one of PyLops solvers to invert the operator
+
+xcgls = pylops.optimization.basic.cgls(
+ Gopjax, yjnp, x0=jnp.zeros(n), niter=100, tol=1e-10, show=True
+)[0]
+print("Inverse: ", xcgls)
+
+###############################################################################
+# Let's see how we can empower the automatic differentiation capabilities
+# of JAX to obtain the adjoint of our operator without having to implement it.
+# Although in PyLops the adjoint of any of operators is hand-written (and
+# optimized), it may be useful in some cases to quickly implement the forward
+# pass of a new operator and get the adjoint for free. This could be extremely
+# beneficial during the prototyping stage of an operator before embarking in
+# implementing an efficient hand-written adjoint.
+
+xadjjnpad = Gopjax.rmatvecad(xjnp, yjnp)
+
+print("Hand-written Adjoint: ", xadjjnp)
+print("AD Adjoint: ", xadjjnpad)
+
+###############################################################################
+# And more in general how we can combine any of JAX native operations with a
+# PyLops operator.
+
+
+def fun(x):
+ y = Gopjax(x)
+ loss = jnp.sum(y)
+ return loss
+
+
+xgrad = jax.grad(fun)(xjnp)
+print("Grad: ", xgrad)
+
+###############################################################################
+# We turn now our attention to automatic vectorization, which is very useful
+# if we want to apply the same operator to multiple vectors. In PyLops we can
+# easily do so by using the ``matmat`` and ``rmatmat`` methods, however under
+# the hood what these methods do is to simply run a for...loop and call the
+# corresponding ``matvec`` / ``rmatvec`` methods multiple times. On the other
+# hand, JAX is able to automatically add a batch axis at the beginning of
+# operator. Moreover, this can be seamlessly combined with `jax.jit` to
+# further improve performance.
+
+auto_batch_matvec = jax.jit(jax.vmap(Gopjax._matvec))
+xs = jnp.stack([xjnp, xjnp])
+ys = auto_batch_matvec(xs)
+
+print("Original output: ", yjnp)
+print("AV Output 1: ", ys[0])
+print("AV Output 1: ", ys[1])
diff --git a/tutorials/torchop.py b/tutorials/torchop.py
index c555573d..9e73d7b3 100755
--- a/tutorials/torchop.py
+++ b/tutorials/torchop.py
@@ -1,6 +1,6 @@
r"""
-19. Automatic Differentiation
-=============================
+20. Torch Operator
+==================
This tutorial focuses on the use of :class:`pylops.TorchOperator` to allow performing
Automatic Differentiation (AD) on chains of operators which can be: