-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile ae.decode #25
base: main
Are you sure you want to change the base?
Conversation
461db42
to
99cecf1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is great! you can push to an internal H100 model (just don't leave it running 😄) on Replicate to test perf in prod, good to have solid metrics on that before we merge
predict.py
Outdated
@@ -166,12 +167,65 @@ def base_setup( | |||
shared_models=shared_models, | |||
) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - since these flags are just simple little flags we set setup
for dev/schnell predictor, I don't mind adding a separate compile_ae
flag
# the order is important: | ||
# torch.compile has to recompile if it makes invalid assumptions | ||
# about the input sizes. Having higher input sizes first makes | ||
# for fewer recompiles. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any way we can compile once with craftier use of dynamo.mark_dynamic
- add a max=192
on dims 2 & 3? I assume you've tried this, curious how it breaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried max=192
, but it didn't have any effect. Setting torch.compile(dynamic=True)
makes for one fewer recompile, but I should check the runtime performance of that.
Did some H100 benchmarks. flux-schnell 1 image, VAE not compiled
flux-schnell 4 images, VAE not compiled
flux-schnell 4 images, VAE compiled
The VAE speed seems reproducible, where the uncompiled VAE spends a lot of time in nchwToNhwcKernel while the compiled version manages to avoid it. At the same time, I had a cog bug saying |
99cecf1
to
0039a42
Compare
Did you figure out what the |
@jonluca as I understand it, it was a regression in cog and should be fixed when building with 0.9.25 and later. |
It takes about 80 seconds on my machine to compile this. Makes the encoding step about 50% faster on A5000 (0.3 -> 0.2s), haven't tried H100.