From 8071089271bc93c62f775e826c3fd0d452f7e7d3 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 29 Jan 2024 03:26:45 -0800 Subject: [PATCH] Explicitly convert jax.numpy.meshgrid outputs to list The return type of several jax.numpy APIs will change from list to tuple in an upcoming JAX version, following a similar change in NumPy 2.0. PiperOrigin-RevId: 602334680 --- dm_pix/_src/augment.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dm_pix/_src/augment.py b/dm_pix/_src/augment.py index d749f61..bd44e02 100644 --- a/dm_pix/_src/augment.py +++ b/dm_pix/_src/augment.py @@ -231,8 +231,9 @@ def elastic_deformation( sigma=sigma, kernel_size=kernel_size) * alpha - meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in single_channel_shape], - indexing="ij") + meshgrid = list( + jnp.meshgrid( + *[jnp.arange(size) for size in single_channel_shape], indexing="ij")) meshgrid[0] += shift_map_i meshgrid[1] += shift_map_j