You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is there any way to extract the carry values of every iteration when using nn.scan? For example, if I pass an entire sequence of values to the RNN, scan will iterate over that array and concatenate the output of the GRU cell at every entry of the sequence. Is there a way to also get the carry values in a similar fashion?
Currently, I'm simply passing one entry of the sequence at a time and returning the carry at every step. However, this approach is slow.
I have the following RNN using a GRU cell:
classSimpleGRU(nn.Module):
"""A simple unidirectional RNN."""hidden_size: intout_dim: int@nn.compactdef__call__(self, carry, x, inspect=False):
cell=nn.scan(nn.GRUCell,
variable_broadcast='params',
in_axes=1,
out_axes=1,
split_rngs={'params': False},
)
new_carry, cell_out=cell(self.hidden_size)(carry, x)
dense_out=nn.Dense(self.out_dim, use_bias=True)(cell_out)
ifinspect:
returnnew_carry, dense_outreturndense_outdefinitialize_carry(self, input_shape):
# Use fixed random key since default state init fn is just zeros.returnnn.GRUCell(self.hidden_size, parent=None).initialize_carry(
jax.random.key(0), input_shape
)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Is there any way to extract the carry values of every iteration when using
nn.scan
? For example, if I pass an entire sequence of values to the RNN,scan
will iterate over that array and concatenate the output of the GRU cell at every entry of the sequence. Is there a way to also get the carry values in a similar fashion?Currently, I'm simply passing one entry of the sequence at a time and returning the carry at every step. However, this approach is slow.
I have the following RNN using a GRU cell:
And this is my current solution:
Any help would be greatly appreciated.
Beta Was this translation helpful? Give feedback.
All reactions