Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fake_pmap_and_jit has a confusing interface #144

Open
marcvanzee opened this issue Feb 11, 2022 · 0 comments
Open

fake_pmap_and_jit has a confusing interface #144

marcvanzee opened this issue Feb 11, 2022 · 0 comments

Comments

@marcvanzee
Copy link

I spend quite some time figuring out why code in a large codebase was so slow, only to find out that jit was disabled throughout the entire project. This was because the main function was called as follows:

with chex.fake_pmap_and_jit(FLAGS.debug):
  main()

While on first sight it appears as if this indeed disables both pmap and jit if flag debug is set, this in fact only disables pmap and always disables jit!

The reason is that fake_pmap_and_jit take two positional arguments that disable respectively pmap and jit, and they are both True by default. The names of these arguments are somewhat cryptic to me as well: enable_pmap_patching and enable_jit_patching, which actually disable these JAX transformations.

Given these observations, I think the situation would improve if the signature would be:

def fake_pmap_and_jit(*, disable_pmap: bool = True, disable_jit: bool = True)

Then my code above would then look like this:

with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug):
  main()

Which shows clearly we are not setting disable_jit, so we would rewrite this to:

with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug, disable_jit=FLAGS.debug):
  main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant