diff --git a/examples/impala/agent.py b/examples/impala/agent.py index 98a590f7e..a406fa7ab 100644 --- a/examples/impala/agent.py +++ b/examples/impala/agent.py @@ -56,7 +56,7 @@ def __init__(self, num_actions: int, obs_spec: Nest, lambda batch_size: net_factory().initial_state(batch_size))) self._init_fn, self._apply_fn = hk.without_apply_rng( - hk.transform(lambda obs, state: net_factory().unroll(obs, state))) + hk.transform(lambda obs, state: net_factory().unroll(obs, state))) # pytype: disable=attribute-error @functools.partial(jax.jit, static_argnums=0) def initial_params(self, rng_key):