You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I spend quite some time figuring out why code in a large codebase was so slow, only to find out that jit was disabled throughout the entire project. This was because the main function was called as follows:
withchex.fake_pmap_and_jit(FLAGS.debug):
main()
While on first sight it appears as if this indeed disables both pmap and jit if flag debug is set, this in fact only disables pmap and always disables jit!
The reason is that fake_pmap_and_jit take two positional arguments that disable respectively pmap and jit, and they are both True by default. The names of these arguments are somewhat cryptic to me as well: enable_pmap_patching and enable_jit_patching, which actually disable these JAX transformations.
Given these observations, I think the situation would improve if the signature would be:
I spend quite some time figuring out why code in a large codebase was so slow, only to find out that
jit
was disabled throughout the entire project. This was because themain
function was called as follows:While on first sight it appears as if this indeed disables both pmap and jit if flag
debug
is set, this in fact only disablespmap
and always disables jit!The reason is that
fake_pmap_and_jit
take two positional arguments that disable respectivelypmap
andjit
, and they are bothTrue
by default. The names of these arguments are somewhat cryptic to me as well:enable_pmap_patching
andenable_jit_patching
, which actually disable these JAX transformations.Given these observations, I think the situation would improve if the signature would be:
Then my code above would then look like this:
Which shows clearly we are not setting
disable_jit
, so we would rewrite this to:The text was updated successfully, but these errors were encountered: