Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Super slow on Mac MPS #16

Open
ran-weii opened this issue Nov 12, 2024 · 2 comments
Open

Super slow on Mac MPS #16

ran-weii opened this issue Nov 12, 2024 · 2 comments

Comments

@ran-weii
Copy link

Hi, a follow up on #15: I compared cpu vs mps and compile vs no compile on halfcheetah for 100k steps using SAC. It shows that mps is significantly slower than cpu, and aot_eager backend makes compile slower and much more so for cpu, tho the default inductor backend makes compile quite a bit faster for cpu but doesn't work for mps.

Screenshot 2024-11-12 at 9 19 03 AM

Code change is the following:

if args.compile:
        mode = None  # "reduce-overhead" if not args.cudagraphs else None
        backend = "aot_eager" if device == torch.device("mps") else "inductor"
        update_main = torch.compile(update_main, mode=mode, backend=backend)
        update_pol = torch.compile(update_pol, mode=mode, backend=backend)
        policy = torch.compile(policy, mode=mode, backend=backend)
@vmoens
Copy link
Contributor

vmoens commented Nov 14, 2024

I'm looking into this. There aren't any more graph break but even collecting data is slower on MPS.
I ran this

python -m cProfile -o prof.prof leanrl/sac_continuous_action_torchcompile.py --compile --learning_starts=100 --total_timesteps=5000

and there are several interesting things in the profile:

  • 13% of runtime is spent in torch.randint
    image
    This benchmark is also funny to look at
<torch.utils.benchmark.utils.common.Measurement object at 0x1147fde10>
torch.randint(1_000_000, (50,), device='cpu')
  Median: 1.87 us
  IQR:    0.08 us (1.83 to 1.92)
  5144 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x10022abc0>
torch.randint(1_000_000, (50,), device='cpu').to('mps')
  Median: 596.19 us
  IQR:    57.96 us (578.77 to 636.73)
  18 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x1147fee90>
torch.randint(1_000_000, (50,), device='mps')
  Median: 20.42 us
  IQR:    1.67 us (20.04 to 21.71)
  449 measurements, 1 runs per measurement, 1 thread

I raised an issue about this: pytorch/pytorch#140706

  • Another big chunk of time is spent in torch fx-related functions, which is also quite weird:
    image

I'll keep you posted, but working with an MPS backend may not be a suitable option for the time being!

@ran-weii
Copy link
Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants