Skip to content

Commit

Permalink
Add export of Flux-dev transformer and uploading to Azure (#717)
Browse files Browse the repository at this point in the history
There was a slight difference in the Flux schnell and dev variants.
Namely, dev has a guidance layer and schenll does not.

Fixed some tensor argument element value types as they were always
passed as f32 while some of them should use the model's dtype.

Refactored a bit the Flux transformer export boilerplate.

Added a script that uploads models to Azure. Right now it uploads the
Flux transformer models.
This can become a part of the CI jobs at some point.
  • Loading branch information
sogartar authored Jan 2, 2025
1 parent a1e632a commit 56f3d21
Show file tree
Hide file tree
Showing 10 changed files with 580 additions and 47 deletions.
1 change: 1 addition & 0 deletions sharktank/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ sharktank = ["py.typed", "kernels/templates/*.mlir"]
file = ["requirements.txt"]

[tool.setuptools.dynamic.optional-dependencies]
dev = {file = ["requirements-dev.txt"]}
testing = {file = ["requirements-tests.txt"]}

[tool.pytest.ini_options]
Expand Down
4 changes: 4 additions & 0 deletions sharktank/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Dependencies only required during development.

azure-identity>=1.19
azure-storage-blob>=12.24
66 changes: 65 additions & 1 deletion sharktank/sharktank/models/flux/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from os import PathLike
import os
from pathlib import Path
import torch

from ...export import export_static_model_mlir
from ...tools.import_hf_dataset import import_hf_dataset
from .flux import FluxModelV1, FluxParams
from ...types import Dataset
from ...utils.hf_datasets import get_dataset

flux_transformer_default_batch_sizes = [4]
flux_transformer_default_batch_sizes = [1]


def export_flux_transformer_model_mlir(
Expand All @@ -23,6 +26,31 @@ def export_flux_transformer_model_mlir(
export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes)


def export_flux_transformer_iree_parameters(
model: FluxModelV1, parameters_output_path: PathLike
):
model.theta.rename_tensors_to_paths()
# TODO: export properties
dataset = Dataset(root_theta=model.theta, properties={})
dataset.save(parameters_output_path)


def export_flux_transformer(
model: FluxModelV1,
mlir_output_path: PathLike,
parameters_output_path: PathLike,
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
):
export_flux_transformer_iree_parameters(model, parameters_output_path)

dataset = Dataset.load(parameters_output_path)
model_with_frozen_theta = FluxModelV1(theta=dataset.root_theta, params=model.params)
model_with_frozen_theta.theta = dataset.root_theta
export_flux_transformer_model_mlir(
model_with_frozen_theta, output_path=mlir_output_path, batch_sizes=batch_sizes
)


def export_flux_transformer_from_hugging_face(
repo_id: str,
mlir_output_path: PathLike,
Expand All @@ -47,3 +75,39 @@ def export_flux_transformer_from_hugging_face(
export_flux_transformer_model_mlir(
model, output_path=mlir_output_path, batch_sizes=batch_sizes
)


def export_flux_transformer_models(dir: Path):
from .testing import export_dev_random_single_layer

base_dir = dir / "flux" / "transformer"
os.makedirs(base_dir)

file_name_base = "black-forest-labs--FLUX.1-dev--black-forest-labs-transformer-bf16"
mlir_path = base_dir / f"{file_name_base}.mlir"
parameters_output_path = base_dir / f"{file_name_base}.irpa"
export_flux_transformer_from_hugging_face(
"black-forest-labs/FLUX.1-dev/black-forest-labs-transformer",
mlir_output_path=mlir_path,
parameters_output_path=parameters_output_path,
)

file_name_base = (
"black-forest-labs--FLUX.1-schnell--black-forest-labs-transformer-bf16"
)
mlir_path = base_dir / f"{file_name_base}.mlir"
parameters_output_path = base_dir / f"{file_name_base}.irpa"
export_flux_transformer_from_hugging_face(
"black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer",
mlir_output_path=mlir_path,
parameters_output_path=parameters_output_path,
)

file_name_base = "black-forest-labs--FLUX.1-dev--transformer-single-layer-b16"
mlir_path = base_dir / f"{file_name_base}.mlir"
parameters_output_path = base_dir / f"{file_name_base}.irpa"
export_dev_random_single_layer(
dtype=torch.bfloat16,
mlir_output_path=mlir_path,
parameters_output_path=parameters_output_path,
)
27 changes: 21 additions & 6 deletions sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from typing import Any, Optional
from collections import OrderedDict
from copy import copy
import math
from dataclasses import dataclass
import torch
Expand Down Expand Up @@ -96,6 +97,7 @@ def __init__(self, theta: Theta, params: FluxParams):
theta,
)

self.params = copy(params)
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
Expand Down Expand Up @@ -146,6 +148,8 @@ def __init__(self, theta: Theta, params: FluxParams):
LastLayer(theta("final_layer")),
)

self.dtype = self._deduce_dtype()

def forward(
self,
img: AnyTensor,
Expand Down Expand Up @@ -193,12 +197,12 @@ def sample_inputs(
raise ValueError(f'Only function "forward" is supported. Got "{function}"')

# TODO: do not hardcode these but derive the required shapes from the config.
img = torch.rand([batch_size, 1024, 64])
img_ids = torch.rand([batch_size, 1024, 3])
txt = torch.rand([batch_size, 512, 4096])
txt_ids = torch.rand([batch_size, 512, 3])
timesteps = torch.rand([batch_size])
y = torch.rand([batch_size, 768])
img = torch.rand([batch_size, 1024, 64], dtype=self.dtype)
img_ids = torch.rand([batch_size, 1024, 3], dtype=torch.float32)
txt = torch.rand([batch_size, 512, 4096], dtype=self.dtype)
txt_ids = torch.rand([batch_size, 512, 3], dtype=torch.float32)
timesteps = torch.rand([batch_size], dtype=self.dtype)
y = torch.rand([batch_size, 768], dtype=self.dtype)

args = tuple()
kwargs = OrderedDict(
Expand All @@ -211,8 +215,19 @@ def sample_inputs(
("y", y),
)
)

if self.guidance:
kwargs["guidance"] = torch.rand([batch_size], dtype=self.dtype)

return args, kwargs

def _deduce_dtype(self) -> torch.dtype:
dtype = self.theta("img_in.weight").dtype
assert (
dtype == self.theta("time_in.in_layer.weight").dtype
), "Inconsistent dtype"
return dtype


################################################################################
# Layers
Expand Down
Loading

0 comments on commit 56f3d21

Please sign in to comment.