You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I installed the environment following the instruction, and am attempting to generate 2D compressible NS data.
With numbers=1 in 2D_multi_Rand.yaml, I run sh run_trainset_2D.sh with the following content: nn=1 key=2031 while [ $nn -le 1 ]; do python3 CFD_multi_Hydra.py +args=2D_Multi_Rand.yaml ++args.init_key=$key nn=$(expr $nn + 1) key=$(expr $key + 1) echo "$nn" echo "$key" done
It seems by default the cpu version of jax is used, so I removed CUDA_VISIBLE_DEVICES='0,1,2,3'.
The program completes successfully, with the evolve function taking 0.71s, and the total runtime being 278s. Increasing nx, ny from 128 to 256 will make the program run 80mins (~20x the time for the smaller grid).
I have located the latency to occur at the return of pm_evolve call, where t, DDD, VVx, VVy, VVz, PPP can be computed in <1s. Since pm_evolve is a jax-wrapped version of evolve, the problem must be from pmap or vmap.
I’ve confirmed that the backend is CPU. Do you have any advice on optimizing this or understanding why the return step is taking so long?
Thank you for your time and assistance!
The text was updated successfully, but these errors were encountered:
Thank you for the great work! 👍
I installed the environment following the instruction, and am attempting to generate 2D compressible NS data.
With
numbers=1
in2D_multi_Rand.yaml
, I runsh run_trainset_2D.sh
with the following content:nn=1 key=2031 while [ $nn -le 1 ]; do python3 CFD_multi_Hydra.py +args=2D_Multi_Rand.yaml ++args.init_key=$key nn=$(expr $nn + 1) key=$(expr $key + 1) echo "$nn" echo "$key" done
It seems by default the cpu version of jax is used, so I removed
CUDA_VISIBLE_DEVICES='0,1,2,3'
.The program completes successfully, with the
evolve
function taking 0.71s, and the total runtime being 278s. Increasing nx, ny from 128 to 256 will make the program run 80mins (~20x the time for the smaller grid).I have located the latency to occur at the return of
pm_evolve
call, where t, DDD, VVx, VVy, VVz, PPP can be computed in <1s. Sincepm_evolve
is a jax-wrapped version of evolve, the problem must be from pmap or vmap.I’ve confirmed that the backend is CPU. Do you have any advice on optimizing this or understanding why the return step is taking so long?
Thank you for your time and assistance!
The text was updated successfully, but these errors were encountered: