Skip to content

Commit

Permalink
Merge pull request #24048 from jakevdp:packbits-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681093674
  • Loading branch information
Google-ML-Automation committed Oct 1, 2024
2 parents 23c46ff + ae374e0 commit 9ad7e2e
Showing 1 changed file with 118 additions and 5 deletions.
123 changes: 118 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9889,11 +9889,64 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array:
return moveaxis(a, axis, start)


@util.implements(np.packbits)
@partial(jit, static_argnames=('axis', 'bitorder'))
def packbits(
a: ArrayLike, axis: int | None = None, bitorder: str = "big"
) -> Array:
def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array:
"""Pack array of bits into a uint8 array.
JAX implementation of :func:`numpy.packbits`
Args:
a: N-dimensional array of bits to pack.
axis: optional axis along which to pack bits. If not specified, ``a`` will
be flattened.
bitorder: ``"big"`` (default) or ``"little"``: specify whether the bit order
is big-endian or little-endian.
Returns:
A uint8 array of packed values.
See also:
- :func:`jax.numpy.unpackbits`: inverse of ``packbits``.
Examples:
Packing bits in one dimension:
>>> bits = jnp.array([0, 0, 0, 0, 0, 1, 1, 1])
>>> jnp.packbits(bits)
Array([7], dtype=uint8)
>>> 0b00000111 # equivalent bit-wise representation:
7
Optionally specifying little-endian convention:
>>> jnp.packbits(bits, bitorder="little")
Array([224], dtype=uint8)
>>> 0b11100000 # equivalent bit-wise representation
224
If the number of bits is not a multiple of 8, it will be right-padded
with zeros:
>>> jnp.packbits(jnp.array([1, 0, 1]))
Array([160], dtype=uint8)
>>> jnp.packbits(jnp.array([1, 0, 1, 0, 0, 0, 0, 0]))
Array([160], dtype=uint8)
For a multi-dimensional input, bits may be packed along a specified axis:
>>> a = jnp.array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
... [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]])
>>> vals = jnp.packbits(a, axis=1)
>>> vals
Array([[212, 150],
[ 69, 207]], dtype=uint8)
The inverse of ``packbits`` is provided by :func:`~jax.numpy.unpackbits`:
>>> jnp.unpackbits(vals, axis=1)
Array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0],
[0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]], dtype=uint8)
"""
util.check_arraylike("packbits", a)
arr = asarray(a)
if not (issubdtype(arr.dtype, integer) or issubdtype(arr.dtype, bool_)):
Expand All @@ -9920,14 +9973,74 @@ def packbits(
return swapaxes(packed, axis, -1)


@util.implements(np.unpackbits)
@partial(jit, static_argnames=('axis', 'count', 'bitorder'))
def unpackbits(
a: ArrayLike,
axis: int | None = None,
count: int | None = None,
bitorder: str = "big",
) -> Array:
"""Unpack the bits in a uint8 array.
JAX implementation of :func:`numpy.unpackbits`.
Args:
a: N-dimensional array of type ``uint8``.
axis: optional axis along which to unpack. If not specified, ``a`` will
be flattened
count: specify the number of bits to unpack (if positive) or the number
of bits to trim from the end (if negative).
bitorder: ``"big"`` (default) or ``"little"``: specify whether the bit order
is big-endian or little-endian.
Returns:
a uint8 array of unpacked bits.
See also:
- :func:`jax.numpy.packbits`: this inverse of ``unpackbits``.
Examples:
Unpacking bits from a scalar:
>>> jnp.unpackbits(jnp.uint8(27)) # big-endian by default
Array([0, 0, 0, 1, 1, 0, 1, 1], dtype=uint8)
>>> jnp.unpackbits(jnp.uint8(27), bitorder="little")
Array([1, 1, 0, 1, 1, 0, 0, 0], dtype=uint8)
Compare this to the Python binary representation:
>>> 0b00011011
27
Unpacking bits along an axis:
>>> vals = jnp.array([[154],
... [ 49]], dtype='uint8')
>>> bits = jnp.unpackbits(vals, axis=1)
>>> bits
Array([[1, 0, 0, 1, 1, 0, 1, 0],
[0, 0, 1, 1, 0, 0, 0, 1]], dtype=uint8)
Using :func:`~jax.numpy.packbits` to invert this:
>>> jnp.packbits(bits, axis=1)
Array([[154],
[ 49]], dtype=uint8)
The ``count`` keyword lets ``unpackbits`` serve as an inverse of ``packbits``
in cases where not all bits are present:
>>> bits = jnp.array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1]) # 11 bits
>>> vals = jnp.packbits(bits)
>>> vals
Array([219, 96], dtype=uint8)
>>> jnp.unpackbits(vals) # 16 zero-padded bits
Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], dtype=uint8)
>>> jnp.unpackbits(vals, count=11) # specify 11 output bits
Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)
>>> jnp.unpackbits(vals, count=-5) # specify 5 bits to be trimmed
Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)
"""
util.check_arraylike("unpackbits", a)
arr = asarray(a)
if _dtype(a) != uint8:
Expand Down

0 comments on commit 9ad7e2e

Please sign in to comment.