How To Implement An RNN #51
-
Im trying to implement an RNN using scan but im kind of unsure how to do scans correctly using penzai. Here is my RNN Cell and its associated RNN in Flax: class LSTMCell(Cell):
features: int
gate1_fn: Callable[..., Any] = nn.sigmoid
gate2_fn: Callable[..., Any] = nn.tanh
kernel_init: Initializer = nn.initializers.lecun_normal()
bias_init: Initializer = nn.initializers.zeros_init()
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
carry_init: Initializer = nn.initializers.zeros_init()
def setup(self) -> None:
self.dense_i = nn.Dense(
features=4 * self.features,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=True,
dtype=self.dtype,
param_dtype=self.param_dtype,
name=f"i_layer",
)
self.dense_h = nn.Dense(
features=4 * self.features,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=True,
dtype=self.dtype,
param_dtype=self.param_dtype,
name=f"h_layer",
)
def __call__(
self,
carry: Carry,
x: Array,
):
h_t, c_t = carry
gates_i = self.dense_i(x)
gates_h = self.dense_h(h_t)
# get the gate outputs
i_t, f_t, g_t, o_t = jnp.split(gates_i + gates_h, 4, axis=-1)
i_t = self.gate1_fn(i_t)
f_t = self.gate1_fn(f_t)
o_t = self.gate1_fn(o_t)
g_t = self.gate2_fn(g_t)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * self.gate2_fn(c_t)
return (h_t, c_t), h_t
@nn.nowrap
def initialize_carry(
self, rng: PRNGKey, shape: Tuple[int, ...]
) -> Tuple[Array, Array]:
key1, key2 = jr.split(rng)
c = self.carry_init(key1, shape, self.param_dtype)
h = self.carry_init(key2, shape, self.param_dtype)
return (h, c) class RNN(_RNN):
features: int
cell_args: Dict
carry_shape: Tuple
_cell: Cell = LSTMCell
def setup(self) -> None:
self.scan_cell = nn.scan(
target=self._cell,
in_axes=1,
out_axes=1,
variable_broadcast="params",
split_rngs={"params": False},
)(**self.cell_args)
def __call__(
self,
x: Array,
) -> Tuple[Array, Array]:
B, T, C = x.shape
carry = self.scan_cell.initialize_carry(rng=jr.key(0), shape=self.carry_shape)
carry, stacked = self.scan_cell(carry, x)
return carry, stacked Im unsure how to do the same/similar thing in penzai. Im hoping I could get some help |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Cell Implementation (attempt) in Penzai: @pz.pytree_dataclass
class LSTMCell(pz.nn.Layer):
features: int
gate1_fn: Callable[..., Any] = jax.nn.sigmoid
gate2_fn: Callable[..., Any] = jax.nn.tanh
affine_i: pz.nn.Layer
affine_h: pz.nn.Layer
@classmethod
def from_config(
cls,
name: str,
init_base_rng: jax.Array | None,
features: int,
input_size: int,
gate1_fn: Callable[..., Any] = jax.nn.sigmoid,
gate2_fn: Callable[..., Any] = jax.nn.tanh,
kernel_init: Callable = jax.nn.initializers.lecun_normal(),
bias_init: Callable = jax.nn.initializers.zeros,
dtype: Any = jnp.float32,
) -> "LSTMCell":
affine_i = pz.nn.Affine.from_config(
name=f"{name}/affine_i",
init_base_rng=init_base_rng,
input_axes={"input": input_size},
output_axes={"gates": 4 * features},
linear_initializer=kernel_init,
bias_initializer=bias_init,
dtype=dtype,
)
affine_h = pz.nn.Affine.from_config(
name=f"{name}/affine_h",
init_base_rng=init_base_rng,
input_axes={"hidden": features},
output_axes={"gates": 4 * features},
linear_initializer=kernel_init,
bias_initializer=bias_init,
dtype=dtype,
)
gate1 = pz.nn.Elementwise(gate1_fn)
gate2 = pz.nn.Elementwise(gate2_fn)
return cls(
features=features,
gate1_fn=gate1,
gate2_fn=gate2,
affine_i=affine_i,
affine_h=affine_h,
)
def __call__(
self,
carry: tuple[pz.nx.NamedArray, pz.nx.NamedArray],
x: pz.nx.NamedArray,
**unused_side_inputs
) -> tuple[tuple[pz.nx.NamedArray, pz.nx.NamedArray], pz.nx.NamedArray]:
# dont quite know how to tag and untag things correct yet but it definitely needs to happen here
h_t, c_t = carry
gates_i = self.affine_i(x)
gates_h = self.affine_h(h_t)
gates = gates_i + gates_h
i_t, f_t, g_t, o_t = pz.nx.nmap(jnp.split)(gates, 4, axis=-1)
i_t = pz.nx.nmap(self.gate1_fn)(i_t)
f_t = pz.nx.nmap(self.gate1_fn)(f_t)
o_t = pz.nx.nmap(self.gate1_fn)(o_t)
g_t = pz.nx.nmap(self.gate2_fn)(g_t)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * pz.nx.nmap(self.gate2_fn)(c_t)
return (h_t, c_t), h_t
def initialize_carry(
self, rng: jax.random.KeyArray, batch_shape: tuple[int, ...]
) -> tuple[pz.nx.NamedArray, pz.nx.NamedArray]:
key1, key2 = jax.random.split(rng)
shape = batch_shape + (self.features,)
c = pz.nx.wrap(jax.nn.initializers.zeros(key1, shape))
h = pz.nx.wrap(jax.nn.initializers.zeros(key2, shape))
return h.tag("hidden"), c.tag("hidden") This is definitely not correct but I am pretty lost when it comes to implementing the |
Beta Was this translation helpful? Give feedback.
-
Thanks for the question! I haven't yet done much with RNNs in Penzai, so I'm not sure yet what the most idiomatic approach would be. But here's my first thoughts. An RNN cell seems different than most Penzai layers because
So it might make sense for import abc
from penzai.experimental.v2 import pz
class RNNCell(pz.Struct, abc.ABC):
"""Abstract base class for RNN cells."""
@abc.abstractmethod
def __call__(
self, carry: Any, input_value: pz.nx.NamedArray, /, **side_inputs: Any
) -> tuple[Any, pz.nx.NamedArray]:
"""Abstract call method for an RNN cell.
Args:
carry: The carry input to the RNN cell.
input_value: The value input (e.g. from previous layers) to the RNN cell.
**side_inputs: Arbitrary side context available to the cell, forwarded
from parent layers.
Returns:
Tuple (new_carry, output_value)
"""
raise NotImplementedError(
"__call__ must be overridden for RNNCell subclasses"
)
@abc.abstractmethod
def initialize_carry(self, input_named_shape: dict[str, Any], **side_inputs: Any) -> Any:
"""Initializes the carry input for the RNN cell.
Args:
input_named_shape: The named shape of the input to the cell.
**side_inputs: Arbitrary side context available to the cell, forwarded
from parent layers.
Returns:
Initial carry input.
"""
raise NotImplementedError(
"initialize_carry must be overridden for RNNCell subclasses"
) Then from typing import Callable, Any
import dataclasses
import functools
lecun_normal_initializer = functools.partial(
pz.nn.variance_scaling_initializer,
scale=1.0,
mode="fan_in",
distribution="normal",
)
@pz.pytree_dataclass
class LSTMCell(RNNCell):
# Child layers (can have parameters)
affine_i: pz.nn.Layer
affine_h: pz.nn.Layer
# Non-pytree attributes
gate1_fn: Callable[..., Any] = dataclasses.field(metadata={"pytree_node": False})
gate2_fn: Callable[..., Any] = dataclasses.field(metadata={"pytree_node": False})
input_axes: dict[str, int] = dataclasses.field(metadata={"pytree_node": False})
carry_axes: dict[str, int] = dataclasses.field(metadata={"pytree_node": False})
@classmethod
def from_config(
cls,
name: str,
init_base_rng: jax.Array | None,
input_axes: dict[str, int],
carry_axes: dict[str, int],
gate1_fn: Callable[..., Any] = jax.nn.sigmoid,
gate2_fn: Callable[..., Any] = jax.nn.tanh,
kernel_init: Callable = lecun_normal_initializer,
bias_init: Callable = pz.nn.zero_initializer,
dtype: Any = jnp.float32,
) -> "LSTMCell":
affine_i = pz.nn.Affine.from_config(
name=f"{name}/affine_i",
init_base_rng=init_base_rng,
input_axes=input_axes,
output_axes={"gates": 4, **carry_axes}, # Use an explicit "gates" axis
linear_initializer=kernel_init,
bias_initializer=bias_init,
dtype=dtype,
)
affine_h = pz.nn.Affine.from_config(
name=f"{name}/affine_h",
init_base_rng=init_base_rng,
input_axes={**carry_axes},
output_axes={"gates": 4, **carry_axes},
linear_initializer=kernel_init,
bias_initializer=bias_init,
dtype=dtype,
)
return cls(
affine_i=affine_i,
affine_h=affine_h,
# No need to wrap the gates in Elementwise if they are stored with "pytree_node": False
gate1_fn=gate1_fn,
gate2_fn=gate2_fn,
input_axes=input_axes,
carry_axes=carry_axes,
)
def initialize_carry(self, input_named_shape, **side_inputs):
batch_shape = {
name: size for name, size in input_named_shape.items()
if name not in self.input_axes
}
carry_shape = {**batch_shape, **self.carry_axes}
return (
pz.nx.zeros(carry_shape), pz.nx.zeros(carry_shape)
)
def __call__(self, carry, input_value, /, **side_inputs):
h_t, c_t = carry
# Passing side inputs to the child layers is a good practice. In this case it's probably not necessary
# since they are just affine layers, but maybe they would eventually include dropout or some other
# more complex logic.
gates_i = self.affine_i(input_value, **side_inputs)
gates_h = self.affine_h(h_t, **side_inputs)
gates = gates_i + gates_h
i_t, f_t, g_t, o_t = gates.untag("gates") # Split over the gates axis
i_t = pz.nx.nmap(self.gate1_fn)(i_t)
f_t = pz.nx.nmap(self.gate1_fn)(f_t)
o_t = pz.nx.nmap(self.gate1_fn)(o_t)
g_t = pz.nx.nmap(self.gate2_fn)(g_t)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * pz.nx.nmap(self.gate2_fn)(c_t)
return (h_t, c_t), h_t A few comments about this:
Finally the RNN scan layer could be something like this: @pz.pytree_dataclass
class ScanRNNCellOverAxis(pz.nn.Layer):
"""Scans an RNN cell over an axis."""
cell: RNNCell
axis: str = dataclasses.field(metadata={"pytree_node": False})
def __call__(self, input_value: pz.nx.NamedArray, **side_inputs):
# Freeze any state variables (and parameters) inside the cell.
# A more fully-featured implementation would probably want to add
# those state variables to the carry as well, but it's easier to assume
# they are constant.
# (With the V2 API, you can generally use JAX transformations freely
# as long as you first either freeze or unbind any variable objects.)
cell, side_inputs = pz.freeze_variables((self.cell, side_inputs))
# Scan acts over the first axis, so move the axis to be scanned over
# to the front.
scan_input = input_value.untag(self.axis).with_positional_prefix()
initial_carry = cell.initialize_carry(scan_input.named_shape, **side_inputs)
# Use ordinary `jax.lax.scan`:
_, scan_out = jax.lax.scan(
lambda c, a: cell(c, a, **side_inputs),
initial_carry,
scan_input,
)
# Re-attach the scanned-over axis name.
return scan_out.tag(self.axis) This implementation could also probably be extended with more features, if needed:
|
Beta Was this translation helpful? Give feedback.
Thanks for the question!
I haven't yet done much with RNNs in Penzai, so I'm not sure yet what the most idiomatic approach would be. But here's my first thoughts.
An RNN cell seems different than most Penzai layers because
So it might make sense for
LSTMCell
to be a different kind ofpz.Struct
, and not be a subclass ofpz.nn.Layer
(because apz.nn.Layer
always takes one primary input, not two). Perhaps something like this: