Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: jax integration #590

Merged
merged 25 commits into from
Aug 9, 2024
Merged

Feature: jax integration #590

merged 25 commits into from
Aug 9, 2024

Conversation

mrava87
Copy link
Collaborator

@mrava87 mrava87 commented Jul 3, 2024

Motivation

This PR introduces a new backend in PyLops to enable using JAX arrays.

As a by-product of JAX-enabled operators, we inherit JAX features like jit, automatic differentiation, and automatic vectorization.

Highlights

  • Created new JaxOperator
  • Modified most of the operators and methods in LinearOperator to enable JAX integration
  • Cleaned up backend module with new logic to detect whether np,cp, or jnp methods should be used based on the input type
  • Added a new tutorial named jaxop
  • Revamped gpu.rst documentation page

@mrava87 mrava87 mentioned this pull request Jul 3, 2024
mrava87 added 3 commits July 3, 2024 22:11
Since sliding and patching use sliding_window_view that is currently
not available in JAX, we will not support them in the jax backend
for the moment.
Copy link

@VascoSch92 VascoSch92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, great job :-)

I left just some nit ;-)

pylops/avo/poststack.py Outdated Show resolved Hide resolved
pylops/jaxoperator.py Outdated Show resolved Hide resolved
pylops/jaxoperator.py Outdated Show resolved Hide resolved
pylops/jaxoperator.py Outdated Show resolved Hide resolved
pylops/signalprocessing/fredholm1.py Outdated Show resolved Hide resolved
@mrava87
Copy link
Collaborator Author

mrava87 commented Jul 4, 2024

Hey, great job :-)

I left just some nit ;-)

Thanks!

@mrava87 mrava87 requested a review from cako July 20, 2024 19:01
@mrava87
Copy link
Collaborator Author

mrava87 commented Jul 21, 2024

@cako you may want to look at this as supporting document https://github.com/PyLops/pylops_notebooks/blob/master/developement-cupy/Timing_CupyJAX.ipynb. It contains timing for most of the methods ported to the Jax backend and a comparison with numpy and cupy

pylops/jaxoperator.py Outdated Show resolved Hide resolved
pylops/jaxoperator.py Outdated Show resolved Hide resolved
cako added 4 commits August 4, 2024 20:53
* Uses positional arguments instead of `n` as int | tuple, which is the correct usage with `np.random.randn`
* Corrects input/output types
Copy link
Collaborator

@cako cako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrava87 nicely done! Going to leave it as approved, but please have a look at some of the comments and my commits.

By the way I ran the notebook... seems like Jax is generally slower than CuPy? Am I reading this wrong?

image

@mrava87
Copy link
Collaborator Author

mrava87 commented Aug 9, 2024

@mrava87 nicely done! Going to leave it as approved, but please have a look at some of the comments and my commits.

By the way I ran the notebook... seems like Jax is generally slower than CuPy? Am I reading this wrong?

image

This is also what I see when running this both locally and on colab... my guess/suspicion is that when the operator has a limited number of steps all calling np/cp, cupy is already very well optimized so the jit of jax does not really do much... and for some reason the equivalent jax.numpy methods are apparently slower... in one case, where the operator matvec/rmatvec has a for...loop (NonStationaryConvolve1D), then jax seems to shine...

I read a lot about jax being very optimized for GPUs/TPUs so this was also an exercise to compare it with cupy, but so far what I observe is somehow that cupy is better ;)

@mrava87 mrava87 merged commit 21e590b into PyLops:dev Aug 9, 2024
13 checks passed
@mrava87 mrava87 deleted the feature-jax branch August 9, 2024 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants