Get index in flax linen scan #3135
-
Hello is there a way to access the index of a current iteration inside the scan that would not be traced ? I know that I can iterate by multiple arrays of the same first dim shape and can put jnp.arange(loop_length) as one thos would give me index but a traced one ; hence I will not be able to use it for example for indexing etc . Thanks for help ! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
An untraced iteration count would essentially mean you have different Python code across loop iterations, which is not possible. You mention you want to index based on the iteration count. This should be possible but you might have run into issues when trying to slice like this:
JAX needs to be able to statically infer the size of your slices and in this case (i + 3) is a tracer and it's not clever enough to realize the size is always 3. Instead you can do the following:
I think this way you can write your loop even if i is a tracer |
Beta Was this translation helpful? Give feedback.
I think you have a few options:
What is the fastest option depends on hardware and the specific ops and shapes your are using. For GPUs the first and last option will almost certainly be the fastest