Skip to content

Commit

Permalink
Finalize deprecation of XLACompatibleSharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681156145
  • Loading branch information
Jake VanderPlas authored and Google-ML-Automation committed Oct 1, 2024
1 parent 6ded1be commit 49ad220
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
is only a tree-prefix of itself. To preserve the current behavior, you can
ask `jax.tree.map` to treat `None` as a leaf value by writing:
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
* `jax.sharding.XLACompatibleSharding` has been removed. Please use
`jax.sharding.Sharding`.

* Bug fixes
* Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs
Expand Down
19 changes: 7 additions & 12 deletions jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,16 @@
from jax._src.mesh import AbstractMesh

_deprecations = {
# Added Jun 4, 2024.
# Finalized 2024-10-01; remove after 2025-01-01.
"XLACompatibleSharding": (
(
"jax.sharding.XLACompatibleSharding is deprecated. Use"
" jax.sharding.Sharding instead."
"jax.sharding.XLACompatibleSharding was removed in JAX v0.4.34. "
"Use jax.sharding.Sharding instead."
),
_deprecated_XLACompatibleSharding,
None,
)
}

import typing
if typing.TYPE_CHECKING:
XLACompatibleSharding = _deprecated_XLACompatibleSharding
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr

0 comments on commit 49ad220

Please sign in to comment.