Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Added support for static_keynames which lets you specify key names to ignore for jax transformations #64
, all tests seem to be passing I've attached the test log (bash test.sh > test_out.txt)
test_out.txt
I'd love to have a cleaner way to get the static_keynames in the flatten func if anyone has ideas, this way there's no indication to users that static_keynames is something placed in there by chex, but calling it _static_keynames set off the tests for private access.
To demonstrate a use case with this I can dataclassify this:
but when I try to JIT the noise function it will crash because the name in the dataclass is a str which is not a valid jax type.
If I do this instead:
then it won't crash since it will make the name field static
now I can jit it just fine (but if I try to use a static field in a jax transformed function in a non-static way then it will crash)