diff --git a/README.md b/README.md index 4383124f5..07a10d55e 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ def update_rule(param, update): for images, labels in input_dataset: grads = jax.grad(loss_fn_t.apply)(params, images, labels) - params = jax.tree_util.tree_map(update_rule, params, grads) + params = jax.tree.map(update_rule, params, grads) ``` The core of Haiku is `hk.transform`. The `transform` function allows you to @@ -218,9 +218,9 @@ grads = jax.grad(loss_fn_t.apply)(params, images, labels) ### Training The training loop in this example is very simple. One detail to note is the use -of `jax.tree_util.tree_map` to apply the `sgd` function across all matching -entries in `params` and `grads`. The result has the same structure as the -previous `params` and can again be used with `apply`. +of `jax.tree.map` to apply the `sgd` function across all matching entries in +`params` and `grads`. The result has the same structure as the previous `params` +and can again be used with `apply`. ## Installation @@ -412,7 +412,7 @@ params = loss_fn_t.init(rng, sample_image, sample_label) # Replicate params onto all devices. num_devices = jax.local_device_count() -params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params) +params = jax.tree.map(lambda x: np.stack([x] * num_devices), params) def make_superbatch(): """Constructs a superbatch, i.e. one batch of data per device.""" diff --git a/docs/notebooks/basics.ipynb b/docs/notebooks/basics.ipynb index 718c59378..33dcdcd85 100644 --- a/docs/notebooks/basics.ipynb +++ b/docs/notebooks/basics.ipynb @@ -339,7 +339,7 @@ } ], "source": [ - "mutated_params = jax.tree_util.tree_map(lambda x: x+1., params)\n", + "mutated_params = jax.tree.map(lambda x: x+1., params)\n", "print(f'Mutated params \\n : {mutated_params}')\n", "mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)\n", "print(f'Output with mutated params \\n {mutated_output}')" diff --git a/docs/notebooks/build_your_own_haiku.ipynb b/docs/notebooks/build_your_own_haiku.ipynb index 3cf7dbd86..6183c4874 100644 --- a/docs/notebooks/build_your_own_haiku.ipynb +++ b/docs/notebooks/build_your_own_haiku.ipynb @@ -365,7 +365,7 @@ "source": [ "# Make printing parameters a little more readable\n", "def parameter_shapes(params):\n", - " return jax.tree_util.tree_map(lambda p: p.shape, params)\n", + " return jax.tree.map(lambda p: p.shape, params)\n", "\n", "\n", "class Linear:\n", @@ -763,7 +763,7 @@ "@jax.jit\n", "def update(params, x, y):\n", " grads = jax.grad(loss_fn)(params, x, y)\n", - " return jax.tree_util.tree_map(\n", + " return jax.tree.map(\n", " lambda p, g: p - LEARNING_RATE * g, params, grads\n", " )" ] diff --git a/docs/notebooks/non_trainable.ipynb b/docs/notebooks/non_trainable.ipynb index 5708f1985..28f217f57 100644 --- a/docs/notebooks/non_trainable.ipynb +++ b/docs/notebooks/non_trainable.ipynb @@ -268,7 +268,7 @@ " return -jnp.sum(labels * jax.nn.log_softmax(logits)) / labels.shape[0]\n", "\n", "def sgd_step(params, grads, *, lr):\n", - " return jax.tree_util.tree_map(lambda p, g: p - g * lr, params, grads)\n", + " return jax.tree.map(lambda p, g: p - g * lr, params, grads)\n", "\n", "def train_step(trainable_params, non_trainable_params, x, y):\n", " # NOTE: We will only compute gradients wrt `trainable_params`.\n", diff --git a/docs/notebooks/parameter_sharing.ipynb b/docs/notebooks/parameter_sharing.ipynb index c2e6e8837..dbf75edb3 100644 --- a/docs/notebooks/parameter_sharing.ipynb +++ b/docs/notebooks/parameter_sharing.ipynb @@ -66,7 +66,7 @@ "\n", "def parameter_shapes(params):\n", " \"\"\"Make printing parameters a little more readable.\"\"\"\n", - " return jax.tree_util.tree_map(lambda p: p.shape, params)\n", + " return jax.tree.map(lambda p: p.shape, params)\n", "\n", "\n", "def transform_and_print_shapes(fn, x_shape=(2, 3)):\n", @@ -80,8 +80,8 @@ " print(parameter_shapes(params))\n", "\n", "def assert_all_equal(params_1, params_2):\n", - " assert all(jax.tree_util.tree_leaves(\n", - " jax.tree_util.tree_map(lambda a, b: (a == b).all(), params_1, params_2)))" + " assert all(jax.tree.leaves(\n", + " jax.tree.map(lambda a, b: (a == b).all(), params_1, params_2)))" ] }, { diff --git a/docs/notebooks/transforms.ipynb b/docs/notebooks/transforms.ipynb index d7a9e6f2b..d55a48e31 100644 --- a/docs/notebooks/transforms.ipynb +++ b/docs/notebooks/transforms.ipynb @@ -236,7 +236,7 @@ "init, apply = hk.transform(eval_shape_net) \n", "params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))\n", "apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))\n", - "jax.tree_util.tree_map(lambda x: x.shape, params)" + "jax.tree.map(lambda x: x.shape, params)" ] }, { diff --git a/examples/imagenet/dataset.py b/examples/imagenet/dataset.py index 104200b0c..67e824bdf 100644 --- a/examples/imagenet/dataset.py +++ b/examples/imagenet/dataset.py @@ -154,11 +154,12 @@ def cast_fn(batch): def _device_put_sharded(sharded_tree, devices): - leaves, treedef = jax.tree_util.tree_flatten(sharded_tree) + leaves, treedef = jax.tree.flatten(sharded_tree) n = leaves[0].shape[0] - return jax.device_put_sharded([ - jax.tree_util.tree_unflatten(treedef, [l[i] for l in leaves]) - for i in range(n)], devices) + return jax.device_put_sharded( + [jax.tree.unflatten(treedef, [l[i] for l in leaves]) for i in range(n)], + devices, + ) def double_buffer(ds: Iterable[Batch]) -> Iterator[Batch]: diff --git a/examples/imagenet/train.py b/examples/imagenet/train.py index 5d3fac886..71ff0fd3f 100644 --- a/examples/imagenet/train.py +++ b/examples/imagenet/train.py @@ -227,7 +227,7 @@ def evaluate( # Params/state are sharded per-device during training. We just need the copy # from the first device (since we do not pmap evaluation at the moment). - params, state = jax.tree_util.tree_map(lambda x: x[0], (params, state)) + params, state = jax.tree.map(lambda x: x[0], (params, state)) test_dataset = dataset.load(split, is_training=False, batch_dims=[FLAGS.eval_batch_size], @@ -323,8 +323,9 @@ def main(argv): # Log progress at fixed intervals. if step_num and step_num % log_every == 0: - train_scalars = jax.tree_util.tree_map( - lambda v: np.mean(v).item(), jax.device_get(train_scalars)) + train_scalars = jax.tree.map( + lambda v: np.mean(v).item(), jax.device_get(train_scalars) + ) logging.info('[Train %s/%s] %s', step_num, num_train_steps, train_scalars) diff --git a/examples/impala/actor.py b/examples/impala/actor.py index 3db086099..5755810fd 100644 --- a/examples/impala/actor.py +++ b/examples/impala/actor.py @@ -84,7 +84,7 @@ def unroll(self, rng_key, frame_count: int, params: hk.Params, # Pack the trajectory and reset parent state. trajectory = jax.device_get(self._traj) - trajectory = jax.tree_util.tree_map(lambda *xs: np.stack(xs), *trajectory) + trajectory = jax.tree.map(lambda *xs: np.stack(xs), *trajectory) self._timestep = timestep self._agent_state = agent_state # Keep the bootstrap timestep for next trajectory. diff --git a/examples/impala/agent.py b/examples/impala/agent.py index a406fa7ab..59bafda6d 100644 --- a/examples/impala/agent.py +++ b/examples/impala/agent.py @@ -61,11 +61,11 @@ def __init__(self, num_actions: int, obs_spec: Nest, @functools.partial(jax.jit, static_argnums=0) def initial_params(self, rng_key): """Initializes the agent params given the RNG key.""" - dummy_inputs = jax.tree_util.tree_map( - lambda t: np.zeros(t.shape, t.dtype), self._obs_spec) + dummy_inputs = jax.tree.map( + lambda t: np.zeros(t.shape, t.dtype), self._obs_spec + ) dummy_inputs = util.preprocess_step(dm_env.restart(dummy_inputs)) - dummy_inputs = jax.tree_util.tree_map( - lambda t: t[None, None, ...], dummy_inputs) + dummy_inputs = jax.tree.map(lambda t: t[None, None, ...], dummy_inputs) return self._init_fn(rng_key, dummy_inputs, self.initial_state(1)) @functools.partial(jax.jit, static_argnums=(0, 1)) @@ -84,15 +84,13 @@ def step( ) -> tuple[AgentOutput, Nest]: """For a given single-step, unbatched timestep, output the chosen action.""" # Pad timestep, state to be [T, B, ...] and [B, ...] respectively. - timestep = jax.tree_util.tree_map(lambda t: t[None, None, ...], timestep) - state = jax.tree_util.tree_map(lambda t: t[None, ...], state) + timestep = jax.tree.map(lambda t: t[None, None, ...], timestep) + state = jax.tree.map(lambda t: t[None, ...], state) net_out, next_state = self._apply_fn(params, timestep, state) # Remove the padding from above. - net_out = jax.tree_util.tree_map( - lambda t: jnp.squeeze(t, axis=(0, 1)), net_out) - next_state = jax.tree_util.tree_map( - lambda t: jnp.squeeze(t, axis=0), next_state) + net_out = jax.tree.map(lambda t: jnp.squeeze(t, axis=(0, 1)), net_out) + next_state = jax.tree.map(lambda t: jnp.squeeze(t, axis=0), next_state) # Sample an action and return. action = hk.multinomial(rng_key, net_out.policy_logits, num_samples=1) action = jnp.squeeze(action, axis=-1) diff --git a/examples/impala/haiku_nets.py b/examples/impala/haiku_nets.py index ca8119dd4..899be1a34 100644 --- a/examples/impala/haiku_nets.py +++ b/examples/impala/haiku_nets.py @@ -144,7 +144,7 @@ def initial_state(self, batch_size): return self._core.initial_state(batch_size) def __call__(self, x: dm_env.TimeStep, state): - x = jax.tree_util.tree_map(lambda t: t[None, ...], x) + x = jax.tree.map(lambda t: t[None, ...], x) return self.unroll(x, state) def unroll(self, x, state): diff --git a/examples/impala/learner.py b/examples/impala/learner.py index e3a42c9dc..6be1330e2 100644 --- a/examples/impala/learner.py +++ b/examples/impala/learner.py @@ -93,18 +93,17 @@ def _loss( trajectories: util.Transition, ) -> tuple[jax.Array, dict[str, jax.Array]]: """Compute vtrace-based actor-critic loss.""" - initial_state = jax.tree_util.tree_map( - lambda t: t[0], trajectories.agent_state) + initial_state = jax.tree.map(lambda t: t[0], trajectories.agent_state) learner_outputs = self._agent.unroll(theta, trajectories.timestep, initial_state) v_t = learner_outputs.values[1:] # Remove bootstrap timestep from non-timesteps. - _, actor_out, _ = jax.tree_util.tree_map(lambda t: t[:-1], trajectories) - learner_outputs = jax.tree_util.tree_map(lambda t: t[:-1], learner_outputs) + _, actor_out, _ = jax.tree.map(lambda t: t[:-1], trajectories) + learner_outputs = jax.tree.map(lambda t: t[:-1], learner_outputs) v_tm1 = learner_outputs.values # Get the discount, reward, step_type from the *next* timestep. - timestep = jax.tree_util.tree_map(lambda t: t[1:], trajectories.timestep) + timestep = jax.tree.map(lambda t: t[1:], trajectories.timestep) discounts = timestep.discount * self._discount_factor rewards = timestep.reward if self._max_abs_reward > 0: @@ -179,8 +178,7 @@ def host_to_device_worker(self): assert len(batch) == self._batch_size # Prepare for consumption, then put batch onto device. - stacked_batch = jax.tree_util.tree_map( - lambda *xs: np.stack(xs, axis=1), *batch) + stacked_batch = jax.tree.map(lambda *xs: np.stack(xs, axis=1), *batch) self._device_q.put(jax.device_put(stacked_batch)) # Clean out the built-up batch. diff --git a/examples/impala_lite.py b/examples/impala_lite.py index 9ceb89c6a..2d59fd6f0 100644 --- a/examples/impala_lite.py +++ b/examples/impala_lite.py @@ -80,7 +80,7 @@ def step( timestep: dm_env.TimeStep, ) -> tuple[jax.Array, jax.Array]: """Steps on a single observation.""" - timestep = jax.tree_util.tree_map(lambda t: jnp.expand_dims(t, 0), timestep) + timestep = jax.tree.map(lambda t: jnp.expand_dims(t, 0), timestep) logits, _ = self._net(params, timestep) logits = jnp.squeeze(logits, axis=0) action = hk.multinomial(rng, logits, num_samples=1) @@ -102,13 +102,12 @@ def loss(self, params: hk.Params, trajs: Transition) -> jax.Array: baseline_tp1 = baseline_with_bootstrap[1:] # Remove bootstrap timestep from non-observations. - _, actions, behavior_logits = jax.tree_util.tree_map( - lambda t: t[:-1], trajs) + _, actions, behavior_logits = jax.tree.map(lambda t: t[:-1], trajs) learner_logits = learner_logits[:-1] # Shift step_type/reward/discount back by one, so that actions match the # timesteps caused by the action. - timestep = jax.tree_util.tree_map(lambda t: t[1:], trajs.timestep) + timestep = jax.tree.map(lambda t: t[1:], trajs.timestep) discount = timestep.discount * self._discount # The step is uninteresting if we transitioned LAST -> FIRST. mask = jnp.not_equal(timestep.step_type, int(dm_env.StepType.FIRST)) @@ -149,7 +148,7 @@ def preprocess_step(ts: dm_env.TimeStep) -> dm_env.TimeStep: ts = ts._replace(reward=0.) if ts.discount is None: ts = ts._replace(discount=1.) - return jax.tree_util.tree_map(np.asarray, ts) + return jax.tree.map(np.asarray, ts) def run_actor( @@ -180,7 +179,7 @@ def run_actor( state.reward) # Stack and send the trajectory. - stacked_traj = jax.tree_util.tree_map(lambda *ts: np.stack(ts), *traj) + stacked_traj = jax.tree.map(lambda *ts: np.stack(ts), *traj) enqueue_traj(stacked_traj) # Reset the trajectory, keeping the last timestep. traj = traj[-1:] @@ -222,8 +221,7 @@ def run(*, trajectories_per_actor, num_actors, unroll_len): # Initialize the optimizer state. sample_ts = env.reset() sample_ts = preprocess_step(sample_ts) - ts_with_batch = jax.tree_util.tree_map( - lambda t: np.expand_dims(t, 0), sample_ts) + ts_with_batch = jax.tree.map(lambda t: np.expand_dims(t, 0), sample_ts) params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch) opt_state = opt.init(params) @@ -236,7 +234,7 @@ def dequeue(): batch = [] for _ in range(batch_size): batch.append(q.get()) - batch = jax.tree_util.tree_map(lambda *ts: np.stack(ts, axis=1), *batch) + batch = jax.tree.map(lambda *ts: np.stack(ts, axis=1), *batch) return jax.device_put(batch) # Start the actors. diff --git a/examples/mnist.py b/examples/mnist.py index 0df34d0f9..29b1cf2a2 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -80,7 +80,8 @@ def loss(params: hk.Params, batch: Batch) -> jax.Array: labels = jax.nn.one_hot(batch.label, NUM_CLASSES) l2_regulariser = 0.5 * sum( - jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params)) + jnp.sum(jnp.square(p)) for p in jax.tree.leaves(params) + ) log_likelihood = jnp.sum(labels * jax.nn.log_softmax(logits)) return -log_likelihood / batch_size + 1e-4 * l2_regulariser diff --git a/examples/mnist_gan.ipynb b/examples/mnist_gan.ipynb index aceb83c86..2ec949ef9 100644 --- a/examples/mnist_gan.ipynb +++ b/examples/mnist_gan.ipynb @@ -203,7 +203,7 @@ "\n", "\n", "def tree_shape(xs):\n", - " return jax.tree_util.tree_map(lambda x: x.shape, xs)\n", + " return jax.tree.map(lambda x: x.shape, xs)\n", "\n", "\n", "def sparse_softmax_cross_entropy(logits, labels):\n", diff --git a/examples/mnist_pruning.py b/examples/mnist_pruning.py index 2bb95e5b4..e07ae469f 100644 --- a/examples/mnist_pruning.py +++ b/examples/mnist_pruning.py @@ -137,8 +137,7 @@ def apply_mask(params: hk.Params, masks: Sequence[hk.Params], module_sparsity, params) pruned_params = [] for value, mask in zip(params_to_prune, masks): - pruned_params.append( - jax.tree_util.tree_map(lambda x, y: x * y, value, mask)) + pruned_params.append(jax.tree.map(lambda x, y: x * y, value, mask)) params = hk.data_structures.merge(*pruned_params, params_no_prune) return params @@ -155,7 +154,7 @@ def map_fn(x: jax.Array, sparsity: float) -> jax.Array: for tree, sparsity in zip(params_to_prune, sparsities): map_fn_sparsity = functools.partial(map_fn, sparsity=sparsity) - mask = jax.tree_util.tree_map(map_fn_sparsity, tree) + mask = jax.tree.map(map_fn_sparsity, tree) masks.append(mask) return masks @@ -163,10 +162,9 @@ def map_fn(x: jax.Array, sparsity: float) -> jax.Array: @jax.jit def get_sparsity(params: hk.Params): """Calculate the total sparsity and tensor-wise sparsity of params.""" - total_params = sum(jnp.size(x) for x in jax.tree_util.tree_leaves(params)) - total_nnz = sum(jnp.sum(x != 0.) for x in jax.tree_util.tree_leaves(params)) - leaf_sparsity = jax.tree_util.tree_map( - lambda x: jnp.sum(x == 0) / jnp.size(x), params) + total_params = sum(jnp.size(x) for x in jax.tree.leaves(params)) + total_nnz = sum(jnp.sum(x != 0.0) for x in jax.tree.leaves(params)) + leaf_sparsity = jax.tree.map(lambda x: jnp.sum(x == 0) / jnp.size(x), params) return total_params, total_nnz, leaf_sparsity @@ -221,8 +219,7 @@ def loss(params: hk.Params, batch: Batch) -> jax.Array: logits = net.apply(params, batch) labels = jax.nn.one_hot(batch["label"], 10) - l2_loss = 0.5 * sum( - jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params)) + l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree.leaves(params)) softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits)) softmax_xent /= labels.shape[0] diff --git a/examples/vqvae_example.ipynb b/examples/vqvae_example.ipynb index c93a017a4..ff465bf76 100644 --- a/examples/vqvae_example.ipynb +++ b/examples/vqvae_example.ipynb @@ -77,7 +77,7 @@ "source": [ "cifar10 = tfds.as_numpy(tfds.load(\"cifar10\", split=\"train+test\", batch_size=-1))\n", "del cifar10[\"id\"], cifar10[\"label\"]\n", - "jax.tree_util.tree_map(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)" + "jax.tree.map(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)" ] }, { @@ -101,9 +101,9 @@ }, "outputs": [], "source": [ - "train_data_dict = jax.tree_util.tree_map(lambda x: x[:40000], cifar10)\n", - "valid_data_dict = jax.tree_util.tree_map(lambda x: x[40000:50000], cifar10)\n", - "test_data_dict = jax.tree_util.tree_map(lambda x: x[50000:], cifar10)" + "train_data_dict = jax.tree.map(lambda x: x[:40000], cifar10)\n", + "valid_data_dict = jax.tree.map(lambda x: x[40000:50000], cifar10)\n", + "test_data_dict = jax.tree.map(lambda x: x[50000:], cifar10)" ] }, { diff --git a/haiku/_src/base.py b/haiku/_src/base.py index c23579540..953d1775c 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -227,7 +227,7 @@ def new_context( ... mod = hk.nets.MLP([300, 100, 10]) ... y1 = mod(jnp.ones([1, 1])) - >>> assert len(jax.tree_util.tree_leaves(ctx.collect_params())) == 6 + >>> assert len(jax.tree.leaves(ctx.collect_params())) == 6 >>> with ctx: ... y2 = mod(jnp.ones([1, 1])) @@ -379,7 +379,7 @@ def get_params() -> Params: """Returns the parameters for the current :func:`transform`. >>> def report(when): - ... shapes = jax.tree_util.tree_map(jnp.shape, hk.get_params()) + ... shapes = jax.tree.map(jnp.shape, hk.get_params()) ... print(f'{when}: {shapes}') >>> def f(x): ... report('Before call') @@ -425,7 +425,7 @@ def get_initial_state() -> State: Example: >>> def report(when): - ... state = jax.tree_util.tree_map(int, hk.get_initial_state()) + ... state = jax.tree.map(int, hk.get_initial_state()) ... print(f'{when}: {state}') >>> def f(): ... report('Before get_state') @@ -471,7 +471,7 @@ def get_current_state() -> State: Example: >>> def report(when): - ... state = jax.tree_util.tree_map(int, hk.get_current_state()) + ... state = jax.tree.map(int, hk.get_current_state()) ... print(f'{when}: {state}') >>> def f(): ... report('Before get_state') @@ -1374,8 +1374,8 @@ def set_state(name: str, value): state = frame.state[bundle_name] if state_setter_stack: - shape = jax.tree_util.tree_map(maybe_shape, value) - dtype = jax.tree_util.tree_map(maybe_dtype, value) + shape = jax.tree.map(maybe_shape, value) + dtype = jax.tree.map(maybe_dtype, value) fq_name = bundle_name + "/" + name context = SetterContext(full_name=fq_name, module=current_module(), original_dtype=dtype, diff --git a/haiku/_src/base_test.py b/haiku/_src/base_test.py index 7ffb54138..f23336ff1 100644 --- a/haiku/_src/base_test.py +++ b/haiku/_src/base_test.py @@ -372,7 +372,7 @@ def dtype_recast_getter(next_getter, value, context): with custom_create_x(dtype_cast_creator), \ custom_get_x(dtype_recast_getter): value = get_x("w", [], jnp.bfloat16, jnp.ones) - orig_value = jax.tree_util.tree_leaves(getattr(ctx, collect_x)())[0] + orig_value = jax.tree.leaves(getattr(ctx, collect_x)())[0] assert value.dtype == jnp.bfloat16 assert orig_value.dtype == jnp.float32 @@ -536,7 +536,7 @@ def my_setter(next_setter, value, context): def test_setter_tree(self): witness = [] x = {"a": jnp.ones([]), "b": jnp.zeros([123])} - y = jax.tree_util.tree_map(lambda x: x + 1, x) + y = jax.tree.map(lambda x: x + 1, x) def my_setter(next_setter, value, ctx): self.assertIs(value, x) @@ -646,8 +646,7 @@ def test_prng_reserve(self): s.reserve(10) hk_keys = tuple(next(s) for _ in range(10)) jax_keys = tuple(jax.random.split(test_utils.clone(k), num=11)[1:]) - jax.tree_util.tree_map( - np.testing.assert_array_equal, hk_keys, jax_keys) + jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys) def test_prng_reserve_twice(self): k = jax.random.PRNGKey(42) @@ -658,16 +657,14 @@ def test_prng_reserve_twice(self): k, subkey1, subkey2 = tuple(jax.random.split(test_utils.clone(k), num=3)) _, subkey3, subkey4 = tuple(jax.random.split(k, num=3)) jax_keys = (subkey1, subkey2, subkey3, subkey4) - jax.tree_util.tree_map( - np.testing.assert_array_equal, hk_keys, jax_keys) + jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys) def test_prng_sequence_split(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) hk_keys = s.take(10) jax_keys = tuple(jax.random.split(test_utils.clone(k), num=11)[1:]) - jax.tree_util.tree_map( - np.testing.assert_array_equal, hk_keys, jax_keys) + jax.tree.map(np.testing.assert_array_equal, hk_keys, jax_keys) @parameterized.parameters(42, 28) def test_with_rng(self, seed): diff --git a/haiku/_src/basic.py b/haiku/_src/basic.py index 6643ab3b5..1573da0b1 100644 --- a/haiku/_src/basic.py +++ b/haiku/_src/basic.py @@ -195,10 +195,10 @@ def ndim_at_least(x, num_dims): def arbitrary_mergeable_leaf(min_num_dims, args, kwargs): - for a in jax.tree_util.tree_leaves(args): + for a in jax.tree.leaves(args): if ndim_at_least(a, min_num_dims): return a - for k in jax.tree_util.tree_leaves(kwargs): + for k in jax.tree.leaves(kwargs): if ndim_at_least(k, min_num_dims): return k # Couldn't find a satisfactory leaf. @@ -258,10 +258,10 @@ def __call__(self, *args, **kwargs): merge = lambda x: merge_leading_dims(x, self.num_dims) split = lambda x: split_leading_dim(x, example.shape[:self.num_dims]) - args = jax.tree_util.tree_map(merge, args) - kwargs = jax.tree_util.tree_map(merge, kwargs) + args = jax.tree.map(merge, args) + kwargs = jax.tree.map(merge, kwargs) outputs = self._f(*args, **kwargs) - return jax.tree_util.tree_map(split, outputs) + return jax.tree.map(split, outputs) def expand_apply(f, axis=0): @@ -269,9 +269,9 @@ def expand_apply(f, axis=0): Syntactic sugar for:: - ins = jax.tree_util.tree_map(lambda t: np.expand_dims(t, axis=axis), ins) + ins = jax.tree.map(lambda t: np.expand_dims(t, axis=axis), ins) out = f(ins) - out = jax.tree_util.tree_map(lambda t: np.squeeze(t, axis=axis), out) + out = jax.tree.map(lambda t: np.squeeze(t, axis=axis), out) This may be useful for applying a function built for ``[Time, Batch, ...]`` arrays to a single timestep. @@ -289,10 +289,10 @@ def expand_apply(f, axis=0): @functools.wraps(f) def wrapper(*args, **kwargs): expand = lambda t: jnp.expand_dims(t, axis=axis) - args = jax.tree_util.tree_map(expand, args) - kwargs = jax.tree_util.tree_map(expand, kwargs) + args = jax.tree.map(expand, args) + kwargs = jax.tree.map(expand, kwargs) outputs = f(*args, **kwargs) - return jax.tree_util.tree_map(lambda t: jnp.squeeze(t, axis=axis), outputs) + return jax.tree.map(lambda t: jnp.squeeze(t, axis=axis), outputs) return wrapper diff --git a/haiku/_src/data_structures.py b/haiku/_src/data_structures.py index aef7cfdbf..1c3db70bc 100644 --- a/haiku/_src/data_structures.py +++ b/haiku/_src/data_structures.py @@ -140,8 +140,8 @@ def to_haiku_dict(structure: Mapping[K, V]) -> MutableMapping[K, V]: def _copy_structure(tree): """Returns a copy of the given structure.""" - leaves, treedef = jax.tree_util.tree_flatten(tree) - return jax.tree_util.tree_unflatten(treedef, leaves) + leaves, treedef = jax.tree.flatten(tree) + return jax.tree.unflatten(treedef, leaves) def _to_dict_recurse(value: Any): @@ -196,14 +196,14 @@ def __init__(self, *args, **kwargs): mapping = None # When unflattening we cannot assume that the leaves are not pytrees (for - # example: `jax.tree_util.tree_map(list, my_map)` would pass a list of + # example: `jax.tree.map(list, my_map)` would pass a list of # lists in as leaves). if not jax.tree_util.all_leaves(leaves): - mapping = jax.tree_util.tree_unflatten(structure, leaves) - leaves, structure = jax.tree_util.tree_flatten(mapping) + mapping = jax.tree.unflatten(structure, leaves) + leaves, structure = jax.tree.flatten(mapping) else: mapping = dict(*args, **kwargs) - leaves, structure = jax.tree_util.tree_flatten(mapping) + leaves, structure = jax.tree.flatten(mapping) self._structure = structure self._leaves = tuple(leaves) @@ -211,8 +211,7 @@ def __init__(self, *args, **kwargs): def _to_mapping(self) -> Mapping[K, V]: if self._mapping is None: - self._mapping = jax.tree_util.tree_unflatten( - self._structure, self._leaves) + self._mapping = jax.tree.unflatten(self._structure, self._leaves) return self._mapping def keys(self): diff --git a/haiku/_src/data_structures_test.py b/haiku/_src/data_structures_test.py index 8d1e92592..01fe89ddc 100644 --- a/haiku/_src/data_structures_test.py +++ b/haiku/_src/data_structures_test.py @@ -202,7 +202,7 @@ def test_get(self): self.assertIsNone(f.get("c")) self.assertEqual(f.get("d", f), f) - @parameterized.parameters(jax.tree_util.tree_map, tree.map_structure) + @parameterized.parameters(jax.tree.map, tree.map_structure) def test_tree_map(self, tree_map): f = FlatMap(dict(a=1, b=dict(c=2))) p = tree_map("v: {}".format, f) @@ -229,7 +229,7 @@ def test_copy(self, clone): self.assertEqual(before, after) self.assertEqual(after, {"a": {"b": 1, "c": 2}}) before_dict = data_structures.to_haiku_dict(before) - jax.tree_util.tree_map(self.assertEqual, before_dict, after) + jax.tree.map(self.assertEqual, before_dict, after) @all_picklers def test_pickle_roundtrip(self, pickler): @@ -299,7 +299,7 @@ def test_init(self): self.assertEqual(outer, nested_flatmapping) # Init from flat structures - values, treedef = jax.tree_util.tree_flatten(f) + values, treedef = jax.tree.flatten(f) self.assertEqual( FlatMap(data_structures.FlatComponents(values, treedef)), f) @@ -323,17 +323,17 @@ def test_tree_functions(self): f = FlatMap( {"foo": {"b": {"c": 1}, "d": 2}, "bar": {"c": 1}}) - m = jax.tree_util.tree_map(lambda x: x + 1, f) + m = jax.tree.map(lambda x: x + 1, f) self.assertEqual(type(m), FlatMap) self.assertEqual(m, {"foo": {"b": {"c": 2}, "d": 3}, "bar": {"c": 2}}) - mm = jax.tree_util.tree_map(lambda x, y: x + y, f, f) + mm = jax.tree.map(lambda x, y: x + y, f, f) self.assertEqual(type(mm), FlatMap) self.assertEqual(mm, {"foo": {"b": {"c": 2}, "d": 4}, "bar": {"c": 2}}) - leaves, treedef = jax.tree_util.tree_flatten(f) + leaves, treedef = jax.tree.flatten(f) self.assertEqual(leaves, [1, 1, 2]) - uf = jax.tree_util.tree_unflatten(treedef, leaves) + uf = jax.tree.unflatten(treedef, leaves) self.assertEqual(type(f), FlatMap) self.assertEqual(f, uf) @@ -342,16 +342,16 @@ def test_flatten_nested_struct(self): "baz": {"bat": [4, 5, 6], "qux": [7, [8, 9]]}} f = FlatMap(d) - leaves, treedef = jax.tree_util.tree_flatten(f) + leaves, treedef = jax.tree.flatten(f) self.assertEqual([4, 5, 6, 7, 8, 9, 1, 2, 3], leaves) - g = jax.tree_util.tree_unflatten(treedef, leaves) + g = jax.tree.unflatten(treedef, leaves) self.assertEqual(g, f) self.assertEqual(g, d) def test_nested_sequence(self): f_map = FlatMap( {"foo": [1, 2], "bar": [{"a": 1}, 2]}) - leaves, _ = jax.tree_util.tree_flatten(f_map) + leaves, _ = jax.tree.flatten(f_map) self.assertEqual(leaves, [1, 2, 1, 2]) self.assertEqual(f_map["foo"][0], 1) @@ -361,7 +361,7 @@ def test_different_sequence_types(self, type_of_sequence): f_map = FlatMap( {"foo": type_of_sequence((1, 2)), "bar": type_of_sequence((3, {"b": 4}))}) - leaves, _ = jax.tree_util.tree_flatten(f_map) + leaves, _ = jax.tree.flatten(f_map) self.assertEqual(leaves, [3, 4, 1, 2]) self.assertEqual(f_map["foo"][0], 1) @@ -370,26 +370,28 @@ def test_different_sequence_types(self, type_of_sequence): def test_replace_leaves_with_nodes_in_map(self): f = FlatMap({"foo": 1, "bar": 2}) - f_nested = jax.tree_util.tree_map(lambda x: {"a": (x, x)}, f) - leaves, _ = jax.tree_util.tree_flatten(f_nested) + f_nested = jax.tree.map(lambda x: {"a": (x, x)}, f) + leaves, _ = jax.tree.flatten(f_nested) self.assertEqual(leaves, [2, 2, 1, 1]) def test_frozen_builtins_jax_compatibility(self): f = FlatMap({"foo": [3, 2], "bar": {"a": 3}}) - mapped_frozen_list = jax.tree_util.tree_map(lambda x: x+1, f["foo"]) + mapped_frozen_list = jax.tree.map(lambda x: x + 1, f["foo"]) self.assertEqual(mapped_frozen_list[0], 4) - mapped_frozen_dict = jax.tree_util.tree_map(lambda x: x+1, f["bar"]) + mapped_frozen_dict = jax.tree.map(lambda x: x + 1, f["bar"]) self.assertEqual(mapped_frozen_dict["a"], 4) def test_tree_transpose(self): - outerdef = jax.tree_util.tree_structure(FlatMap({"a": 1, "b": 2})) - innerdef = jax.tree_util.tree_structure([1, 2]) + outerdef = jax.tree.structure(FlatMap({"a": 1, "b": 2})) + innerdef = jax.tree.structure([1, 2]) self.assertEqual( [FlatMap({"a": 3, "b": 5}), FlatMap({"a": 4, "b": 6})], - jax.tree_util.tree_transpose( - outerdef, innerdef, FlatMap({"a": [3, 4], "b": [5, 6]}))) + jax.tree.transpose( + outerdef, innerdef, FlatMap({"a": [3, 4], "b": [5, 6]}) + ), + ) class DataStructuresTest(parameterized.TestCase): diff --git a/haiku/_src/dot.py b/haiku/_src/dot.py index cefa1738a..6e03cd757 100644 --- a/haiku/_src/dot.py +++ b/haiku/_src/dot.py @@ -136,7 +136,7 @@ def to_graph(fun): def wrapped_fun(*args): """See `fun`.""" f = lu.wrap_init(fun) - args_flat, in_tree = jax.tree_util.tree_flatten((args, {})) + args_flat, in_tree = jax.tree.flatten((args, {})) flat_fun, out_tree = jax.api_util.flatten_fun(f, in_tree) graph = Graph.create(title=name_or_str(fun)) @@ -154,7 +154,7 @@ def method_hook(mod: module.Module, method_name: str): module.hook_methods(method_hook), \ jax.core.new_main(DotTrace) as main: out_flat = _interpret_subtrace(flat_fun, main).call_wrapped(*args_flat) - out = jax.tree_util.tree_unflatten(out_tree(), out_flat) + out = jax.tree.unflatten(out_tree(), out_flat) return graph, args, out @@ -207,14 +207,14 @@ def process_primitive(self, primitive, tracers, params): return self.process_call(primitive, fun, tracers, params) inputs = [t.val for t in tracers] - outputs = list(jax.tree_util.tree_leaves(val_out)) + outputs = list(jax.tree.leaves(val_out)) graph = graph_stack.peek() node = Node(id=outputs[0], title=str(primitive), outputs=outputs) graph.nodes.append(node) graph.edges.extend([(i, outputs[0]) for i in inputs]) - return jax.tree_util.tree_map(lambda v: DotTracer(self, v), val_out) + return jax.tree.map(lambda v: DotTracer(self, v), val_out) def process_call(self, call_primitive, f, tracers, params): assert call_primitive.multiple_results @@ -291,8 +291,8 @@ def format_path(path): argid_usecount = collections.Counter() op_outids = set() captures = [] - argids = {id(v) for v in jax.tree_util.tree_leaves(args)} - outids = {id(v) for v in jax.tree_util.tree_leaves(outputs)} + argids = {id(v) for v in jax.tree.leaves(args)} + outids = {id(v) for v in jax.tree.leaves(outputs)} outname = {id(v): format_path(p) for p, v in tree.flatten_with_path(outputs)} def render_graph(g: Graph, parent: Optional[Graph] = None, depth: int = 0): diff --git a/haiku/_src/filtering_test.py b/haiku/_src/filtering_test.py index 4114d119a..e4d86f3c3 100644 --- a/haiku/_src/filtering_test.py +++ b/haiku/_src/filtering_test.py @@ -294,7 +294,7 @@ def map_fn(module_name, name, v): return 2. * v new_params = filtering.map(map_fn, params) - self.assertLen(jax.tree_util.tree_leaves(new_params), 4) + self.assertLen(jax.tree.leaves(new_params), 4) first_layer_params, second_layer_params = filtering.partition( lambda module_name, *_: module_name == "first_layer", diff --git a/haiku/_src/integration/checkpoint_utils.py b/haiku/_src/integration/checkpoint_utils.py index 706283e7f..cc9875f4b 100644 --- a/haiku/_src/integration/checkpoint_utils.py +++ b/haiku/_src/integration/checkpoint_utils.py @@ -47,9 +47,9 @@ def summarize(d: descriptors.ModuleDescriptor) -> Mapping[str, Any]: if params: out["param_size"] = int(hk.data_structures.tree_size(params)) out["param_bytes"] = int(hk.data_structures.tree_bytes(params)) - out["params"] = jax.tree_util.tree_map(format_tensor, params) + out["params"] = jax.tree.map(format_tensor, params) if state: out["state_size"] = int(hk.data_structures.tree_size(state)) out["state_bytes"] = int(hk.data_structures.tree_bytes(state)) - out["state"] = jax.tree_util.tree_map(format_tensor, state) + out["state"] = jax.tree.map(format_tensor, state) return out diff --git a/haiku/_src/integration/common.py b/haiku/_src/integration/common.py index 34dd7f23b..abb3417b4 100644 --- a/haiku/_src/integration/common.py +++ b/haiku/_src/integration/common.py @@ -62,7 +62,7 @@ def cast_if_floating(x): def init_fn(rng, x): params, state = g.init(rng, x) - state = jax.tree_util.tree_map(cast_if_floating, state) + state = jax.tree.map(cast_if_floating, state) return params, state x = np.ones(shape, test_dtype) diff --git a/haiku/_src/integration/descriptors.py b/haiku/_src/integration/descriptors.py index 42583b377..6b0b4ff31 100644 --- a/haiku/_src/integration/descriptors.py +++ b/haiku/_src/integration/descriptors.py @@ -57,9 +57,10 @@ def __init__(self, module: hk.RNNCore, unroller=None): self.unroller = unroller def __call__(self, x: jax.Array): - initial_state = jax.tree_util.tree_map( + initial_state = jax.tree.map( lambda v: v.astype(x.dtype), - self.wrapped.initial_state(batch_size=x.shape[0])) + self.wrapped.initial_state(batch_size=x.shape[0]), + ) x = jnp.expand_dims(x, axis=0) return self.unroller(self.wrapped, x, initial_state) diff --git a/haiku/_src/integration/hk_transforms_test.py b/haiku/_src/integration/hk_transforms_test.py index 49b771718..8444a0f1f 100644 --- a/haiku/_src/integration/hk_transforms_test.py +++ b/haiku/_src/integration/hk_transforms_test.py @@ -56,15 +56,17 @@ def g(x, jit=False): # NOTE: We shard init/apply tests since some modules are expensive to jit # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test). if init: - jax.tree_util.tree_map( - assert_allclose, jax.jit(f.init)(rng, x), f.init(rng, x, jit=True)) + jax.tree.map( + assert_allclose, jax.jit(f.init)(rng, x), f.init(rng, x, jit=True) + ) else: params, state = f.init(rng, x) - jax.tree_util.tree_map( + jax.tree.map( assert_allclose, jax.jit(f.apply)(params, state, rng, x), - f.apply(params, state, rng, x, jit=True)) + f.apply(params, state, rng, x, jit=True), + ) @test_utils.combined_named_parameters(descriptors.ALL_MODULES, test_utils.named_bools('init')) @@ -96,8 +98,8 @@ def s(carry, x): if init: u_params, u_state = u_f.init(rng, xs) - jax.tree_util.tree_map(assert_allclose, u_params, params) - jax.tree_util.tree_map(assert_allclose, u_state, state) + jax.tree.map(assert_allclose, u_params, params) + jax.tree.map(assert_allclose, u_state, state) return def fun(state, x): @@ -106,8 +108,8 @@ def fun(state, x): s_state, s_ys = jax.lax.scan(fun, state, xs) u_ys, u_state = u_f.apply(params, state, rng, xs) - jax.tree_util.tree_map(assert_allclose, u_ys, s_ys) - jax.tree_util.tree_map(assert_allclose, u_state, s_state) + jax.tree.map(assert_allclose, u_ys, s_ys) + jax.tree.map(assert_allclose, u_state, s_state) @test_utils.combined_named_parameters( # TODO(tomhennigan) Enable once grad for _scan_transpose implemented. @@ -137,9 +139,11 @@ def g(x, remat=False): has_aux=True) params, state = f.init(rng, x) - jax.tree_util.tree_map( - assert_allclose, grad_jax_remat(params, state, rng, x), - grad_hk_remat(params, state, rng, x)) + jax.tree.map( + assert_allclose, + grad_jax_remat(params, state, rng, x), + grad_hk_remat(params, state, rng, x), + ) @test_utils.combined_named_parameters(descriptors.ALL_MODULES) def test_optimize_rng_use_under_jit(self, module_fn: ModuleFn, shape, dtype): @@ -159,17 +163,18 @@ def g(x): assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) params, state = jax.jit(f.init)(rng, x) - jax.tree_util.tree_map(assert_allclose, (params, state), f.init(rng, x)) + jax.tree.map(assert_allclose, (params, state), f.init(rng, x)) if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA): # For stochastic modules just test apply runs. jax.device_get(jax.jit(f.apply)(params, state, rng, x)) else: - jax.tree_util.tree_map( + jax.tree.map( assert_allclose, jax.jit(f.apply)(params, state, rng, x), - f.apply(params, state, rng, x)) + f.apply(params, state, rng, x), + ) @test_utils.combined_named_parameters(descriptors.OPTIONAL_BATCH_MODULES) def test_vmap(self, module_fn: ModuleFn, shape, dtype): @@ -198,9 +203,11 @@ def test_vmap(self, module_fn: ModuleFn, shape, dtype): module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) - jax.tree_util.tree_map( - assert_allclose, f_mapped.apply(params, state, rng, x), - v_apply(params, state, rng, x)) + jax.tree.map( + assert_allclose, + f_mapped.apply(params, state, rng, x), + v_apply(params, state, rng, x), + ) @test_utils.combined_named_parameters(descriptors.ALL_MODULES) def test_fast_eval_shape(self, module_fn: ModuleFn, shape, dtype): diff --git a/haiku/_src/integration/jax2tf_test.py b/haiku/_src/integration/jax2tf_test.py index bec23a406..fa8d9792b 100644 --- a/haiku/_src/integration/jax2tf_test.py +++ b/haiku/_src/integration/jax2tf_test.py @@ -65,19 +65,18 @@ def g(x): atol = CUSTOM_ATOL.get(descriptors.module_type(module_fn), DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) - get = lambda t: jax.tree_util.tree_map(lambda x: x.numpy(), t) + get = lambda t: jax.tree.map(lambda x: x.numpy(), t) if init: init_jax = jax_transform(f.init) init_tf = tf_transform(jax2tf.convert(f.init)) - jax.tree_util.tree_map( - assert_allclose, init_jax(rng, x), get(init_tf(rng, x))) + jax.tree.map(assert_allclose, init_jax(rng, x), get(init_tf(rng, x))) else: params, state = f.init(rng, x) apply_jax = jax_transform(f.apply) apply_tf = tf_transform(jax2tf.convert(f.apply)) - jax.tree_util.tree_map( + jax.tree.map( assert_allclose, apply_jax(params, state, rng, x), get(apply_tf(params, state, rng, x)), diff --git a/haiku/_src/integration/jax_transforms_test.py b/haiku/_src/integration/jax_transforms_test.py index 27f74e854..acdbf7e4d 100644 --- a/haiku/_src/integration/jax_transforms_test.py +++ b/haiku/_src/integration/jax_transforms_test.py @@ -53,12 +53,11 @@ def g(x): assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) # Ensure initialization under jit is the same. - jax.tree_util.tree_map( - assert_allclose, f.init(rng, x), jax.jit(f.init)(rng, x)) + jax.tree.map(assert_allclose, f.init(rng, x), jax.jit(f.init)(rng, x)) # Ensure application under jit is the same. params, state = f.init(rng, x) - jax.tree_util.tree_map( + jax.tree.map( assert_allclose, f.apply(params, state, rng, x), jax.jit(f.apply)(params, state, rng, x), @@ -82,9 +81,11 @@ def g(x): # Ensure application under vmap is the same. params, state = f.init(rng, sample) v_apply = jax.vmap(f.apply, in_axes=(None, None, None, 0)) - jax.tree_util.tree_map( + jax.tree.map( lambda a, b: np.testing.assert_allclose(a, b, atol=DEFAULT_ATOL), - f.apply(params, state, rng, batch), v_apply(params, state, rng, batch)) + f.apply(params, state, rng, batch), + v_apply(params, state, rng, batch), + ) if __name__ == '__main__': absltest.main() diff --git a/haiku/_src/integration/numpy_inputs_test.py b/haiku/_src/integration/numpy_inputs_test.py index 5dcc06378..95526f438 100644 --- a/haiku/_src/integration/numpy_inputs_test.py +++ b/haiku/_src/integration/numpy_inputs_test.py @@ -30,8 +30,7 @@ def tree_assert_allclose(a, b, *, atol=1e-6): - jax.tree_util.tree_map( - functools.partial(np.testing.assert_allclose, atol=atol), a, b) + jax.tree.map(functools.partial(np.testing.assert_allclose, atol=atol), a, b) def get_module_cls(module_fn: ModuleFn) -> type[hk.Module]: diff --git a/haiku/_src/layer_norm_test.py b/haiku/_src/layer_norm_test.py index 3d6716f0d..3d2d74d69 100644 --- a/haiku/_src/layer_norm_test.py +++ b/haiku/_src/layer_norm_test.py @@ -72,8 +72,7 @@ def f(x): fwd = transform.transform(f) data = jnp.zeros([2, 3, 4, 5], dtype=jnp.bfloat16) params = fwd.init(jax.random.PRNGKey(428), data) - bf16_params = jax.tree_util.tree_map( - lambda t: t.astype(jnp.bfloat16), params) + bf16_params = jax.tree.map(lambda t: t.astype(jnp.bfloat16), params) self.assertEqual(fwd.apply(bf16_params, None, data).dtype, jnp.bfloat16) @parameterized.parameters(True, False) diff --git a/haiku/_src/layer_stack.py b/haiku/_src/layer_stack.py index 5882bdb4f..14b22b88a 100644 --- a/haiku/_src/layer_stack.py +++ b/haiku/_src/layer_stack.py @@ -377,7 +377,7 @@ def iterate(f): if with_per_layer_inputs: @functools.wraps(f) def wrapped(x, *args, **kwargs): - for ys in jax.tree_util.tree_leaves(args): + for ys in jax.tree.leaves(args): assert ys.shape[0] == num_layers, f"{ys.shape[0]} != {num_layers}" mod = _LayerStackWithPerLayer( f, diff --git a/haiku/_src/layer_stack_test.py b/haiku/_src/layer_stack_test.py index a4b917aa8..21f123e9d 100644 --- a/haiku/_src/layer_stack_test.py +++ b/haiku/_src/layer_stack_test.py @@ -236,8 +236,9 @@ def outer_fn_layer_stack(x): assert_fn = functools.partial( np.testing.assert_allclose, atol=1e-4, rtol=1e-4) - jax.tree_util.tree_map( - assert_fn, unrolled_grad, _slice_layers_params(layer_stack_grad)) + jax.tree.map( + assert_fn, unrolled_grad, _slice_layers_params(layer_stack_grad) + ) def test_random(self): """Random numbers should be handled correctly.""" @@ -415,8 +416,7 @@ def mul_by_m(x): m_x = m_x[..., None] return x * m_x * alpha - params = jax.tree_util.tree_map( - mul_by_m, hk_fn.init(next(key_seq), init_value)) + params = jax.tree.map(mul_by_m, hk_fn.init(next(key_seq), init_value)) a, b = forward[-1] x_n = hk_fn.apply(params, init_value) @@ -591,8 +591,8 @@ def stacked(x: jax.Array) -> jax.Array: stacked_params = stacked.init(rng, x) self.assertEqual( - jax.tree_util.tree_structure(looped_params), - jax.tree_util.tree_structure(stacked_params), + jax.tree.structure(looped_params), + jax.tree.structure(stacked_params), ) # Use same set of params for both calls since stacked_params have different @@ -668,8 +668,8 @@ def stacked(x: CustomParam) -> CustomParam: stacked_params = stacked.init(rng, x) self.assertEqual( - jax.tree_util.tree_structure(looped_params), - jax.tree_util.tree_structure(stacked_params), + jax.tree.structure(looped_params), + jax.tree.structure(stacked_params), ) # Use same set of params for both calls since stacked_params have different diff --git a/haiku/_src/lift.py b/haiku/_src/lift.py index d31fbe12d..7f97e3f83 100644 --- a/haiku/_src/lift.py +++ b/haiku/_src/lift.py @@ -214,7 +214,7 @@ def lift( >>> rng = jax.random.PRNGKey(777) >>> x = jnp.ones([32, 128]) >>> params = f.init(rng, x) - >>> jax.tree_util.tree_map(lambda x: x.shape, params) + >>> jax.tree.map(lambda x: x.shape, params) {'ensemble/mlp/~/linear_0': {'b': (4, 300), 'w': (4, 128, 300)}, 'ensemble/mlp/~/linear_1': {'b': (4, 100), 'w': (4, 300, 100)}, 'ensemble/mlp/~/linear_2': {'b': (4, 10), 'w': (4, 100, 10)}} @@ -228,8 +228,8 @@ def lift( Args: init_fn: The ``init`` function from an :class:`Transformed`\ . - allow_reuse: Allows lifted parameters and state to be reused from the - outer :func:`transform`. This can be desirable when using ``lift`` within + allow_reuse: Allows lifted parameters and state to be reused from the outer + :func:`transform`. This can be desirable when using ``lift`` within control flow (e.g. ``hk.scan``). name: A string name to prefix parameters with. diff --git a/haiku/_src/lift_test.py b/haiku/_src/lift_test.py index 8056b7f4c..7ce0254b6 100644 --- a/haiku/_src/lift_test.py +++ b/haiku/_src/lift_test.py @@ -234,12 +234,11 @@ def outer(): outer = transform.transform_with_state(outer) params, state = outer.init(None) self.assertEmpty(params) - self.assertEqual(jax.tree_util.tree_map(int, state), {"lifted/~": {"w": 0}}) + self.assertEqual(jax.tree.map(int, state), {"lifted/~": {"w": 0}}) for expected in (1, 2, 3): (w, inner_state), state = outer.apply(params, state, None) - self.assertEqual( - jax.tree_util.tree_map(int, inner_state), {"~": {"w": expected}}) + self.assertEqual(jax.tree.map(int, inner_state), {"~": {"w": expected}}) self.assertEqual(w, expected) self.assertEmpty(params) self.assertEqual(state, {"lifted/~": {"w": expected}}) @@ -264,13 +263,11 @@ def __call__(self): outer = transform.transform_with_state(lambda: Outer()()) # pylint: disable=unnecessary-lambda params, state = outer.init(None) self.assertEmpty(params) - self.assertEqual( - jax.tree_util.tree_map(int, state), {"outer/lifted/~": {"w": 0}}) + self.assertEqual(jax.tree.map(int, state), {"outer/lifted/~": {"w": 0}}) for expected in (1, 2, 3): (w, inner_state), state = outer.apply(params, state, None) - self.assertEqual( - jax.tree_util.tree_map(int, inner_state), {"~": {"w": expected}}) + self.assertEqual(jax.tree.map(int, inner_state), {"~": {"w": expected}}) self.assertEqual(w, expected) self.assertEmpty(params) self.assertEqual(state, {"outer/lifted/~": {"w": expected}}) @@ -306,12 +303,11 @@ def outer(): params, state = outer.init(None) self.assertEmpty(params) - self.assertEqual(jax.tree_util.tree_map(int, state), {"~": {"w": 0}}) + self.assertEqual(jax.tree.map(int, state), {"~": {"w": 0}}) for expected in (1, 2, 3): (w, inner_state), state = outer.apply(params, state, None) - self.assertEqual( - jax.tree_util.tree_map(int, inner_state), {"~": {"w": expected}}) + self.assertEqual(jax.tree.map(int, inner_state), {"~": {"w": expected}}) self.assertEqual(w, expected) self.assertEmpty(params) self.assertEqual(state, inner_state) @@ -336,12 +332,11 @@ def __call__(self): outer = transform.transform_with_state(lambda: Outer()()) # pylint: disable=unnecessary-lambda params, state = outer.init(None) self.assertEmpty(params) - self.assertEqual(jax.tree_util.tree_map(int, state), {"outer/~": {"w": 0}}) + self.assertEqual(jax.tree.map(int, state), {"outer/~": {"w": 0}}) for expected in (1, 2, 3): (w, inner_state), state = outer.apply(params, state, None) - self.assertEqual( - jax.tree_util.tree_map(int, inner_state), {"~": {"w": expected}}) + self.assertEqual(jax.tree.map(int, inner_state), {"~": {"w": expected}}) self.assertEqual(w, expected) self.assertEmpty(params) self.assertEqual(state, {"outer/~": {"w": expected}}) @@ -432,8 +427,7 @@ def fn(x): x = jnp.ones([10, 10]) params_with_lift = fn.init(None, x) params_without_lift = transform.transform(inner_module).init(None, x) - jax.tree_util.tree_map( - self.assertAlmostEqual, params_with_lift, params_without_lift) + jax.tree.map(self.assertAlmostEqual, params_with_lift, params_without_lift) fn.apply(params_with_lift, None, x) @@ -488,7 +482,7 @@ def test_same_name_across_transforms_no_closed_error(self): params1 = init1(None, 1.) params2 = init2(None, 1.) # does not fail - jax.tree_util.tree_map(self.assertAlmostEqual, params1, params2) + jax.tree.map(self.assertAlmostEqual, params1, params2) def test_closed_over_within_transparent_lift_no_closed_error(self): # You can close over modules within the boundary of the transparent_lift. diff --git a/haiku/_src/mixed_precision_test.py b/haiku/_src/mixed_precision_test.py index c51c7b90c..4196247e4 100644 --- a/haiku/_src/mixed_precision_test.py +++ b/haiku/_src/mixed_precision_test.py @@ -67,8 +67,7 @@ def g(*args, **kwargs): params = f.init(rng, *args, **kwargs) out = f.apply(params, None, *args, **kwargs) return params, out - return jax.tree_util.tree_map( - lambda x: x.dtype, jax.eval_shape(g, *args, **kwargs)) + return jax.tree.map(lambda x: x.dtype, jax.eval_shape(g, *args, **kwargs)) class MixedPrecisionTest(absltest.TestCase): @@ -240,7 +239,7 @@ def test_policy_for_reloaded_class(self): params, y = transform_and_run_once( lambda: conv_local.ConvND(2, 1, 1)(jnp.ones([1, 1, 1, 1]))) - jax.tree_util.tree_map(lambda p: self.assertEqual(p, jnp.float16), params) + jax.tree.map(lambda p: self.assertEqual(p, jnp.float16), params) self.assertEqual(y, jnp.float16) @test_utils.transform_and_run diff --git a/haiku/_src/module.py b/haiku/_src/module.py index 438feed5d..25c279e11 100644 --- a/haiku/_src/module.py +++ b/haiku/_src/module.py @@ -263,7 +263,7 @@ def intercept_methods(interceptor: MethodGetter): ... x = x.astype(jnp.float32) ... return x ... - ... args, kwargs = jax.tree_util.tree_map(cast_if_array, (args, kwargs)) + ... args, kwargs = jax.tree.map(cast_if_array, (args, kwargs)) ... out = next_f(*args, **kwargs) ... return out @@ -388,7 +388,7 @@ def name_scope( >>> f = hk.transform(lambda x: MyModule()(x)) >>> params = f.init(jax.random.PRNGKey(42), jnp.ones([1, 1])) - >>> jax.tree_util.tree_map(jnp.shape, params) + >>> jax.tree.map(jnp.shape, params) {'my_module/my_name_scope': {'w': ()}, 'my_module/my_name_scope/submodule': {'b': (1,), 'w': (1, 1)}} diff --git a/haiku/_src/multi_transform.py b/haiku/_src/multi_transform.py index 6a8761db8..7aa446cc1 100644 --- a/haiku/_src/multi_transform.py +++ b/haiku/_src/multi_transform.py @@ -119,7 +119,7 @@ def multi_transform_with_state( >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> params, state = f.init(rng, x) - >>> jax.tree_util.tree_map(jnp.shape, params) + >>> jax.tree.map(jnp.shape, params) {'decoder': {'b': (1,), 'w': (1, 1)}, 'encoder': {'b': (1,), 'w': (1, 1)}} @@ -128,8 +128,8 @@ def multi_transform_with_state( >>> y, state = decode(params, state, None, z) Args: - f: Function returning a "template" function and an arbitrary - tree of functions using modules connected in the template function. + f: Function returning a "template" function and an arbitrary tree of + functions using modules connected in the template function. Returns: An ``init`` function and a tree of pure ``apply`` functions. @@ -149,7 +149,7 @@ def init_fn(*args, **kwargs) -> tuple[hk.MutableParams, hk.MutableState]: def apply_fn_i(i): def apply_fn(*args, **kwargs): """Applies the transformed function at the given inputs.""" - return jax.tree_util.tree_leaves(f()[1])[i](*args, **kwargs) + return jax.tree.leaves(f()[1])[i](*args, **kwargs) return apply_fn # We need to find out the structure of f()[1], including how many @@ -163,7 +163,7 @@ def get_output_treedef() -> Box: rng = jax.random.PRNGKey(42) # This is fine, see above fns = hk.transform_with_state(lambda: f()[1]) apply_fns, _ = fns.apply(*fns.init(rng), rng) - return Box(jax.tree_util.tree_structure(apply_fns)) + return Box(jax.tree.structure(apply_fns)) output_treedef = jax.eval_shape(get_output_treedef).python_value apply_fns = make_tree(lambda i: hk.transform_with_state(apply_fn_i(i)).apply, @@ -208,7 +208,7 @@ def multi_transform( >>> rng = jax.random.PRNGKey(42) >>> x = jnp.ones([1, 1]) >>> params = f.init(rng, x) - >>> jax.tree_util.tree_map(jnp.shape, params) + >>> jax.tree.map(jnp.shape, params) {'decoder': {'b': (1,), 'w': (1, 1)}, 'encoder': {'b': (1,), 'w': (1, 1)}} @@ -256,7 +256,7 @@ def apply_fn(params: hk.Params, rng, *args, **kwargs): return out return apply_fn - apply_fns = jax.tree_util.tree_map(apply_without_state, f.apply) + apply_fns = jax.tree.map(apply_without_state, f.apply) return MultiTransformed(init_fn, apply_fns) @@ -326,16 +326,18 @@ def apply_fn(params, *args, **kwargs): def make_new_apply_fn(apply_fn, params, state, *args, **kwargs): check_rng_kwarg(kwargs) return apply_fn(params, state, None, *args, **kwargs) - apply_fn = jax.tree_util.tree_map( - lambda fn: functools.partial(make_new_apply_fn, fn), f.apply) + apply_fn = jax.tree.map( + lambda fn: functools.partial(make_new_apply_fn, fn), f.apply + ) f_new = MultiTransformedWithState(init=f.init, apply=apply_fn) elif isinstance(f, MultiTransformed): def make_new_apply_fn(apply_fn, params, *args, **kwargs): check_rng_kwarg(kwargs) return apply_fn(params, None, *args, **kwargs) - apply_fn = jax.tree_util.tree_map( - lambda fn: functools.partial(make_new_apply_fn, fn), f.apply) + apply_fn = jax.tree.map( + lambda fn: functools.partial(make_new_apply_fn, fn), f.apply + ) f_new = MultiTransformed(init=f.init, apply=apply_fn) else: @@ -349,4 +351,4 @@ def make_new_apply_fn(apply_fn, params, *args, **kwargs): def make_tree(f: Callable[[int], Any], treedef: jax.tree_util.PyTreeDef): leaves = list(map(f, range(treedef.num_leaves))) - return jax.tree_util.tree_unflatten(treedef, leaves) + return jax.tree.unflatten(treedef, leaves) diff --git a/haiku/_src/nets/resnet_test.py b/haiku/_src/nets/resnet_test.py index 77587f9c6..4f937ec88 100644 --- a/haiku/_src/nets/resnet_test.py +++ b/haiku/_src/nets/resnet_test.py @@ -97,8 +97,7 @@ def model_func(img): image = jnp.ones([2, 64, 64, 3]) rng = jax.random.PRNGKey(0) params, _ = model.init(rng, image) - num_params = sum( - np.prod(p.shape).item() for p in jax.tree_util.tree_leaves(params)) + num_params = sum(np.prod(p.shape).item() for p in jax.tree.leaves(params)) self.assertGreater(num_params, int(0.998 * expected_num_params)) self.assertLess(num_params, int(1.002 * expected_num_params)) diff --git a/haiku/_src/nets/vqvae_test.py b/haiku/_src/nets/vqvae_test.py index d24b121f1..25195dbbd 100644 --- a/haiku/_src/nets/vqvae_test.py +++ b/haiku/_src/nets/vqvae_test.py @@ -55,7 +55,7 @@ def testConstruct(self, constructor, kwargs): # Output shape is correct self.assertEqual(vq_output['quantize'].shape, inputs.shape) - vq_output_np = jax.tree_util.tree_map(lambda t: t, vq_output) + vq_output_np = jax.tree.map(lambda t: t, vq_output) embeddings_np = vqvae_module.embeddings self.assertEqual(embeddings_np.shape, diff --git a/haiku/_src/random_test.py b/haiku/_src/random_test.py index d3f2025f9..278f7101c 100644 --- a/haiku/_src/random_test.py +++ b/haiku/_src/random_test.py @@ -41,7 +41,7 @@ def f(): # With optimize_rng_use the keys returned should be equal to split(n). f_opt = transform.transform(random.optimize_rng_use(f)) - jax.tree_util.tree_map( + jax.tree.map( assert_allclose, f_opt.apply({}, key), tuple(jax.random.split(key, 3))[1:], @@ -50,9 +50,7 @@ def f(): # Without optimize_rng_use the keys should be equivalent to splitting in a # loop. f = transform.transform(f) - jax.tree_util.tree_map( - assert_allclose, f.apply({}, key), tuple(split_for_n(key, 2)) - ) + jax.tree.map(assert_allclose, f.apply({}, key), tuple(split_for_n(key, 2))) def test_rbg_default_impl(self): with jax.default_prng_impl("rbg"): diff --git a/haiku/_src/recurrent.py b/haiku/_src/recurrent.py index 14521700a..27f17b889 100644 --- a/haiku/_src/recurrent.py +++ b/haiku/_src/recurrent.py @@ -122,25 +122,26 @@ def static_unroll(core, input_sequence, initial_state, time_major=True): """ output_sequence = [] time_axis = 0 if time_major else 1 - num_steps = jax.tree_util.tree_leaves(input_sequence)[0].shape[time_axis] + num_steps = jax.tree.leaves(input_sequence)[0].shape[time_axis] state = initial_state for t in range(num_steps): if time_major: - inputs = jax.tree_util.tree_map(lambda x, _t=t: x[_t], input_sequence) + inputs = jax.tree.map(lambda x, _t=t: x[_t], input_sequence) else: - inputs = jax.tree_util.tree_map(lambda x, _t=t: x[:, _t], input_sequence) + inputs = jax.tree.map(lambda x, _t=t: x[:, _t], input_sequence) outputs, state = core(inputs, state) output_sequence.append(outputs) # Stack outputs along the time axis. - output_sequence = jax.tree_util.tree_map( - lambda *args: jnp.stack(args, axis=time_axis), *output_sequence) + output_sequence = jax.tree.map( + lambda *args: jnp.stack(args, axis=time_axis), *output_sequence + ) return output_sequence, state def _swap_batch_time(inputs): """Swaps batch and time axes, assumed to be the first two axes.""" - return jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), inputs) + return jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), inputs) def dynamic_unroll(core, @@ -220,7 +221,7 @@ def scan_f(prev_state, inputs): def add_batch(nest, batch_size: Optional[int]): """Adds a batch dimension at axis 0 to the leaves of a nested structure.""" broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape) - return jax.tree_util.tree_map(broadcast, nest) + return jax.tree.map(broadcast, nest) class VanillaRNN(RNNCore): @@ -646,11 +647,10 @@ def __call__(self, inputs, state): Tuple of the wrapped core's ``output, next_state``. """ inputs, should_reset = inputs - if jax.tree_util.treedef_is_leaf( - jax.tree_util.tree_structure(should_reset)): + if jax.tree_util.treedef_is_leaf(jax.tree.structure(should_reset)): # Equivalent to not tree.is_nested, but with support for Jax extensible # pytrees. - should_reset = jax.tree_util.tree_map(lambda _: should_reset, state) + should_reset = jax.tree.map(lambda _: should_reset, state) # We now need to manually pad 'on the right' to ensure broadcasting operates # correctly. @@ -698,26 +698,25 @@ def __call__(self, inputs, state): # >> batch_entry 1: # >> [[0. 0.] # >> [0. 0.]] - should_reset = jax.tree_util.tree_map( - _validate_and_conform, should_reset, state) + should_reset = jax.tree.map(_validate_and_conform, should_reset, state) if self._is_batched(state): - batch_size = jax.tree_util.tree_leaves(inputs)[0].shape[0] + batch_size = jax.tree.leaves(inputs)[0].shape[0] else: batch_size = None - initial_state = jax.tree_util.tree_map( - lambda s, i: i.astype(s.dtype), state, self.initial_state(batch_size)) - state = jax.tree_util.tree_map( - jnp.where, should_reset, initial_state, state) + initial_state = jax.tree.map( + lambda s, i: i.astype(s.dtype), state, self.initial_state(batch_size) + ) + state = jax.tree.map(jnp.where, should_reset, initial_state, state) return self.core(inputs, state) def initial_state(self, batch_size: Optional[int]): return self.core.initial_state(batch_size) def _is_batched(self, state): - state = jax.tree_util.tree_leaves(state) + state = jax.tree.leaves(state) if not state: # Empty state is treated as unbatched. return False - batched = jax.tree_util.tree_leaves(self.initial_state(batch_size=1)) + batched = jax.tree.leaves(self.initial_state(batch_size=1)) return all(b.shape[1:] == s.shape[1:] for b, s in zip(batched, state)) @@ -747,7 +746,7 @@ def __call__(self, inputs, state): state_idx = 0 for idx, layer in enumerate(self.layers): if self.skip_connections and idx > 0: - current_inputs = jax.tree_util.tree_map( + current_inputs = jax.tree.map( lambda x, *args: ( None if x is None else jnp.concatenate((x,) + args, axis=-1) ), @@ -765,8 +764,7 @@ def __call__(self, inputs, state): current_inputs = layer(current_inputs) if self.skip_connections: - out = jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=-1), - *outputs) + out = jax.tree.map(lambda *args: jnp.concatenate(args, axis=-1), *outputs) else: out = current_inputs diff --git a/haiku/_src/recurrent_test.py b/haiku/_src/recurrent_test.py index aa96a72ae..02e8508bd 100644 --- a/haiku/_src/recurrent_test.py +++ b/haiku/_src/recurrent_test.py @@ -140,11 +140,11 @@ class VanillaRNNTest(absltest.TestCase): def test_double_bias_length_parameters(self): double_bias = recurrent.VanillaRNN(1, double_bias=True) double_bias(jnp.zeros([1]), double_bias.initial_state(None)) - double_bias_params = jax.tree_util.tree_leaves(double_bias.params_dict()) + double_bias_params = jax.tree.leaves(double_bias.params_dict()) vanilla = recurrent.VanillaRNN(1, double_bias=False) vanilla(jnp.zeros([1]), vanilla.initial_state(None)) - vanilla_params = jax.tree_util.tree_leaves(vanilla.params_dict()) + vanilla_params = jax.tree.leaves(vanilla.params_dict()) self.assertLen(double_bias_params, len(vanilla_params) + 1) @@ -202,7 +202,7 @@ def __call__(self, inputs, prev_state): return inputs, prev_state def initial_state(self, batch_size): - return jax.tree_util.tree_map(jnp.zeros_like, self._state) + return jax.tree.map(jnp.zeros_like, self._state) class _IncrementByOneCore(recurrent.RNNCore): @@ -513,12 +513,16 @@ def test_batch_major(self, unroll): batch_major_outputs, batch_major_unroll_state_out = unroll( core, batch_major_inputs, initial_state, time_major=False) - jax.tree_util.tree_map( + jax.tree.map( np.testing.assert_array_equal, - time_major_unroll_state_out, batch_major_unroll_state_out) - jax.tree_util.tree_map( + time_major_unroll_state_out, + batch_major_unroll_state_out, + ) + jax.tree.map( lambda x, y: np.testing.assert_array_equal(x, jnp.swapaxes(y, 0, 1)), - time_major_outputs, batch_major_outputs) + time_major_outputs, + batch_major_outputs, + ) if __name__ == "__main__": diff --git a/haiku/_src/rms_norm_test.py b/haiku/_src/rms_norm_test.py index 0847d305d..3c3d1a53e 100644 --- a/haiku/_src/rms_norm_test.py +++ b/haiku/_src/rms_norm_test.py @@ -53,8 +53,7 @@ def f(x): fwd = transform.transform(f) data = jnp.zeros([2, 3, 4, 5], dtype=jnp.bfloat16) params = fwd.init(jax.random.PRNGKey(428), data) - bf16_params = jax.tree_util.tree_map( - lambda t: t.astype(jnp.bfloat16), params) + bf16_params = jax.tree.map(lambda t: t.astype(jnp.bfloat16), params) self.assertEqual(fwd.apply(bf16_params, None, data).dtype, jnp.bfloat16) @test_utils.transform_and_run diff --git a/haiku/_src/stateful.py b/haiku/_src/stateful.py index 681736c9d..f85c7ab20 100644 --- a/haiku/_src/stateful.py +++ b/haiku/_src/stateful.py @@ -32,7 +32,7 @@ def copy_structure(bundle: T) -> T: - return jax.tree_util.tree_map(lambda x: x, bundle) + return jax.tree.map(lambda x: x, bundle) def internal_state(*, params=True) -> InternalState: @@ -314,16 +314,18 @@ def if_changed(is_new, box_a, box_b): is_new_param = lambda a, b: a is not b params_before, params_after = box_and_fill_missing(before.params, after.params) - params_after = jax.tree_util.tree_map( - functools.partial(if_changed, is_new_param), params_before, params_after) + params_after = jax.tree.map( + functools.partial(if_changed, is_new_param), params_before, params_after + ) # state def is_new_state(a: base.StatePair, b: base.StatePair): return a.initial is not b.initial or a.current is not b.current state_before, state_after = box_and_fill_missing(before.state, after.state) - state_after = jax.tree_util.tree_map( - functools.partial(if_changed, is_new_state), state_before, state_after) + state_after = jax.tree.map( + functools.partial(if_changed, is_new_state), state_before, state_after + ) # rng def is_new_rng(a: Optional[base.PRNGSequenceState], @@ -590,7 +592,7 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): "Use jax.lax.scan() instead.") if length is None: - length = jax.tree_util.tree_leaves(xs)[0].shape[0] + length = jax.tree.leaves(xs)[0].shape[0] running_init_fn = not base.params_frozen() @@ -599,20 +601,19 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): # carry contains the Haiku state and during `init` this may change structure # (e.g. as state is created). if not length: - x0 = jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs) + x0 = jax.tree.map(lambda x: jnp.zeros(x.shape[1:], x.dtype), xs) _, y0 = f(init, x0) - y0 = jax.tree_util.tree_map( - lambda y: jnp.zeros((0,) + y.shape, y.dtype), y0) + y0 = jax.tree.map(lambda y: jnp.zeros((0,) + y.shape, y.dtype), y0) return init, y0 if reverse: - x0 = jax.tree_util.tree_map(lambda x: x[-1], xs) - xs = jax.tree_util.tree_map(lambda x: x[:-1], xs) + x0 = jax.tree.map(lambda x: x[-1], xs) + xs = jax.tree.map(lambda x: x[:-1], xs) else: - x0 = jax.tree_util.tree_map(lambda x: x[0], xs) - xs = jax.tree_util.tree_map(lambda x: x[1:], xs) + x0 = jax.tree.map(lambda x: x[0], xs) + xs = jax.tree.map(lambda x: x[1:], xs) init, y0 = f(init, x0) - y0 = jax.tree_util.tree_map(lambda y: jnp.expand_dims(y, 0), y0) + y0 = jax.tree.map(lambda y: jnp.expand_dims(y, 0), y0) length -= 1 if not length: return init, y0 @@ -646,11 +647,9 @@ def stateful_fun(carry, x): if running_init_fn: if reverse: - ys = jax.tree_util.tree_map( - lambda y0, ys: jnp.concatenate([ys, y0]), y0, ys) + ys = jax.tree.map(lambda y0, ys: jnp.concatenate([ys, y0]), y0, ys) else: - ys = jax.tree_util.tree_map( - lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys) + ys = jax.tree.map(lambda y0, ys: jnp.concatenate([y0, ys]), y0, ys) return carry, ys @@ -702,7 +701,7 @@ def pure_body_fun(i, val): def maybe_get_axis(axis: Optional[int], arrays: Any) -> Optional[int]: """Returns `array.shape[axis]` for one of the arrays in the input.""" if axis is None: return None - shapes = [a.shape for a in jax.tree_util.tree_leaves(arrays)] + shapes = [a.shape for a in jax.tree.leaves(arrays)] sizes = {s[axis] for s in shapes} if len(sizes) != 1: raise ValueError("Arrays must have the same mapped axis size, found " @@ -715,9 +714,13 @@ def maybe_get_axis(axis: Optional[int], arrays: Any) -> Optional[int]: def get_mapped_axis_size(args: tuple[Any], in_axes: Any) -> int: - sizes = uniq(jax.tree_util.tree_leaves( - jax.tree_util.tree_map(maybe_get_axis, in_axes, args, - is_leaf=lambda x: x is None))) + sizes = uniq( + jax.tree.leaves( + jax.tree.map( + maybe_get_axis, in_axes, args, is_leaf=lambda x: x is None + ) + ) + ) assert sizes, "hk.vmap should guarantee non-empty in_axes" # NOTE: We use the first in_axes regardless of how many non-unique values # there are to allow JAX to handle multiple conflicting sizes. @@ -803,7 +806,7 @@ def vmap( See :func:`jax.vmap`. """ - if not jax.tree_util.tree_leaves(in_axes): + if not jax.tree.leaves(in_axes): raise ValueError( f"{fun.__name__} must have at least one non-None value in in_axes " "to use with `hk.vmap`.") diff --git a/haiku/_src/stateful_test.py b/haiku/_src/stateful_test.py index 2cd8595e6..c4e3f03c2 100644 --- a/haiku/_src/stateful_test.py +++ b/haiku/_src/stateful_test.py @@ -412,8 +412,7 @@ def test_switch_no_transform(self): def test_difference_empty(self): before = stateful.internal_state() after = stateful.internal_state() - self.assertEmpty( - jax.tree_util.tree_leaves(stateful.difference(before, after))) + self.assertEmpty(jax.tree.leaves(stateful.difference(before, after))) @parameterized.parameters(base.get_parameter, base.get_state) @test_utils.transform_and_run(run_apply=False) @@ -600,7 +599,7 @@ def f(x): # State should not be mapped. self.assertEmpty(params) - cnt, = jax.tree_util.tree_leaves(state) + (cnt,) = jax.tree.leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 0) @@ -608,7 +607,7 @@ def f(x): y, state = f.apply(params, state, None, x) self.assertEqual(y.shape, (4,)) np.testing.assert_allclose(y, x ** 2) - cnt, = jax.tree_util.tree_leaves(state) + (cnt,) = jax.tree.leaves(state) self.assertEqual(cnt.ndim, 0) self.assertEqual(cnt, 1) diff --git a/haiku/_src/summarise.py b/haiku/_src/summarise.py index 808ec66f5..266985a47 100644 --- a/haiku/_src/summarise.py +++ b/haiku/_src/summarise.py @@ -87,8 +87,8 @@ class ModuleDetails: @classmethod def of(cls, module: hk.Module, method_name: str) -> "ModuleDetails": - params = jax.tree_util.tree_map(ArraySpec.from_array, module.params_dict()) - state = jax.tree_util.tree_map(ArraySpec.from_array, module.state_dict()) + params = jax.tree.map(ArraySpec.from_array, module.params_dict()) + state = jax.tree.map(ArraySpec.from_array, module.state_dict()) return ModuleDetails(module=module, method_name=method_name, params=params, state=state) @@ -128,9 +128,9 @@ def get_call_stack() -> Sequence[ModuleDetails]: def to_spec(tree): - return jax.tree_util.tree_map( - lambda x: ArraySpec.from_array(x) if isinstance(x, jax.Array) else x, - tree) + return jax.tree.map( + lambda x: ArraySpec.from_array(x) if isinstance(x, jax.Array) else x, tree + ) IGNORED_METHODS = ("__init__", "params_dict", "state_dict") diff --git a/haiku/_src/transform_test.py b/haiku/_src/transform_test.py index 383bb562f..6e1dff92f 100644 --- a/haiku/_src/transform_test.py +++ b/haiku/_src/transform_test.py @@ -346,7 +346,7 @@ def test_method(self): obj_out, y = obj.forward.apply(params, None, x) self.assertEqual(y, 1) self.assertIs(obj, obj_out) - params = jax.tree_util.tree_map(lambda v: v + 1, params) + params = jax.tree.map(lambda v: v + 1, params) obj_out, y = obj.forward.apply(params, None, x) self.assertEqual(y, 2) self.assertIs(obj, obj_out) diff --git a/haiku/_src/utils.py b/haiku/_src/utils.py index 02fcb51d5..6fbf9ad92 100644 --- a/haiku/_src/utils.py +++ b/haiku/_src/utils.py @@ -202,7 +202,7 @@ def tree_size(tree) -> int: And compare that with casting our parameters to bf16: - >>> params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + >>> params = jax.tree.map(lambda x: x.astype(jnp.bfloat16), params) >>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') @@ -214,7 +214,7 @@ def tree_size(tree) -> int: Returns: The total size (number of elements) of the array(s) in the input. """ - return sum(x.size for x in jax.tree_util.tree_leaves(tree)) + return sum(x.size for x in jax.tree.leaves(tree)) def tree_bytes(tree) -> int: @@ -240,7 +240,7 @@ def tree_bytes(tree) -> int: And compare that with casting our parameters to bf16: - >>> params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) + >>> params = jax.tree.map(lambda x: x.astype(jnp.bfloat16), params) >>> num_params = hk.data_structures.tree_size(params) >>> byte_size = hk.data_structures.tree_bytes(params) >>> print(f'{num_params} params, size: {byte_size / 1e6:.2f}MB') @@ -252,7 +252,7 @@ def tree_bytes(tree) -> int: Returns: The total size in bytes of the array(s) in the input. """ - return sum(x.size * x.dtype.itemsize for x in jax.tree_util.tree_leaves(tree)) + return sum(x.size * x.dtype.itemsize for x in jax.tree.leaves(tree)) _CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") camel_to_snake = lambda value: _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower() diff --git a/haiku/benchmarks/init.py b/haiku/benchmarks/init.py index 6c30e0ec0..c87fa4ca8 100644 --- a/haiku/benchmarks/init.py +++ b/haiku/benchmarks/init.py @@ -58,7 +58,7 @@ def run_bench(state): while state: params, _ = jitted_init(k, x) # block on computation to finish - jax.tree_util.tree_map(lambda x: x.block_until_ready(), params) + jax.tree.map(lambda x: x.block_until_ready(), params) return trace_bench, compile_bench, run_bench