Single environment performance on GPU worse than on CPU #280
nico-bohlinger
started this conversation in
General
Replies: 1 comment
-
Hi @nico-bohlinger , I think the lower steps-per-sec is somewhat expected for a single environment, but you'd be able to batch the envs on GPU and get a higher overall steps-per-sec. You will want to ask on the gym repo how to batch the envs, but in brax it would look something like:
You likely want to keep the RL algo running on the same device as the env.step to avoid data transfers btwn the host and the gpu/tpu device The agents in https://github.com/google/brax/tree/main/brax/training/agents are a great reference |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to use a single(!) ant environment for my project and checked the steps per seconds with this script:
Using my CPU for jitting I get around 2800 steps per second, which is better than what I would get with a MuJoCo ant environment. But if I comment out the line for putting jax in CPU mode and use my GPU I only get 700 steps per second.
Jitting with the GPU is amazing for many environments but seems to be not that great with only one environment. On the other hand using the GPU is far better for the RL algorithm I want to use later
Is there a way of improving single environment performance on GPU?
Or is there maybe a way of efficiently jitting the environment with the CPU and still be using the GPU for my RL algorithm in Jax?
Beta Was this translation helpful? Give feedback.
All reactions