diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 57eb9dbcd46a..3e28d6bbcbba 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -9889,11 +9889,64 @@ def rollaxis(a: ArrayLike, axis: int, start: int = 0) -> Array: return moveaxis(a, axis, start) -@util.implements(np.packbits) @partial(jit, static_argnames=('axis', 'bitorder')) -def packbits( - a: ArrayLike, axis: int | None = None, bitorder: str = "big" -) -> Array: +def packbits(a: ArrayLike, axis: int | None = None, bitorder: str = "big") -> Array: + """Pack array of bits into a uint8 array. + + JAX implementation of :func:`numpy.packbits` + + Args: + a: N-dimensional array of bits to pack. + axis: optional axis along which to pack bits. If not specified, ``a`` will + be flattened. + bitorder: ``"big"`` (default) or ``"little"``: specify whether the bit order + is big-endian or little-endian. + + Returns: + A uint8 array of packed values. + + See also: + - :func:`jax.numpy.unpackbits`: inverse of ``packbits``. + + Examples: + Packing bits in one dimension: + + >>> bits = jnp.array([0, 0, 0, 0, 0, 1, 1, 1]) + >>> jnp.packbits(bits) + Array([7], dtype=uint8) + >>> 0b00000111 # equivalent bit-wise representation: + 7 + + Optionally specifying little-endian convention: + + >>> jnp.packbits(bits, bitorder="little") + Array([224], dtype=uint8) + >>> 0b11100000 # equivalent bit-wise representation + 224 + + If the number of bits is not a multiple of 8, it will be right-padded + with zeros: + + >>> jnp.packbits(jnp.array([1, 0, 1])) + Array([160], dtype=uint8) + >>> jnp.packbits(jnp.array([1, 0, 1, 0, 0, 0, 0, 0])) + Array([160], dtype=uint8) + + For a multi-dimensional input, bits may be packed along a specified axis: + + >>> a = jnp.array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + ... [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]]) + >>> vals = jnp.packbits(a, axis=1) + >>> vals + Array([[212, 150], + [ 69, 207]], dtype=uint8) + + The inverse of ``packbits`` is provided by :func:`~jax.numpy.unpackbits`: + + >>> jnp.unpackbits(vals, axis=1) + Array([[1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1]], dtype=uint8) + """ util.check_arraylike("packbits", a) arr = asarray(a) if not (issubdtype(arr.dtype, integer) or issubdtype(arr.dtype, bool_)): @@ -9920,7 +9973,6 @@ def packbits( return swapaxes(packed, axis, -1) -@util.implements(np.unpackbits) @partial(jit, static_argnames=('axis', 'count', 'bitorder')) def unpackbits( a: ArrayLike, @@ -9928,6 +9980,67 @@ def unpackbits( count: int | None = None, bitorder: str = "big", ) -> Array: + """Unpack the bits in a uint8 array. + + JAX implementation of :func:`numpy.unpackbits`. + + Args: + a: N-dimensional array of type ``uint8``. + axis: optional axis along which to unpack. If not specified, ``a`` will + be flattened + count: specify the number of bits to unpack (if positive) or the number + of bits to trim from the end (if negative). + bitorder: ``"big"`` (default) or ``"little"``: specify whether the bit order + is big-endian or little-endian. + + Returns: + a uint8 array of unpacked bits. + + See also: + - :func:`jax.numpy.packbits`: this inverse of ``unpackbits``. + + Examples: + Unpacking bits from a scalar: + + >>> jnp.unpackbits(jnp.uint8(27)) # big-endian by default + Array([0, 0, 0, 1, 1, 0, 1, 1], dtype=uint8) + >>> jnp.unpackbits(jnp.uint8(27), bitorder="little") + Array([1, 1, 0, 1, 1, 0, 0, 0], dtype=uint8) + + Compare this to the Python binary representation: + + >>> 0b00011011 + 27 + + Unpacking bits along an axis: + + >>> vals = jnp.array([[154], + ... [ 49]], dtype='uint8') + >>> bits = jnp.unpackbits(vals, axis=1) + >>> bits + Array([[1, 0, 0, 1, 1, 0, 1, 0], + [0, 0, 1, 1, 0, 0, 0, 1]], dtype=uint8) + + Using :func:`~jax.numpy.packbits` to invert this: + + >>> jnp.packbits(bits, axis=1) + Array([[154], + [ 49]], dtype=uint8) + + The ``count`` keyword lets ``unpackbits`` serve as an inverse of ``packbits`` + in cases where not all bits are present: + + >>> bits = jnp.array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1]) # 11 bits + >>> vals = jnp.packbits(bits) + >>> vals + Array([219, 96], dtype=uint8) + >>> jnp.unpackbits(vals) # 16 zero-padded bits + Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], dtype=uint8) + >>> jnp.unpackbits(vals, count=11) # specify 11 output bits + Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8) + >>> jnp.unpackbits(vals, count=-5) # specify 5 bits to be trimmed + Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8) + """ util.check_arraylike("unpackbits", a) arr = asarray(a) if _dtype(a) != uint8: