using flax sow inside a function with scanning #3727
Unanswered
TalKachman
asked this question in
Q&A
Replies: 1 comment
-
You'll need to specify For example: import flax.linen as nn
import jax
import jax.numpy as jnp
class Block(nn.Module):
width: int
@nn.compact
def __call__(self, carry, unused_inputs):
carry = nn.Dense(self.width)(carry)
carry = nn.relu(carry)
self.sow('intermediates', 'carry', carry)
return carry, None
class MLP(nn.Module):
width: int
depth: int
@nn.compact
def __call__(self, x):
self.sow('intermediates', 'x', x)
carry, unused_outputs = nn.scan(
Block,
# If 'intermediates' is not listed in `variable_axes` below, then
# `self.sow('intermediates', ...)` will not work inside the `nn.scan()`.
variable_axes={'params': 0, 'intermediates': 0},
split_rngs={'params': True},
length=self.depth,
)(
width=self.width,
)(
x, None,
)
return carry
model = MLP(width=2, depth=3)
x = jnp.zeros([1, 2])
variables = model.init(jax.random.PRNGKey(0), x)
out, state = model.apply(variables, x, mutable=['intermediates'])
jax.tree.map(lambda x: x.shape, dict(out=out, state=state, variables=variables)) will output
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi Everyone!
Is there a way to use the sow method within a scan function? ideally I would love to save some metadata and the ability to look at states.
Small example of what I want to try:
this gives an empty state. how for example can I access the output within the scan ?
Many thanks!
Beta Was this translation helpful? Give feedback.
All reactions