Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile ae.decode #25

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions fp8/flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,15 +615,16 @@ def generate(
)

# prepare inputs
img, img_ids, vec, txt, txt_ids = map(
lambda x: x, # x.contiguous(),
self.prepare(
img=img,
prompt=prompt,
target_device=self.device_flux,
target_dtype=self.dtype,
),
)
with torch.profiler.record_function("prepare"):
img, img_ids, vec, txt, txt_ids = map(
lambda x: x, # x.contiguous(),
self.prepare(
img=img,
prompt=prompt,
target_device=self.device_flux,
target_dtype=self.dtype,
),
)

# dispatch to gpu if offloaded
if self.offload_flow:
Expand All @@ -634,16 +635,17 @@ def generate(
output_imgs = []

for i in range(batch_size):
denoised_img = self.denoise_single_item(
img[i],
img_ids[i],
txt[i],
txt_ids[i],
vec[i],
timesteps,
guidance,
compiling
)
with torch.profiler.record_function("denoise-single-item"):
denoised_img = self.denoise_single_item(
img[i],
img_ids[i],
txt[i],
txt_ids[i],
vec[i],
timesteps,
guidance,
compiling
)
output_imgs.append(denoised_img)
compiling = False

Expand All @@ -655,7 +657,8 @@ def generate(
torch.cuda.empty_cache()

# decode latents to pixel space
img = self.vae_decode(img, height, width)
with torch.profiler.record_function("vae-decode"):
img = self.vae_decode(img, height, width)

return self.as_img_tensor(img)

Expand Down
59 changes: 57 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Optional

import torch
import torch._dynamo as dynamo

torch.set_float32_matmul_precision("high")
torch.backends.cuda.matmul.allow_tf32 = True
Expand Down Expand Up @@ -116,6 +117,7 @@ def base_setup(
compile_fp8: bool = False,
compile_bf16: bool = False,
disable_fp8: bool = False,
compile_ae: bool = False,
) -> None:
self.flow_model_name = flow_model_name
print(f"Booting model {self.flow_model_name}")
Expand Down Expand Up @@ -179,9 +181,62 @@ def base_setup(
if compile_fp8:
self.compile_fp8()

if compile_ae:
self.compile_ae()

if compile_bf16:
self.compile_bf16()

@torch.inference_mode()
def compile_ae(self):
# helpful: export TORCH_COMPILE_DEBUG=1 TORCH_LOGS=dynamic,dynamo

# the order is important:
# torch.compile has to recompile if it makes invalid assumptions
# about the input sizes. Having higher input sizes first makes
# for fewer recompiles.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any way we can compile once with craftier use of dynamo.mark_dynamic - add a max=192 on dims 2 & 3? I assume you've tried this, curious how it breaks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried max=192, but it didn't have any effect. Setting torch.compile(dynamic=True) makes for one fewer recompile, but I should check the runtime performance of that.

vae_sizes = [
[1, 16, 192, 168],
[1, 16, 96, 96],
[1, 16, 96, 168],
[1, 16, 128, 128],
[1, 16, 96, 168],
[1, 16, 80, 192],
[1, 16, 104, 152],
[1, 16, 152, 104],
[1, 16, 136, 112],
[1, 16, 112, 136],
[1, 16, 144, 112],
[1, 16, 112, 144],
[1, 16, 168, 96],
[1, 16, 192, 80],
[4, 16, 128, 128],
]
print("compiling AE")
st = time.time()
device = torch.device("cuda")
if self.offload:
self.ae.decoder.to(device)

self.ae.decoder = torch.compile(self.ae.decoder)

# actual compilation happens when you give it inputs
for f in vae_sizes:
print("Compiling AE for size", f)
x = torch.rand(f, device=device)
dynamo.mark_dynamic(x, 0, min=1, max=4)
dynamo.mark_dynamic(x, 2, min=80)
dynamo.mark_dynamic(x, 3, min=80)
with torch.autocast(
device_type=device.type, dtype=torch.bfloat16, cache_enabled=False
):
self.ae.decode(x)

if self.offload:
self.ae.decoder.cpu()
torch.cuda.empty_cache()
print("compiled AE in ", time.time() - st)

def compile_fp8(self):
print("compiling fp8 model")
st = time.time()
Expand Down Expand Up @@ -463,7 +518,7 @@ def run_falcon_safety_checker(self, image):

class SchnellPredictor(Predictor):
def setup(self) -> None:
self.base_setup("flux-schnell", compile_fp8=True)
self.base_setup("flux-schnell", compile_fp8=True, compile_ae=True)

def predict(
self,
Expand Down Expand Up @@ -513,7 +568,7 @@ def predict(

class DevPredictor(Predictor):
def setup(self) -> None:
self.base_setup("flux-dev", compile_fp8=True)
self.base_setup("flux-dev", compile_fp8=True, compile_ae=True)

def predict(
self,
Expand Down
Loading