Skip to content

Commit

Permalink
Use short aliases for common tree operations in Haiku.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686515246
  • Loading branch information
tomhennigan authored and copybara-github committed Oct 17, 2024
1 parent 69f6eed commit 79b8f1f
Show file tree
Hide file tree
Showing 51 changed files with 270 additions and 277 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def update_rule(param, update):

for images, labels in input_dataset:
grads = jax.grad(loss_fn_t.apply)(params, images, labels)
params = jax.tree_util.tree_map(update_rule, params, grads)
params = jax.tree.map(update_rule, params, grads)
```

The core of Haiku is `hk.transform`. The `transform` function allows you to
Expand Down Expand Up @@ -218,9 +218,9 @@ grads = jax.grad(loss_fn_t.apply)(params, images, labels)
### Training

The training loop in this example is very simple. One detail to note is the use
of `jax.tree_util.tree_map` to apply the `sgd` function across all matching
entries in `params` and `grads`. The result has the same structure as the
previous `params` and can again be used with `apply`.
of `jax.tree.map` to apply the `sgd` function across all matching entries in
`params` and `grads`. The result has the same structure as the previous `params`
and can again be used with `apply`.


## Installation<a id="installation"></a>
Expand Down Expand Up @@ -412,7 +412,7 @@ params = loss_fn_t.init(rng, sample_image, sample_label)

# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)
params = jax.tree.map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch():
"""Constructs a superbatch, i.e. one batch of data per device."""
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@
}
],
"source": [
"mutated_params = jax.tree_util.tree_map(lambda x: x+1., params)\n",
"mutated_params = jax.tree.map(lambda x: x+1., params)\n",
"print(f'Mutated params \\n : {mutated_params}')\n",
"mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)\n",
"print(f'Output with mutated params \\n {mutated_output}')"
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/build_your_own_haiku.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@
"source": [
"# Make printing parameters a little more readable\n",
"def parameter_shapes(params):\n",
" return jax.tree_util.tree_map(lambda p: p.shape, params)\n",
" return jax.tree.map(lambda p: p.shape, params)\n",
"\n",
"\n",
"class Linear:\n",
Expand Down Expand Up @@ -763,7 +763,7 @@
"@jax.jit\n",
"def update(params, x, y):\n",
" grads = jax.grad(loss_fn)(params, x, y)\n",
" return jax.tree_util.tree_map(\n",
" return jax.tree.map(\n",
" lambda p, g: p - LEARNING_RATE * g, params, grads\n",
" )"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/non_trainable.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@
" return -jnp.sum(labels * jax.nn.log_softmax(logits)) / labels.shape[0]\n",
"\n",
"def sgd_step(params, grads, *, lr):\n",
" return jax.tree_util.tree_map(lambda p, g: p - g * lr, params, grads)\n",
" return jax.tree.map(lambda p, g: p - g * lr, params, grads)\n",
"\n",
"def train_step(trainable_params, non_trainable_params, x, y):\n",
" # NOTE: We will only compute gradients wrt `trainable_params`.\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/notebooks/parameter_sharing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"\n",
"def parameter_shapes(params):\n",
" \"\"\"Make printing parameters a little more readable.\"\"\"\n",
" return jax.tree_util.tree_map(lambda p: p.shape, params)\n",
" return jax.tree.map(lambda p: p.shape, params)\n",
"\n",
"\n",
"def transform_and_print_shapes(fn, x_shape=(2, 3)):\n",
Expand All @@ -80,8 +80,8 @@
" print(parameter_shapes(params))\n",
"\n",
"def assert_all_equal(params_1, params_2):\n",
" assert all(jax.tree_util.tree_leaves(\n",
" jax.tree_util.tree_map(lambda a, b: (a == b).all(), params_1, params_2)))"
" assert all(jax.tree.leaves(\n",
" jax.tree.map(lambda a, b: (a == b).all(), params_1, params_2)))"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@
"init, apply = hk.transform(eval_shape_net) \n",
"params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))\n",
"apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))\n",
"jax.tree_util.tree_map(lambda x: x.shape, params)"
"jax.tree.map(lambda x: x.shape, params)"
]
},
{
Expand Down
9 changes: 5 additions & 4 deletions examples/imagenet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,12 @@ def cast_fn(batch):


def _device_put_sharded(sharded_tree, devices):
leaves, treedef = jax.tree_util.tree_flatten(sharded_tree)
leaves, treedef = jax.tree.flatten(sharded_tree)
n = leaves[0].shape[0]
return jax.device_put_sharded([
jax.tree_util.tree_unflatten(treedef, [l[i] for l in leaves])
for i in range(n)], devices)
return jax.device_put_sharded(
[jax.tree.unflatten(treedef, [l[i] for l in leaves]) for i in range(n)],
devices,
)


def double_buffer(ds: Iterable[Batch]) -> Iterator[Batch]:
Expand Down
7 changes: 4 additions & 3 deletions examples/imagenet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def evaluate(

# Params/state are sharded per-device during training. We just need the copy
# from the first device (since we do not pmap evaluation at the moment).
params, state = jax.tree_util.tree_map(lambda x: x[0], (params, state))
params, state = jax.tree.map(lambda x: x[0], (params, state))
test_dataset = dataset.load(split,
is_training=False,
batch_dims=[FLAGS.eval_batch_size],
Expand Down Expand Up @@ -323,8 +323,9 @@ def main(argv):

# Log progress at fixed intervals.
if step_num and step_num % log_every == 0:
train_scalars = jax.tree_util.tree_map(
lambda v: np.mean(v).item(), jax.device_get(train_scalars))
train_scalars = jax.tree.map(
lambda v: np.mean(v).item(), jax.device_get(train_scalars)
)
logging.info('[Train %s/%s] %s',
step_num, num_train_steps, train_scalars)

Expand Down
2 changes: 1 addition & 1 deletion examples/impala/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def unroll(self, rng_key, frame_count: int, params: hk.Params,

# Pack the trajectory and reset parent state.
trajectory = jax.device_get(self._traj)
trajectory = jax.tree_util.tree_map(lambda *xs: np.stack(xs), *trajectory)
trajectory = jax.tree.map(lambda *xs: np.stack(xs), *trajectory)
self._timestep = timestep
self._agent_state = agent_state
# Keep the bootstrap timestep for next trajectory.
Expand Down
18 changes: 8 additions & 10 deletions examples/impala/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def __init__(self, num_actions: int, obs_spec: Nest,
@functools.partial(jax.jit, static_argnums=0)
def initial_params(self, rng_key):
"""Initializes the agent params given the RNG key."""
dummy_inputs = jax.tree_util.tree_map(
lambda t: np.zeros(t.shape, t.dtype), self._obs_spec)
dummy_inputs = jax.tree.map(
lambda t: np.zeros(t.shape, t.dtype), self._obs_spec
)
dummy_inputs = util.preprocess_step(dm_env.restart(dummy_inputs))
dummy_inputs = jax.tree_util.tree_map(
lambda t: t[None, None, ...], dummy_inputs)
dummy_inputs = jax.tree.map(lambda t: t[None, None, ...], dummy_inputs)
return self._init_fn(rng_key, dummy_inputs, self.initial_state(1))

@functools.partial(jax.jit, static_argnums=(0, 1))
Expand All @@ -84,15 +84,13 @@ def step(
) -> tuple[AgentOutput, Nest]:
"""For a given single-step, unbatched timestep, output the chosen action."""
# Pad timestep, state to be [T, B, ...] and [B, ...] respectively.
timestep = jax.tree_util.tree_map(lambda t: t[None, None, ...], timestep)
state = jax.tree_util.tree_map(lambda t: t[None, ...], state)
timestep = jax.tree.map(lambda t: t[None, None, ...], timestep)
state = jax.tree.map(lambda t: t[None, ...], state)

net_out, next_state = self._apply_fn(params, timestep, state)
# Remove the padding from above.
net_out = jax.tree_util.tree_map(
lambda t: jnp.squeeze(t, axis=(0, 1)), net_out)
next_state = jax.tree_util.tree_map(
lambda t: jnp.squeeze(t, axis=0), next_state)
net_out = jax.tree.map(lambda t: jnp.squeeze(t, axis=(0, 1)), net_out)
next_state = jax.tree.map(lambda t: jnp.squeeze(t, axis=0), next_state)
# Sample an action and return.
action = hk.multinomial(rng_key, net_out.policy_logits, num_samples=1)
action = jnp.squeeze(action, axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion examples/impala/haiku_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def initial_state(self, batch_size):
return self._core.initial_state(batch_size)

def __call__(self, x: dm_env.TimeStep, state):
x = jax.tree_util.tree_map(lambda t: t[None, ...], x)
x = jax.tree.map(lambda t: t[None, ...], x)
return self.unroll(x, state)

def unroll(self, x, state):
Expand Down
12 changes: 5 additions & 7 deletions examples/impala/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,17 @@ def _loss(
trajectories: util.Transition,
) -> tuple[jax.Array, dict[str, jax.Array]]:
"""Compute vtrace-based actor-critic loss."""
initial_state = jax.tree_util.tree_map(
lambda t: t[0], trajectories.agent_state)
initial_state = jax.tree.map(lambda t: t[0], trajectories.agent_state)
learner_outputs = self._agent.unroll(theta, trajectories.timestep,
initial_state)
v_t = learner_outputs.values[1:]
# Remove bootstrap timestep from non-timesteps.
_, actor_out, _ = jax.tree_util.tree_map(lambda t: t[:-1], trajectories)
learner_outputs = jax.tree_util.tree_map(lambda t: t[:-1], learner_outputs)
_, actor_out, _ = jax.tree.map(lambda t: t[:-1], trajectories)
learner_outputs = jax.tree.map(lambda t: t[:-1], learner_outputs)
v_tm1 = learner_outputs.values

# Get the discount, reward, step_type from the *next* timestep.
timestep = jax.tree_util.tree_map(lambda t: t[1:], trajectories.timestep)
timestep = jax.tree.map(lambda t: t[1:], trajectories.timestep)
discounts = timestep.discount * self._discount_factor
rewards = timestep.reward
if self._max_abs_reward > 0:
Expand Down Expand Up @@ -179,8 +178,7 @@ def host_to_device_worker(self):

assert len(batch) == self._batch_size
# Prepare for consumption, then put batch onto device.
stacked_batch = jax.tree_util.tree_map(
lambda *xs: np.stack(xs, axis=1), *batch)
stacked_batch = jax.tree.map(lambda *xs: np.stack(xs, axis=1), *batch)
self._device_q.put(jax.device_put(stacked_batch))

# Clean out the built-up batch.
Expand Down
16 changes: 7 additions & 9 deletions examples/impala_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def step(
timestep: dm_env.TimeStep,
) -> tuple[jax.Array, jax.Array]:
"""Steps on a single observation."""
timestep = jax.tree_util.tree_map(lambda t: jnp.expand_dims(t, 0), timestep)
timestep = jax.tree.map(lambda t: jnp.expand_dims(t, 0), timestep)
logits, _ = self._net(params, timestep)
logits = jnp.squeeze(logits, axis=0)
action = hk.multinomial(rng, logits, num_samples=1)
Expand All @@ -102,13 +102,12 @@ def loss(self, params: hk.Params, trajs: Transition) -> jax.Array:
baseline_tp1 = baseline_with_bootstrap[1:]

# Remove bootstrap timestep from non-observations.
_, actions, behavior_logits = jax.tree_util.tree_map(
lambda t: t[:-1], trajs)
_, actions, behavior_logits = jax.tree.map(lambda t: t[:-1], trajs)
learner_logits = learner_logits[:-1]

# Shift step_type/reward/discount back by one, so that actions match the
# timesteps caused by the action.
timestep = jax.tree_util.tree_map(lambda t: t[1:], trajs.timestep)
timestep = jax.tree.map(lambda t: t[1:], trajs.timestep)
discount = timestep.discount * self._discount
# The step is uninteresting if we transitioned LAST -> FIRST.
mask = jnp.not_equal(timestep.step_type, int(dm_env.StepType.FIRST))
Expand Down Expand Up @@ -149,7 +148,7 @@ def preprocess_step(ts: dm_env.TimeStep) -> dm_env.TimeStep:
ts = ts._replace(reward=0.)
if ts.discount is None:
ts = ts._replace(discount=1.)
return jax.tree_util.tree_map(np.asarray, ts)
return jax.tree.map(np.asarray, ts)


def run_actor(
Expand Down Expand Up @@ -180,7 +179,7 @@ def run_actor(
state.reward)

# Stack and send the trajectory.
stacked_traj = jax.tree_util.tree_map(lambda *ts: np.stack(ts), *traj)
stacked_traj = jax.tree.map(lambda *ts: np.stack(ts), *traj)
enqueue_traj(stacked_traj)
# Reset the trajectory, keeping the last timestep.
traj = traj[-1:]
Expand Down Expand Up @@ -222,8 +221,7 @@ def run(*, trajectories_per_actor, num_actors, unroll_len):
# Initialize the optimizer state.
sample_ts = env.reset()
sample_ts = preprocess_step(sample_ts)
ts_with_batch = jax.tree_util.tree_map(
lambda t: np.expand_dims(t, 0), sample_ts)
ts_with_batch = jax.tree.map(lambda t: np.expand_dims(t, 0), sample_ts)
params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch)
opt_state = opt.init(params)

Expand All @@ -236,7 +234,7 @@ def dequeue():
batch = []
for _ in range(batch_size):
batch.append(q.get())
batch = jax.tree_util.tree_map(lambda *ts: np.stack(ts, axis=1), *batch)
batch = jax.tree.map(lambda *ts: np.stack(ts, axis=1), *batch)
return jax.device_put(batch)

# Start the actors.
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def loss(params: hk.Params, batch: Batch) -> jax.Array:
labels = jax.nn.one_hot(batch.label, NUM_CLASSES)

l2_regulariser = 0.5 * sum(
jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
jnp.sum(jnp.square(p)) for p in jax.tree.leaves(params)
)
log_likelihood = jnp.sum(labels * jax.nn.log_softmax(logits))

return -log_likelihood / batch_size + 1e-4 * l2_regulariser
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist_gan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
"\n",
"\n",
"def tree_shape(xs):\n",
" return jax.tree_util.tree_map(lambda x: x.shape, xs)\n",
" return jax.tree.map(lambda x: x.shape, xs)\n",
"\n",
"\n",
"def sparse_softmax_cross_entropy(logits, labels):\n",
Expand Down
15 changes: 6 additions & 9 deletions examples/mnist_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def apply_mask(params: hk.Params, masks: Sequence[hk.Params],
module_sparsity, params)
pruned_params = []
for value, mask in zip(params_to_prune, masks):
pruned_params.append(
jax.tree_util.tree_map(lambda x, y: x * y, value, mask))
pruned_params.append(jax.tree.map(lambda x, y: x * y, value, mask))
params = hk.data_structures.merge(*pruned_params, params_no_prune)
return params

Expand All @@ -155,18 +154,17 @@ def map_fn(x: jax.Array, sparsity: float) -> jax.Array:

for tree, sparsity in zip(params_to_prune, sparsities):
map_fn_sparsity = functools.partial(map_fn, sparsity=sparsity)
mask = jax.tree_util.tree_map(map_fn_sparsity, tree)
mask = jax.tree.map(map_fn_sparsity, tree)
masks.append(mask)
return masks


@jax.jit
def get_sparsity(params: hk.Params):
"""Calculate the total sparsity and tensor-wise sparsity of params."""
total_params = sum(jnp.size(x) for x in jax.tree_util.tree_leaves(params))
total_nnz = sum(jnp.sum(x != 0.) for x in jax.tree_util.tree_leaves(params))
leaf_sparsity = jax.tree_util.tree_map(
lambda x: jnp.sum(x == 0) / jnp.size(x), params)
total_params = sum(jnp.size(x) for x in jax.tree.leaves(params))
total_nnz = sum(jnp.sum(x != 0.0) for x in jax.tree.leaves(params))
leaf_sparsity = jax.tree.map(lambda x: jnp.sum(x == 0) / jnp.size(x), params)
return total_params, total_nnz, leaf_sparsity


Expand Down Expand Up @@ -221,8 +219,7 @@ def loss(params: hk.Params, batch: Batch) -> jax.Array:
logits = net.apply(params, batch)
labels = jax.nn.one_hot(batch["label"], 10)

l2_loss = 0.5 * sum(
jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))
l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in jax.tree.leaves(params))
softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
softmax_xent /= labels.shape[0]

Expand Down
8 changes: 4 additions & 4 deletions examples/vqvae_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
"source": [
"cifar10 = tfds.as_numpy(tfds.load(\"cifar10\", split=\"train+test\", batch_size=-1))\n",
"del cifar10[\"id\"], cifar10[\"label\"]\n",
"jax.tree_util.tree_map(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)"
"jax.tree.map(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)"
]
},
{
Expand All @@ -101,9 +101,9 @@
},
"outputs": [],
"source": [
"train_data_dict = jax.tree_util.tree_map(lambda x: x[:40000], cifar10)\n",
"valid_data_dict = jax.tree_util.tree_map(lambda x: x[40000:50000], cifar10)\n",
"test_data_dict = jax.tree_util.tree_map(lambda x: x[50000:], cifar10)"
"train_data_dict = jax.tree.map(lambda x: x[:40000], cifar10)\n",
"valid_data_dict = jax.tree.map(lambda x: x[40000:50000], cifar10)\n",
"test_data_dict = jax.tree.map(lambda x: x[50000:], cifar10)"
]
},
{
Expand Down
Loading

0 comments on commit 79b8f1f

Please sign in to comment.