You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I observed performance regressions for bert and bart model inference with jax-v0.4.33 compared to jax-v0.4.31 on both x86 and arm64 cpu platforms. The performance drop was almost 2x. I have root-caused it to the following PR (and commit) that added an async dispatcher for expensive computations. My observation is that this async dispatcher is spawning more threads (almost double the available vcpus on an instance) and hence causing thread over subscription issue on these platforms.
I have tested two AWS EC2 instances:
c7i.xlarge (x86 architecture) and
c8g.xlarge (arm64 architecture)
both had shown performance drop by 2x, and I was able to restore the performance by disabling the async dispatcher with the following setting jax.config.update('jax_cpu_enable_async_dispatch', False)
So, my question is, which platforms and scenarios did this async dispatcher showed performance improvement?
Also, please let me know if there is any other configuration that has to go along with it.
otherwise, I'm curious why it was enabled by default.
commit f255fb700af8a8d2455c42ea1863cb1420ea6da3 (HEAD)
Author: Yue Sheng <[email protected]>
Date: Mon Aug 5 17:47:34 2024 -0700
Async dispatch expensive computations on the JAX CPU backend. By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior.
The text was updated successfully, but these errors were encountered:
I observed performance regressions for bert and bart model inference with jax-v0.4.33 compared to jax-v0.4.31 on both x86 and arm64 cpu platforms. The performance drop was almost 2x. I have root-caused it to the following PR (and commit) that added an async dispatcher for expensive computations. My observation is that this async dispatcher is spawning more threads (almost double the available vcpus on an instance) and hence causing thread over subscription issue on these platforms.
I have tested two AWS EC2 instances:
c7i.xlarge (x86 architecture) and
c8g.xlarge (arm64 architecture)
both had shown performance drop by 2x, and I was able to restore the performance by disabling the async dispatcher with the following setting
jax.config.update('jax_cpu_enable_async_dispatch', False)
So, my question is, which platforms and scenarios did this async dispatcher showed performance improvement?
Also, please let me know if there is any other configuration that has to go along with it.
otherwise, I'm curious why it was enabled by default.
PR: #15740
The text was updated successfully, but these errors were encountered: