Releases: jax-ml/jax
JAX release v0.4.33
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of libtpu-nightly
.
This release also fixes an inaccurate result for F64 tanh on CPU (#23590).
Jaxlib release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
JAX release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
Jaxlib release v0.4.31
jaxlib-v0.4.31 jaxlib version 0.4.31
JAX release v0.4.31
jax-v0.4.31 jax version 0.4.31
Jaxlib release v0.4.30
jaxlib-v0.4.30 jaxlib version 0.4.30
Jax release v0.4.30
jax-v0.4.30 jax version 0.4.30
Jaxlib release v0.4.29
-
Bug fixes
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403). - Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(openxla/xla#13301). - Fixes a compiler crash on GPU (#21396).
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
-
Deprecations
jax.tree.map(f, None, non-None)
now emits aDeprecationWarning
, and will
raise an error in a future version of jax.None
is only a tree-prefix of
itself. To preserve the current behavior, you can askjax.tree.map
to
treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.
JAX v0.4.29
-
Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.pip install jax[cuda12]
). - JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
jax.experimental.export
API. It is not possible anymore to use
from jax.experimental.export import export
, and instead you should use
from jax.experimental import export
.
The removed functionality has been deprecated since 0.4.24.
- We anticipate that this will be the last release of JAX and jaxlib
-
Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please use
jax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed as
jax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
.
The old names will be removed after 3 months.- Removed a number of previously-deprecated APIs:
- from {mod}
jax.core
:non_negative_dim
,DimSize
,Shape
- from {mod}
jax.lax
:tie_in
- from {mod}
jax.nn
:normalize
- from {mod}
jax.interpreters.xla
:backend_specific_translations
,
translations
,register_translation
,xla_destructure
,
TranslationRule
,TranslationContext
,XlaOp
.
- from {mod}
- The
tol
argument of {func}jax.numpy.linalg.matrix_rank
is being
deprecated and will soon be removed. Usertol
instead. - The
rcond
argument of {func}jax.numpy.linalg.pinv
is being
deprecated and will soon be removed. Usertol
instead. - The deprecated
jax.config
submodule has been removed. To configure JAX
useimport jax
and then reference the config object viajax.config
. - {mod}
jax.random
APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}jax.vmap
in such cases.
-
New Functionality
- Added {func}
jax.experimental.Exported.in_shardings_jax
to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in theExported
objects.
- Added {func}
jaxlib v0.4.28
-
Bug fixes
- Fixes a memory corruption bug in the type name of Array and JIT Python
objects in Python 3.10 or earlier. - Fixed a warning
'+ptx84' is not a recognized feature for this target
under CUDA 12.4. - Fixed a slow compilation problem on CPU.
- Fixes a memory corruption bug in the type name of Array and JIT Python
-
Changes
- The Windows build is now built with Clang instead of MSVC.