diff --git a/.gitignore b/.gitignore index 42f9cc53..0365f10d 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ build dist pylops.egg-info/ .eggs/ +__pycache__ # setuptools_scm generated # pylops/version.py diff --git a/docs/source/conf.py b/docs/source/conf.py index c5e6536d..75680cac 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- -import sys -import os import datetime +import os +import sys + from sphinx_gallery.sorting import ExampleTitleSortKey + from pylops import __version__ # Sphinx needs to be able to import the package to use autodoc and get the version number @@ -37,6 +39,8 @@ "matplotlib": ("https://matplotlib.org/", None), "pyfftw": ("https://pyfftw.readthedocs.io/en/latest/", None), "spgl1": ("https://spgl1.readthedocs.io/en/latest/", None), + "pymc": ("https://www.pymc.io/", None), + "arviz": ("https://python.arviz.org/en/latest/", None), } # Generate autodoc stubs with summaries from code @@ -103,9 +107,7 @@ # These enable substitutions using |variable| in the rst files rst_epilog = """ .. |year| replace:: {year} -""".format( - year=year -) +""".format(year=year) html_static_path = ["_static"] html_last_updated_fmt = "%b %d, %Y" html_title = "PyLops" @@ -122,15 +124,15 @@ # Theme config html_theme = "pydata_sphinx_theme" html_theme_options = { - "logo_only": True, - "display_version": True, + # "logo_only": True, + # "display_version": True, "logo": { "image_light": "pylops_b.png", "image_dark": "pylops.png", } } html_css_files = [ - 'css/custom.css', + "css/custom.css", ] html_context = { diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst index ecc87120..a37597e9 100755 --- a/docs/source/gpu.rst +++ b/docs/source/gpu.rst @@ -462,5 +462,5 @@ Again, the code is almost unchanged apart from the fact that we now use ``jax`` .. note:: - More examples for the CuPy and JAX backends be found `here `_ - and `here `_. \ No newline at end of file + More examples for the CuPy and JAX backends be found `here `__ + and `here `__. \ No newline at end of file diff --git a/examples/plot_bayeslinearregr.py b/examples/plot_bayeslinearregr.py new file mode 100644 index 00000000..42417eb1 --- /dev/null +++ b/examples/plot_bayeslinearregr.py @@ -0,0 +1,235 @@ +r""" +Bayesian Linear Regression +========================== + +In the :ref:`sphx_glr_gallery_plot_linearregr.py` example, we +performed linear regression by applying a variety of solvers to the +:py:class:`pylops.LinearRegression` operator. + +In this example, we will apply linear regression the Bayesian way. +In Bayesian inference, we are not looking for a "best" estimate +of the linear regression parameters; rather, we are looking for +all possible parameters and their associated (posterior) probability, +that is, how likely that those are the parameters that generated our data. + +To do this, we will leverage the probabilistic programming library +`PyMC `_. + +In the Bayesian formulation, we write the problem in the following manner: + + .. math:: + y_i \sim N(x_0 + x_1 t_i, \sigma) \qquad \forall i=0,1,\ldots,N-1 + +where :math:`x_0` is the intercept and :math:`x_1` is the gradient. +This notation means that the obtained measurements :math:`y_i` are normally distributed around +mean :math:`x_0 + x_1 t_i` with a standard deviation of :math:`\sigma`. +We can also express this problem in a matrix form, which makes it clear that we +can use a PyLops operator to describe this relationship. + + .. math:: + \mathbf{y} \sim N(\mathbf{A} \mathbf{x}, \sigma) + +In this example, we will combine the Bayesian power of PyMC with the linear language of +PyLops. +""" + +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import pymc as pm + +import pylops + +plt.close("all") +np.random.seed(10) + +############################################################################### +# Define the input parameters: number of samples along the t-axis (``N``), +# linear regression coefficients (``x``), and standard deviation of noise +# to be added to data (``sigma``). +N = 30 +x = np.array([1.0, 0.5]) +sigma = 0.25 + +############################################################################### +# Let's create the time axis and initialize the +# :py:class:`pylops.LinearRegression` operator +t = np.linspace(0, 1, N) +LRop = pylops.LinearRegression(t, dtype=t.dtype) + +############################################################################### +# We can then apply the operator in forward mode to compute our data points +# along the x-axis (``y``). We will also generate some random gaussian noise +# and create a noisy version of the data (``yn``). +y = LRop @ x +yn = y + np.random.normal(0, sigma, N) + +############################################################################### +# The deterministic solution is to solve the +# :math:`\mathbf{y} = \mathbf{A} \mathbf{x}` in a least-squares sense. +# Using PyLops, the ``/`` operator solves the iteratively (i.e., +# :py:func:`scipy.sparse.linalg.lsqr`). +# In Bayesian terminology, this estimator is known as the maximulum likelihood +# estimation (MLE). +x_mle = LRop / yn +noise_mle = np.sqrt(np.sum((yn - LRop @ x_mle) ** 2) / (N - 1)) + +############################################################################### +# Alternatively, we may regularize the problem. In this case we will condition +# the solution towards smaller magnitude parameters, we can use a regularized +# least squares approach. Since the weight is pretty small, we expect the +# result to be very similar to the one above. +sigma_prior = 20 +eps = 1 / np.sqrt(2) / sigma_prior +x_map, *_ = pylops.optimization.basic.lsqr(LRop, yn, damp=eps) +noise_map = np.sqrt(np.sum((yn - LRop @ x_map) ** 2) / (N - 1)) + +############################################################################### +# Let's plot the best fitting line for the case of noise free and noisy data +fig, ax = plt.subplots(figsize=(8, 4)) +for est, est_label, c in zip( + [x, x_mle, x_map], ["True", "MLE", "MAP"], ["k", "C0", "C1"] +): + ax.plot( + np.array([t.min(), t.max()]), + np.array([t.min(), t.max()]) * est[1] + est[0], + color=c, + ls="--" if est_label == "MAP" else "-", + lw=4, + label=rf"{est_label}: $x_0$ = {est[0]:.2f}, $x_1$ = {est[1]:.2f}", + ) +ax.scatter(t, y, c="r", s=70) +ax.scatter(t, yn, c="g", s=70) +ax.legend() +fig.tight_layout() + +############################################################################### +# Let's solve this problem the Bayesian way, which consists in obtaining the +# posterior probability :math:`p(\mathbf{x}\,|\,\mathbf{y})` via Bayes theorem: +# +# .. math:: +# \underbrace{p(\mathbf{x} \,|\, \mathbf{y})}_{\text{posterior}} +# \propto \overbrace{p(\mathbf{y} \,|\, \mathbf{x})}^{\text{likelihood}}\; +# \overbrace{p(\mathbf{x})}^{\text{prior}} +# +# To do so, we need to define the priors and the likelihood. +# +# As hinted above, priors in Bayesian analysis can be interpreted as the +# probabilistic equivalent to regularization. Finding the maximum a posteriori +# (MAP) estimate to a least-squares problem with a Gaussian prior on the +# parameters is equivalent to applying a Tikhonov (L2) regularization to these +# parameters. A Laplace prior is equivalent to a sparse (L1) regularization. +# In addition, the weight of the regularization is controlled by the "scale" of +# the distribution of the prior; the standard deviation (in the case of a Gaussian) +# is inversely proportional strength of the regularization. So if we use the same +# sigma_prior above as the standard deviation of our prior distribition, we +# should get the same MAP out of them. In practice, in Bayesian analysis we are +# not only interested in point estimates like MAP, but rather, the whole +# posterior distribution. If you want the MAP only, there are better, +# methods to obtain them, such as the one shown above. +# +# In this problem we will use weak, not very informative priors, by setting +# their prior to accept a wide range of probable values. This is equivalent to +# setting the "weights" to be small, as shown above: +# +# .. math:: +# x_0 \sim N(0, 20) +# +# x_1 \sim N(0, 20) +# +# \sigma \sim \text{HalfCauchy}(10) +# +# The (log) likelihood in Bayesian analysis is the equivalent of the cost +# function in deterministic inverse problems. In this case we have already +# seen this likelihood: +# +# .. math:: +# p(\mathbf{y}\,|\,\mathbf{x}) \sim N(\mathbf{A}\mathbf{x}, \sigma) +# + +# Construct a PyTensor `Op` which can be used in a PyMC model. +pytensor_lrop = pylops.PyTensorOperator(LRop) +dims = pytensor_lrop.dims # Inherits dims, dimsd and shape from LRop + +# Construct the PyMC model +with pm.Model() as model: + y_data = pm.Data("y_data", yn) + + # Define priors + sp = pm.HalfCauchy("σ", beta=10) + xp = pm.Normal("x", 0, sigma=sigma_prior, shape=dims) + mu = pm.Deterministic("mu", pytensor_lrop(xp)) + + # Define likelihood + likelihood = pm.Normal("y", mu=mu, sigma=sp, observed=y_data) + + # Inference! + idata = pm.sample(500, tune=200, chains=2) + +############################################################################### +# The plot below is known as the "trace" plot. The left column displays the +# posterior distributions of all latent variables in the model. The top-left +# plot has multiple colored posteriors, one for each parameter of the latent +# vector :math:`\mathbf{x}`. The bottom left plot displays the posterior of the +# estimated noise :math:`\sigma`. +# +# In these plots there are multiple distributions of the same color and +# multiple line styles. Each of these represents a "chain". A chain is a single +# run of a Monte Carlo algorithm. Generally, Monte Carlo methods run various +# chains to ensure that all regions of the posterior distribution are sampled. +# These chains are shown on the right hand plots. + +axes = az.plot_trace(idata, figsize=(10, 7), var_names=["~mu"]) +axes[0, 0].axvline(x[0], label="True Intercept", lw=2, color="k") +axes[0, 0].axvline(x_map[0], label="Intercept MAP", lw=2, color="C0", ls="--") +axes[0, 0].axvline(x[1], label="True Slope", lw=2, color="darkgray") +axes[0, 0].axvline(x_map[1], label="Slope MAP", lw=2, color="C1", ls="--") +axes[0, 1].axhline(x[0], label="True Intercept", lw=2, color="k") +axes[0, 1].axhline(x_map[0], label="Intercept MAP", lw=2, color="C0", ls="--") +axes[0, 1].axhline(x[1], label="True Slope", lw=2, color="darkgray") +axes[0, 1].axhline(x_map[1], label="Slope MAP", lw=2, color="C1", ls="--") +axes[1, 0].axvline(sigma, label="True Sigma", lw=2, color="k") +axes[1, 0].axvline(noise_map, label="Sigma MAP", lw=2, color="C0", ls="--") +axes[1, 1].axhline(sigma, label="True Sigma", lw=2, color="k") +axes[1, 1].axhline(noise_map, label="Sigma MAP", lw=2, color="C0", ls="--") +for ax in axes.ravel(): + ax.legend() +ax.get_figure().tight_layout() + +################################################################################ +# With this model, we can obtain an uncertainty measurement via the High Density +# Interval. To do that, we need to sample the "preditive posterior", that is, +# the posterior distribution of the data, given the model. What this does is +# sample the latent vetors from their posteriors (above), and use the model +# to construct realizations of the data given these realizations. They represent +# what the model thinks the data should look like, given everything it has +# already seen. + +with model: + pm.sample_posterior_predictive(idata, extend_inferencedata=True) + +############################################################################### +# sphinx_gallery_thumbnail_number = 3 +fig, ax = plt.subplots(figsize=(8, 4)) +az.plot_hdi( + t, + idata.posterior_predictive["y"], + fill_kwargs={"label": "95% HDI"}, + hdi_prob=0.95, + ax=ax, +) +for est, est_label, c in zip( + [x, x_mle, x_map], ["True", "MLE", "MAP"], ["k", "C0", "C1"] +): + ax.plot( + np.array([t.min(), t.max()]), + np.array([t.min(), t.max()]) * est[1] + est[0], + color=c, + ls="--" if est_label == "MAP" else "-", + lw=4, + label=rf"{est_label}: $x_0$ = {est[0]:.2f}, $x_1$ = {est[1]:.2f}", + ) +ax.scatter(t, y, c="r", s=70) +ax.scatter(t, yn, c="g", s=70) +ax.legend() +fig.tight_layout() diff --git a/examples/plot_slopeest.py b/examples/plot_slopeest.py index 89938a4a..22cf9096 100755 --- a/examples/plot_slopeest.py +++ b/examples/plot_slopeest.py @@ -14,7 +14,7 @@ precondition sparsity-promoting inverse problems. We will show examples of a variety of different settings, including a comparison -with the original implementation in [1]. +with the original implementation in [1]_. .. [1] van Vliet, L. J., Verbeek, P. W., "Estimators for orientation and anisotropy in digitized images", Journal ASCI Imaging Workshop. 1995. @@ -145,7 +145,7 @@ ############################################################################### # Concentric circles # ------------------ -# The original paper by van Vliet and Verbeek [1] has an example with concentric +# The original paper by van Vliet and Verbeek [1]_ has an example with concentric # circles. We recover their original images and compare our implementation with # theirs. @@ -215,7 +215,7 @@ def rgb2gray(rgb): ############################################################################### # Core samples # ------------------ -# The original paper by van Vliet and Verbeek [1] also has an example with images +# The original paper by van Vliet and Verbeek [1]_ also has an example with images # of core samples. Since the original paper does not have a scale with which to # plot the angles, we have chosen ours it to match their image as closely as # possible. diff --git a/pylops/__init__.py b/pylops/__init__.py index 7672fda4..b1ee04c8 100755 --- a/pylops/__init__.py +++ b/pylops/__init__.py @@ -48,6 +48,7 @@ from .config import * from .linearoperator import * from .torchoperator import * +from .pytensoroperator import * from .jaxoperator import * from .basicoperators import * from . import ( diff --git a/pylops/basicoperators/matrixmult.py b/pylops/basicoperators/matrixmult.py index a5f713ab..66b96c70 100644 --- a/pylops/basicoperators/matrixmult.py +++ b/pylops/basicoperators/matrixmult.py @@ -79,12 +79,14 @@ def __init__( else: otherdims = _value_or_sized_to_array(otherdims) self.otherdims = np.array(otherdims, dtype=int) - dims, dimsd = np.insert(self.otherdims, 0, self.A.shape[1]), np.insert( - self.otherdims, 0, self.A.shape[0] + dims, dimsd = ( + np.insert(self.otherdims, 0, self.A.shape[1]), + np.insert(self.otherdims, 0, self.A.shape[0]), + ) + self.dimsflatten, self.dimsdflatten = ( + np.insert([np.prod(self.otherdims)], 0, self.A.shape[1]), + np.insert([np.prod(self.otherdims)], 0, self.A.shape[0]), ) - self.dimsflatten, self.dimsdflatten = np.insert( - [np.prod(self.otherdims)], 0, self.A.shape[1] - ), np.insert([np.prod(self.otherdims)], 0, self.A.shape[0]) self.reshape = True explicit = False @@ -138,7 +140,7 @@ def inv(self) -> NDArray: r"""Return the inverse of :math:`\mathbf{A}`. Returns - ---------- + ------- Ainv : :obj:`numpy.ndarray` Inverse matrix. diff --git a/pylops/basicoperators/regression.py b/pylops/basicoperators/regression.py index dc51ded5..1160fe9b 100644 --- a/pylops/basicoperators/regression.py +++ b/pylops/basicoperators/regression.py @@ -124,7 +124,7 @@ def apply(self, t: npt.ArrayLike, x: NDArray) -> NDArray: Regression coefficients Returns - ---------- + ------- y : :obj:`numpy.ndarray` Values along y-axis diff --git a/pylops/basicoperators/restriction.py b/pylops/basicoperators/restriction.py index c2fceab7..e27610bf 100644 --- a/pylops/basicoperators/restriction.py +++ b/pylops/basicoperators/restriction.py @@ -189,7 +189,7 @@ def mask(self, x: NDArray) -> NDArray: Input array (can be either flattened or not) Returns - ---------- + ------- y : :obj:`numpy.ma.core.MaskedArray` Masked array. diff --git a/pylops/pytensoroperator.py b/pylops/pytensoroperator.py new file mode 100644 index 00000000..cdd04c04 --- /dev/null +++ b/pylops/pytensoroperator.py @@ -0,0 +1,85 @@ +import pylops +from pylops.utils import deps + +pytensor_message = deps.pytensor_import("the pytensor module") + +if pytensor_message is not None: + + class PyTensorOperator: + """PyTensor Op which applies a PyLops Linear Operator, including gradient support. + + This class "converts" a PyLops `LinearOperator` class into a PyTensor `Op`. + This applies the `LinearOperator` in "forward-mode" in `self.perform`, and applies + its adjoint when computing the vector-Jacobian product (`self.grad`), as that is + the analytically correct gradient for linear operators. This class should pass + `pytensor.gradient.verify_grad`. + + Parameters + ---------- + LOp : pylops.LinearOperator + """ + + def __init__(self, LOp: pylops.LinearOperator) -> None: + if not deps.pytensor_enabled: + raise NotImplementedError(pytensor_message) + +else: + import pytensor.tensor as pt + from pytensor.graph.basic import Apply + from pytensor.graph.op import Op + + class _PyTensorOperatorNoGrad(Op): + """PyTensor Op which applies a PyLops Linear Operator, excluding gradient support. + + This class "converts" a PyLops `LinearOperator` class into a PyTensor `Op`. + This applies the `LinearOperator` in "forward-mode" in `self.perform`. + + Parameters + ---------- + LOp : pylops.LinearOperator + """ + + __props__ = ("dims", "dimsd", "shape") + + def __init__(self, LOp: pylops.LinearOperator) -> None: + self._LOp = LOp + self.dims = self._LOp.dims + self.dimsd = self._LOp.dimsd + self.shape = self._LOp.shape + super().__init__() + + def make_node(self, x) -> Apply: + x = pt.as_tensor_variable(x) + inputs = [x] + outputs = [pt.tensor(dtype=x.type.dtype, shape=self._LOp.dimsd)] + return Apply(self, inputs, outputs) + + def perform( + self, node: Apply, inputs: list, output_storage: list[list[None]] + ) -> None: + (x,) = inputs + (yt,) = output_storage + yt[0] = self._LOp @ x + + class PyTensorOperator(_PyTensorOperatorNoGrad): + """PyTensor Op which applies a PyLops Linear Operator, including gradient support. + + This class "converts" a PyLops `LinearOperator` class into a PyTensor `Op`. + This applies the `LinearOperator` in "forward-mode" in `self.perform`, and applies + its adjoint when computing the vector-Jacobian product (`self.grad`), as that is + the analytically correct gradient for linear operators. This class should pass + `pytensor.gradient.verify_grad`. + + Parameters + ---------- + LOp : pylops.LinearOperator + """ + + def __init__(self, LOp: pylops.LinearOperator) -> None: + super().__init__(LOp) + self._gradient_op = _PyTensorOperatorNoGrad(self._LOp.H) + + def grad( + self, inputs: list[pt.TensorVariable], output_grads: list[pt.TensorVariable] + ): + return [self._gradient_op(output_grads[0])] diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index ecf69a95..df4a0d9e 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -10,6 +10,7 @@ "spgl1_enabled", "sympy_enabled", "torch_enabled", + "pytensor_enabled", ] import os @@ -223,6 +224,23 @@ def sympy_import(message: Optional[str] = None) -> str: return sympy_message +def pytensor_import(message: Optional[str] = None) -> str: + if pytensor_enabled: + try: + import_module("pytensor") # noqa: F401 + + pytensor_message = None + except Exception as e: + pytensor_message = f"Failed to import pytensor (error:{e})." + else: + pytensor_message = ( + f"pytensor package not installed. In order to be able to use " + f"{message} run " + f'"pip install pytensor" or "conda install -c conda-forge pytensor".' + ) + return pytensor_message + + # Set package availability booleans # cupy and jax: the package is imported to check everything is working correctly, # if not the package is disabled. We do this here as these libraries are used as drop-in @@ -245,3 +263,4 @@ def sympy_import(message: Optional[str] = None) -> str: spgl1_enabled = util.find_spec("spgl1") is not None sympy_enabled = util.find_spec("sympy") is not None torch_enabled = util.find_spec("torch") is not None +pytensor_enabled = util.find_spec("pytensor") is not None diff --git a/pylops/waveeqprocessing/marchenko.py b/pylops/waveeqprocessing/marchenko.py index bf8e3305..d810c70b 100644 --- a/pylops/waveeqprocessing/marchenko.py +++ b/pylops/waveeqprocessing/marchenko.py @@ -301,7 +301,7 @@ def apply_onepoint( greens: bool = False, dottest: bool = False, usematmul: bool = False, - **kwargs_solver + **kwargs_solver, ) -> Union[ Tuple[NDArray, NDArray, NDArray, NDArray, NDArray], Tuple[NDArray, NDArray, NDArray, NDArray], @@ -341,7 +341,7 @@ def apply_onepoint( for numpy and cupy `data`, respectively) Returns - ---------- + ------- f1_inv_minus : :obj:`numpy.ndarray` Inverted upgoing focusing function of size :math:`[n_r \times n_t]` f1_inv_plus : :obj:`numpy.ndarray` @@ -473,7 +473,7 @@ def apply_onepoint( Mop, d.ravel(), x0=self.ncp.zeros(2 * (2 * self.nt - 1) * self.nr, dtype=self.dtype), - **kwargs_solver + **kwargs_solver, )[0] f1_inv = f1_inv.reshape(2 * self.nt2, self.nr) @@ -486,8 +486,9 @@ def apply_onepoint( # Create Green's functions g_inv = Gop * f1_inv_tot.ravel() g_inv = g_inv.reshape(2 * self.nt2, self.ns) - g_inv_minus, g_inv_plus = -g_inv[: self.nt2].T, np.fliplr( - g_inv[self.nt2 :].T + g_inv_minus, g_inv_plus = ( + -g_inv[: self.nt2].T, + np.fliplr(g_inv[self.nt2 :].T), ) if rtm and greens: return f1_inv_minus, f1_inv_plus, p0_minus, g_inv_minus, g_inv_plus @@ -507,7 +508,7 @@ def apply_multiplepoints( greens: bool = False, dottest: bool = False, usematmul: bool = False, - **kwargs_solver + **kwargs_solver, ) -> Union[ Tuple[NDArray, NDArray, NDArray, NDArray, NDArray], Tuple[NDArray, NDArray, NDArray, NDArray], @@ -548,7 +549,7 @@ def apply_multiplepoints( for numpy and cupy `data`, respectively) Returns - ---------- + ------- f1_inv_minus : :obj:`numpy.ndarray` Inverted upgoing focusing function of size :math:`[n_r \times n_{vs} \times n_t]` @@ -695,7 +696,7 @@ def apply_multiplepoints( x0=self.ncp.zeros( 2 * (2 * self.nt - 1) * self.nr * nvs, dtype=self.dtype ), - **kwargs_solver + **kwargs_solver, )[0] f1_inv = f1_inv.reshape(2 * self.nt2, self.nr, nvs) diff --git a/pytests/test_oneway.py b/pytests/test_oneway.py index 48f73a9e..1f1b936c 100755 --- a/pytests/test_oneway.py +++ b/pytests/test_oneway.py @@ -96,7 +96,7 @@ def test_PhaseShift_3dsignal(par): @pytest.mark.parametrize("par", [(par1), (par2), (par1v), (par2v)]) def test_Deghosting_2dsignal(par, create_data2D): """Deghosting of 2d data""" - p2d, p2d_minus = create_data2D(1 if par["kind"] is "p" else -1) + p2d, p2d_minus = create_data2D(1 if par["kind"] == "p" else -1) p2d_minus_inv, p2d_plus_inv = Deghosting( p2d, @@ -111,7 +111,7 @@ def test_Deghosting_2dsignal(par, create_data2D): npad=0, ntaper=0, dtype=par["dtype"], - **dict(damp=1e-10, iter_lim=60) + **dict(damp=1e-10, iter_lim=60), ) assert np.linalg.norm(p2d_minus_inv - p2d_minus) / np.linalg.norm(p2d_minus) < 3e-1 diff --git a/pytests/test_pytensoroperator.py b/pytests/test_pytensoroperator.py new file mode 100755 index 00000000..9a59bc88 --- /dev/null +++ b/pytests/test_pytensoroperator.py @@ -0,0 +1,51 @@ +import numpy as np +import pytensor +import pytest +from numpy.testing import assert_array_equal + +from pylops import MatrixMult, PyTensorOperator + +par1 = {"ny": 11, "nx": 11, "dtype": np.float32} # square +par2 = {"ny": 21, "nx": 11, "dtype": np.float32} # overdetermined + +np.random.seed(0) +rng = np.random.default_rng() + + +@pytest.mark.parametrize("par", [(par1)]) +def test_PyTensorOperator(par): + """Verify output and gradient of PyTensor function obtained from a LinearOperator.""" + Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"]))) + pytensor_op = PyTensorOperator(Dop) + + # Check gradient + inp = np.random.randn(*pytensor_op.dims) + pytensor.gradient.verify_grad(pytensor_op, (inp,), rng=rng) + + # Check value + x = pytensor.tensor.dvector() + f = pytensor.function([x], pytensor_op(x)) + out = f(inp) + assert_array_equal(out, Dop @ inp) + + +@pytest.mark.parametrize("par", [(par1)]) +def test_PyTensorOperator_nd(par): + """Verify output and gradient of PyTensor function obtained from a LinearOperator + using an ND-array.""" + otherdims = rng.choice(range(1, 3), size=rng.choice(range(2, 8))) + Dop = MatrixMult( + np.random.normal(0.0, 1.0, (par["ny"], par["nx"])), otherdims=otherdims + ) + pytensor_op = PyTensorOperator(Dop) + + # Check gradient + inp = np.random.randn(*pytensor_op.dims) + pytensor.gradient.verify_grad(pytensor_op, (inp,), rng=rng) + + # Check value + tensor = pytensor.tensor.TensorType(dtype="float64", shape=(None,) * inp.ndim) + x = tensor() + f = pytensor.function([x], pytensor_op(x)) + out = f(inp) + assert_array_equal(out, Dop @ inp) diff --git a/requirements-dev.txt b/requirements-dev.txt index 2fca40dd..8eb9f87d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,4 @@ -numpy>=1.21.0 +numpy>=1.21.0,<2 scipy>=1.11.0 jax numba @@ -28,3 +28,4 @@ isort black flake8 mypy +pytensor diff --git a/requirements-doc.txt b/requirements-doc.txt index 2e62dea7..e13ad06e 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -3,7 +3,7 @@ # same reason, we force devito==4.8.7 as later versions of devito # require numpy>=2.0.0 numpy>=1.21.0,<2.0.0 -scipy>=1.11.0 +scipy>=1.11.0,<1.13 jax --extra-index-url https://download.pytorch.org/whl/cpu torch>=1.2.0 @@ -34,3 +34,5 @@ isort black flake8 mypy +pytensor +pymc diff --git a/requirements-torch.txt b/requirements-torch.txt index f2c3b105..3f94f19f 100644 --- a/requirements-torch.txt +++ b/requirements-torch.txt @@ -1,2 +1,2 @@ --index-url https://download.pytorch.org/whl/cpu -torch>=1.2.0 +torch>=1.2.0,<2.5 diff --git a/tutorials/torchop.py b/tutorials/torchop.py index 9e73d7b3..a18ff2ef 100755 --- a/tutorials/torchop.py +++ b/tutorials/torchop.py @@ -14,6 +14,7 @@ modelling operators. """ + import matplotlib.pyplot as plt import numpy as np import torch @@ -30,24 +31,24 @@ # In this example we consider a simple multidimensional functional: # # .. math:: -# \mathbf{y} = \mathbf{A} sin(\mathbf{x}) +# \mathbf{y} = \mathbf{A} \sin(\mathbf{x}) # # and we use AD to compute the gradient with respect to the input vector # evaluated at :math:`\mathbf{x}=\mathbf{x}_0` : -# :math:`\mathbf{g} = d\mathbf{y} / d\mathbf{x} |_{\mathbf{x}=\mathbf{x}_0}`. +# :math:`\mathbf{g} = \partial\mathbf{y} / \partial\mathbf{x} |_{\mathbf{x}=\mathbf{x}_0}`. # # Let's start by defining the Jacobian: # # .. math:: # \textbf{J} = \begin{bmatrix} -# dy_1 / dx_1 & ... & dy_1 / dx_M \\ -# ... & ... & ... \\ -# dy_N / dx_1 & ... & dy_N / dx_M +# \frac{\partial y_1}{\partial x_1} & \cdots & \frac{\partial y_1}{\partial x_M} \\ +# \vdots & \ddots & \vdots \\ +# \frac{\partial y_N}{\partial x_1} & \cdots & \frac{\partial y_N}{\partial x_M} # \end{bmatrix} = \begin{bmatrix} -# a_{11} cos(x_1) & ... & a_{1M} cos(x_M) \\ -# ... & ... & ... \\ -# a_{N1} cos(x_1) & ... & a_{NM} cos(x_M) -# \end{bmatrix} = \textbf{A} cos(\mathbf{x}) +# a_{11} \cos(x_1) & \cdots & a_{1M} \cos(x_M) \\ +# \vdots & \ddots & \vdots \\ +# a_{N1} \cos(x_1) & \cdots & a_{NM} \cos(x_M) +# \end{bmatrix} = \textbf{A} \cos(\mathbf{x}) # # Since both input and output are multidimensional, # PyTorch ``backward`` actually computes the product between the transposed