From 56f3d217e871389c9cd11481b8a4316a82c90353 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 2 Jan 2025 09:38:43 -0800 Subject: [PATCH] Add export of Flux-dev transformer and uploading to Azure (#717) There was a slight difference in the Flux schnell and dev variants. Namely, dev has a guidance layer and schenll does not. Fixed some tensor argument element value types as they were always passed as f32 while some of them should use the model's dtype. Refactored a bit the Flux transformer export boilerplate. Added a script that uploads models to Azure. Right now it uploads the Flux transformer models. This can become a part of the CI jobs at some point. --- sharktank/pyproject.toml | 1 + sharktank/requirements-dev.txt | 4 + sharktank/sharktank/models/flux/export.py | 66 ++++- sharktank/sharktank/models/flux/flux.py | 27 +- sharktank/sharktank/models/flux/testing.py | 257 ++++++++++++++++++ .../tools/upload_all_models_to_azure.py | 57 ++++ sharktank/sharktank/types/theta.py | 20 +- sharktank/sharktank/utils/azure.py | 127 +++++++++ sharktank/sharktank/utils/hf_datasets.py | 15 + sharktank/tests/models/flux/flux_test.py | 53 +--- 10 files changed, 580 insertions(+), 47 deletions(-) create mode 100644 sharktank/requirements-dev.txt create mode 100644 sharktank/sharktank/models/flux/testing.py create mode 100644 sharktank/sharktank/tools/upload_all_models_to_azure.py create mode 100644 sharktank/sharktank/utils/azure.py diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml index 01cad409b..09aca178b 100644 --- a/sharktank/pyproject.toml +++ b/sharktank/pyproject.toml @@ -34,6 +34,7 @@ sharktank = ["py.typed", "kernels/templates/*.mlir"] file = ["requirements.txt"] [tool.setuptools.dynamic.optional-dependencies] +dev = {file = ["requirements-dev.txt"]} testing = {file = ["requirements-tests.txt"]} [tool.pytest.ini_options] diff --git a/sharktank/requirements-dev.txt b/sharktank/requirements-dev.txt new file mode 100644 index 000000000..560aaa24c --- /dev/null +++ b/sharktank/requirements-dev.txt @@ -0,0 +1,4 @@ +# Dependencies only required during development. + +azure-identity>=1.19 +azure-storage-blob>=12.24 diff --git a/sharktank/sharktank/models/flux/export.py b/sharktank/sharktank/models/flux/export.py index fae3a5362..404f00413 100644 --- a/sharktank/sharktank/models/flux/export.py +++ b/sharktank/sharktank/models/flux/export.py @@ -5,6 +5,9 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from os import PathLike +import os +from pathlib import Path +import torch from ...export import export_static_model_mlir from ...tools.import_hf_dataset import import_hf_dataset @@ -12,7 +15,7 @@ from ...types import Dataset from ...utils.hf_datasets import get_dataset -flux_transformer_default_batch_sizes = [4] +flux_transformer_default_batch_sizes = [1] def export_flux_transformer_model_mlir( @@ -23,6 +26,31 @@ def export_flux_transformer_model_mlir( export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes) +def export_flux_transformer_iree_parameters( + model: FluxModelV1, parameters_output_path: PathLike +): + model.theta.rename_tensors_to_paths() + # TODO: export properties + dataset = Dataset(root_theta=model.theta, properties={}) + dataset.save(parameters_output_path) + + +def export_flux_transformer( + model: FluxModelV1, + mlir_output_path: PathLike, + parameters_output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + export_flux_transformer_iree_parameters(model, parameters_output_path) + + dataset = Dataset.load(parameters_output_path) + model_with_frozen_theta = FluxModelV1(theta=dataset.root_theta, params=model.params) + model_with_frozen_theta.theta = dataset.root_theta + export_flux_transformer_model_mlir( + model_with_frozen_theta, output_path=mlir_output_path, batch_sizes=batch_sizes + ) + + def export_flux_transformer_from_hugging_face( repo_id: str, mlir_output_path: PathLike, @@ -47,3 +75,39 @@ def export_flux_transformer_from_hugging_face( export_flux_transformer_model_mlir( model, output_path=mlir_output_path, batch_sizes=batch_sizes ) + + +def export_flux_transformer_models(dir: Path): + from .testing import export_dev_random_single_layer + + base_dir = dir / "flux" / "transformer" + os.makedirs(base_dir) + + file_name_base = "black-forest-labs--FLUX.1-dev--black-forest-labs-transformer-bf16" + mlir_path = base_dir / f"{file_name_base}.mlir" + parameters_output_path = base_dir / f"{file_name_base}.irpa" + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", + mlir_output_path=mlir_path, + parameters_output_path=parameters_output_path, + ) + + file_name_base = ( + "black-forest-labs--FLUX.1-schnell--black-forest-labs-transformer-bf16" + ) + mlir_path = base_dir / f"{file_name_base}.mlir" + parameters_output_path = base_dir / f"{file_name_base}.irpa" + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", + mlir_output_path=mlir_path, + parameters_output_path=parameters_output_path, + ) + + file_name_base = "black-forest-labs--FLUX.1-dev--transformer-single-layer-b16" + mlir_path = base_dir / f"{file_name_base}.mlir" + parameters_output_path = base_dir / f"{file_name_base}.irpa" + export_dev_random_single_layer( + dtype=torch.bfloat16, + mlir_output_path=mlir_path, + parameters_output_path=parameters_output_path, + ) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index d99b14ad4..531083ae1 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -11,6 +11,7 @@ from typing import Any, Optional from collections import OrderedDict +from copy import copy import math from dataclasses import dataclass import torch @@ -96,6 +97,7 @@ def __init__(self, theta: Theta, params: FluxParams): theta, ) + self.params = copy(params) self.in_channels = params.in_channels self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: @@ -146,6 +148,8 @@ def __init__(self, theta: Theta, params: FluxParams): LastLayer(theta("final_layer")), ) + self.dtype = self._deduce_dtype() + def forward( self, img: AnyTensor, @@ -193,12 +197,12 @@ def sample_inputs( raise ValueError(f'Only function "forward" is supported. Got "{function}"') # TODO: do not hardcode these but derive the required shapes from the config. - img = torch.rand([batch_size, 1024, 64]) - img_ids = torch.rand([batch_size, 1024, 3]) - txt = torch.rand([batch_size, 512, 4096]) - txt_ids = torch.rand([batch_size, 512, 3]) - timesteps = torch.rand([batch_size]) - y = torch.rand([batch_size, 768]) + img = torch.rand([batch_size, 1024, 64], dtype=self.dtype) + img_ids = torch.rand([batch_size, 1024, 3], dtype=torch.float32) + txt = torch.rand([batch_size, 512, 4096], dtype=self.dtype) + txt_ids = torch.rand([batch_size, 512, 3], dtype=torch.float32) + timesteps = torch.rand([batch_size], dtype=self.dtype) + y = torch.rand([batch_size, 768], dtype=self.dtype) args = tuple() kwargs = OrderedDict( @@ -211,8 +215,19 @@ def sample_inputs( ("y", y), ) ) + + if self.guidance: + kwargs["guidance"] = torch.rand([batch_size], dtype=self.dtype) + return args, kwargs + def _deduce_dtype(self) -> torch.dtype: + dtype = self.theta("img_in.weight").dtype + assert ( + dtype == self.theta("time_in.in_layer.weight").dtype + ), "Inconsistent dtype" + return dtype + ################################################################################ # Layers diff --git a/sharktank/sharktank/models/flux/testing.py b/sharktank/sharktank/models/flux/testing.py new file mode 100644 index 000000000..e0354ff7b --- /dev/null +++ b/sharktank/sharktank/models/flux/testing.py @@ -0,0 +1,257 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch +from os import PathLike + +from .flux import FluxParams, FluxModelV1 +from .export import export_flux_transformer, flux_transformer_default_batch_sizes +from ...types import DefaultPrimitiveTensor, Theta, save_load_theta +from ...layers.testing import ( + make_rand_torch, +) + + +def make_random_theta(config: FluxParams, dtype: torch.dtype): + # TODO: do not hardcode values. + + in_channels = config.in_channels + in_channels2 = 128 + hidden_size = config.hidden_size + mlp_ratio = config.mlp_ratio + mlp_hidden_size = int((mlp_ratio - 1) * hidden_size) + mlp_hidden_size2 = int(mlp_ratio * hidden_size) + mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size) + mlp_hidden_size4 = int((mlp_ratio + 1) * hidden_size) + mlp_hidden_size5 = int((2 * mlp_ratio - 1) * hidden_size) + context_in_dim = config.context_in_dim + time_dim = 256 + vec_dim = config.vec_in_dim + patch_size = 1 + out_channels = config.out_channels + tensor_dict = { + "img_in.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, in_channels), dtype=dtype) + ), + "img_in.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "txt_in.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, context_in_dim), dtype=dtype) + ), + "txt_in.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "time_in.in_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, time_dim), dtype=dtype) + ), + "time_in.in_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "time_in.out_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "time_in.out_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "vector_in.in_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, vec_dim), dtype=dtype) + ), + "vector_in.in_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "vector_in.out_layer.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "vector_in.out_layer.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.img_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "double_blocks.0.img_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.img_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "double_blocks.0.img_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.img_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) + ), + "double_blocks.0.img_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "double_blocks.0.txt_attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "double_blocks.0.txt_attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_attn.qkv.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "double_blocks.0.txt_attn.qkv.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.0.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.txt_mlp.0.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "double_blocks.0.txt_mlp.2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype) + ), + "double_blocks.0.txt_mod.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3,), dtype=dtype) + ), + "double_blocks.0.txt_mod.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype) + ), + "single_blocks.0.norm.key_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "single_blocks.0.norm.query_norm.scale": DefaultPrimitiveTensor( # + data=make_rand_torch((in_channels2,), dtype=dtype) + ), + "single_blocks.0.attn.proj.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size,), dtype=dtype) + ), + "single_blocks.0.attn.proj.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, hidden_size), dtype=dtype) + ), + "single_blocks.0.linear1.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size5,), dtype=dtype) + ), + "single_blocks.0.linear1.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size5, hidden_size), dtype=dtype) + ), + "single_blocks.0.linear2.bias": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size), dtype=dtype) + ), + "single_blocks.0.linear2.weight": DefaultPrimitiveTensor( + data=make_rand_torch((hidden_size, mlp_hidden_size4), dtype=dtype) + ), + "single_blocks.0.modulation.lin.bias": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size,), dtype=dtype) + ), + "single_blocks.0.modulation.lin.weight": DefaultPrimitiveTensor( + data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype) + ), + "final_layer.linear.weight": DefaultPrimitiveTensor( # + data=make_rand_torch( + (patch_size * patch_size * out_channels, hidden_size), dtype=dtype + ) + ), + "final_layer.linear.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((patch_size * patch_size * out_channels,), dtype=dtype) + ), + "final_layer.adaLN_modulation.1.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size * 2, hidden_size), dtype=dtype) + ), + "final_layer.adaLN_modulation.1.bias": DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size * 2,), dtype=dtype) + ), + } + + if config.guidance_embed: + tensor_dict["guidance_in.in_layer.weight"] = DefaultPrimitiveTensor( # + data=make_rand_torch( + ( + hidden_size, + time_dim, + ), + dtype=dtype, + ) + ) + tensor_dict["guidance_in.in_layer.bias"] = DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ) + tensor_dict["guidance_in.out_layer.weight"] = DefaultPrimitiveTensor( # + data=make_rand_torch( + ( + hidden_size, + hidden_size, + ), + dtype=dtype, + ) + ) + tensor_dict["guidance_in.out_layer.bias"] = DefaultPrimitiveTensor( # + data=make_rand_torch((hidden_size,), dtype=dtype) + ) + + return Theta(tensor_dict) + + +def export_dev_random_single_layer( + dtype: torch.dtype, + mlir_output_path: PathLike, + parameters_output_path: PathLike, + batch_sizes: list[int] = flux_transformer_default_batch_sizes, +): + rng_state = torch.get_rng_state() + torch.random.manual_seed(12345) + + dtype = torch.bfloat16 + params = FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=1, + depth_single_blocks=1, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ) + theta = make_random_theta(params, dtype) + flux = FluxModelV1( + theta=theta, + params=params, + ) + + export_flux_transformer( + flux, + mlir_output_path=mlir_output_path, + parameters_output_path=parameters_output_path, + batch_sizes=batch_sizes, + ) + + torch.set_rng_state(rng_state) diff --git a/sharktank/sharktank/tools/upload_all_models_to_azure.py b/sharktank/sharktank/tools/upload_all_models_to_azure.py new file mode 100644 index 000000000..05ce4e51a --- /dev/null +++ b/sharktank/sharktank/tools/upload_all_models_to_azure.py @@ -0,0 +1,57 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ..utils.azure import upload_all_models + +import logging +import argparse + + +def main(args: list[str] = None): + parser = argparse.ArgumentParser( + description=( + "Upload all models to Azure storage. Uploads only if files are different. " + "If they need updating a snapshot will be created before uploading." + ) + ) + parser.add_argument( + "--account-name", type=str, required=True, help="Storage account name." + ) + parser.add_argument("--container-name", type=str, required=True) + parser.add_argument( + "--account-key", + type=str, + default=None, + help=( + "Access key. If not provided, will use environment variable AZURE_STORAGE_KEY" + " as key. If this is not available, will use the default Azure credential." + ), + ) + parser.add_argument( + "--destination-name-prefix", + type=str, + required=True, + help="Name prefix of all blobs that will be uploaded.", + ) + parsed_args = parser.parse_args(args) + + upload_all_models( + account_name=parsed_args.account_name, + container_name=parsed_args.container_name, + destination_name_prefix=parsed_args.destination_name_prefix, + account_key=parsed_args.account_key, + ) + + +if __name__ == "__main__": + # Set the logging level for all azure-storage-* libraries + azure_logger = logging.getLogger("azure.storage") + azure_logger.setLevel(logging.INFO) + + upload_logger = logging.getLogger("sharktank.utils.azure") + upload_logger.setLevel(logging.INFO) + + main() diff --git a/sharktank/sharktank/types/theta.py b/sharktank/sharktank/types/theta.py index 021925169..29b870782 100644 --- a/sharktank/sharktank/types/theta.py +++ b/sharktank/sharktank/types/theta.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Any, Callable, Optional, Union, Collection, Sequence, List - +from tempfile import TemporaryFile import json from pathlib import Path from types import NotImplementedType @@ -32,7 +32,13 @@ REGISTERED_INFERENCE_TENSOR_CLASSES, ) -__all__ = ["Dataset", "flat_to_nested_dict", "Theta", "torch_module_to_theta"] +__all__ = [ + "Dataset", + "flat_to_nested_dict", + "Theta", + "torch_module_to_theta", + "save_load_theta", +] IOReportCallback = Callable[[str], None] @@ -297,6 +303,16 @@ def _norm_name_path(name_parts) -> list[str]: return accum +def save_load_theta(theta: Theta) -> Theta: + """Roundtrip to disk to avoid treating parameters as constants that would appear + in the MLIR.""" + theta.rename_tensors_to_paths() + dataset = Dataset(root_theta=theta, properties={}) + with TemporaryFile(prefix="save_load_theta", suffix=".irpa") as file_path: + dataset.save(file_path) + return Dataset.load(file_path).root_theta + + ################################################################################ # Dataset objects # diff --git a/sharktank/sharktank/utils/azure.py b/sharktank/sharktank/utils/azure.py new file mode 100644 index 000000000..b696f2b0b --- /dev/null +++ b/sharktank/sharktank/utils/azure.py @@ -0,0 +1,127 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient, ContentSettings +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Callable, Optional +import hashlib +import os +import logging + +logger = logging.getLogger(__name__) + + +def calculate_hash(file_path: str) -> str: + hasher = hashlib.md5() + with open(file_path, "rb") as file: + buf = file.read() + hasher.update(buf) + return hasher.digest() + + +def create_blob_service_client( + account_name: str, account_key: Optional[str] = None +) -> BlobServiceClient: + if account_key is None and "AZURE_STORAGE_KEY" in os.environ: + account_key = os.environ["AZURE_STORAGE_KEY"] + if account_key: + connection_string = ( + f"DefaultEndpointsProtocol=https;AccountName={account_name};" + f"AccountKey={account_key};" + "EndpointSuffix=core.windows.net" + ) + return BlobServiceClient.from_connection_string(connection_string) + + credential = DefaultAzureCredential() + account_url = f"https://{account_name}.blob.core.windows.net" + return BlobServiceClient(account_url, credential) + + +def snapshot_and_upload_blob_if_different( + blob_service_client: BlobServiceClient, + container_name: str, + blob_name: str, + file_path: str, +): + blob_client = blob_service_client.get_blob_client(container_name, blob_name) + local_hash = calculate_hash(file_path) + + blob_exists = False + try: + blob_properties = blob_client.get_blob_properties() + existing_hash = blob_properties.content_settings.content_md5 + blob_exists = True + except Exception: + existing_hash = None + + if local_hash == existing_hash: + logger.info(f'Skipping upload to blob "{blob_name}".') + return + + if blob_exists: + blob_client.create_snapshot() + + with open(file_path, "rb") as f: + logger.info(f'Uploading to blob "{blob_name}"...') + content_settings = ContentSettings(content_md5=local_hash) + blob_client.upload_blob(f, overwrite=True, content_settings=content_settings) + logger.info(f'Blob "{blob_name}" uploaded.') + + +def upload_directory( + blob_service_client: BlobServiceClient, + container_name: str, + source_dir: str, + destination_blob_name_prefix: str, +): + for root, dirs, files in os.walk(source_dir): + for file_name in files: + file_path = Path(root) / file_name + blob_name = f"{destination_blob_name_prefix}{os.path.relpath(file_path, source_dir)}" + snapshot_and_upload_blob_if_different( + blob_service_client, container_name, blob_name, file_path + ) + + +def upload_model( + export_fn: Callable[[Path], None], + blob_service_client: BlobServiceClient, + container_name: str, + destination_blob_name_prefix: str, +): + with TemporaryDirectory() as tmp_dir: + export_fn(Path(tmp_dir)) + upload_directory( + blob_service_client, + container_name, + source_dir=tmp_dir, + destination_blob_name_prefix=destination_blob_name_prefix, + ) + + +def upload_all_models( + account_name: str, + container_name: str, + destination_name_prefix: str, + account_key: Optional[str] = None, +): + """Upload all models to Azure. + Will generate temporary export artifacts. + If MD5 hashes match with the existing blobs nothing will be uploaded. + Creates snapshots if files need updating.""" + from ..models.flux.export import export_flux_transformer_models + + blob_service_client = create_blob_service_client(account_name, account_key) + + upload_model( + export_flux_transformer_models, + blob_service_client, + container_name, + destination_name_prefix, + ) + # TODO: add more models here diff --git a/sharktank/sharktank/utils/hf_datasets.py b/sharktank/sharktank/utils/hf_datasets.py index 6893b637a..bc3a08c50 100644 --- a/sharktank/sharktank/utils/hf_datasets.py +++ b/sharktank/sharktank/utils/hf_datasets.py @@ -419,6 +419,21 @@ def alias_dataset(from_name: str, to_name: str): ), ), ) +Dataset( + "black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", + ( + RemoteFile( + "config", + "black-forest-labs/FLUX.1-dev", + "transformer/config.json", + ), + RemoteFile( + "parameters", + "black-forest-labs/FLUX.1-dev", + "flux1-dev.safetensors", + ), + ), +) ################################################################################ diff --git a/sharktank/tests/models/flux/flux_test.py b/sharktank/tests/models/flux/flux_test.py index fc4d23251..ee8e6d82e 100644 --- a/sharktank/tests/models/flux/flux_test.py +++ b/sharktank/tests/models/flux/flux_test.py @@ -13,9 +13,9 @@ FluxParams, ) from sharktank.models.flux.export import ( - export_flux_transformer_model_mlir, export_flux_transformer_from_hugging_face, ) +from sharktank.models.flux.testing import export_dev_random_single_layer import sharktank.ops as ops from sharktank.layers.testing import ( make_rand_torch, @@ -216,52 +216,29 @@ def setUp(self): self.num_heads = 24 self.batch_size = 5 - def testExportBfloat16SingleLayer(self): - dtype = torch.bfloat16 - params = FluxParams( - in_channels=64, - out_channels=64, - vec_in_dim=768, - context_in_dim=4096, - hidden_size=3072, - mlp_ratio=4.0, - num_heads=24, - depth=1, - depth_single_blocks=1, - axes_dim=[16, 56, 56], - theta=10_000, - qkv_bias=True, - guidance_embed=False, - ) - theta = make_random_theta(dtype) - theta = self.save_load_theta(theta) - flux = FluxModelV1( - theta=theta, - params=params, - ) - - export_flux_transformer_model_mlir( - flux, - output_path=self._temp_dir / "model.mlir", - batch_sizes=[self.batch_size], + def testExportDevRandomSingleLayerBf16(self): + export_dev_random_single_layer( + dtype=torch.bfloat16, + batch_sizes=[1], + mlir_output_path=self._temp_dir / "model.mlir", + parameters_output_path=self._temp_dir / "parameters.irpa", ) @with_flux_data - def testExportSchnellFromHuggingFace(self): + def testExportSchnellTransformerFromHuggingFace(self): export_flux_transformer_from_hugging_face( "black-forest-labs/FLUX.1-schnell/black-forest-labs-transformer", mlir_output_path=self._temp_dir / "model.mlir", parameters_output_path=self._temp_dir / "parameters.irpa", ) - def save_load_theta(self, theta: Theta): - # Roundtrip to disk to avoid treating parameters as constants that would appear - # in the MLIR. - theta.rename_tensors_to_paths() - dataset = Dataset(root_theta=theta, properties={}) - file_path = self._temp_dir / "parameters.irpa" - dataset.save(file_path) - return Dataset.load(file_path).root_theta + @with_flux_data + def testExportDevTransformerFromHuggingFace(self): + export_flux_transformer_from_hugging_face( + "black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", + mlir_output_path=self._temp_dir / "model.mlir", + parameters_output_path=self._temp_dir / "parameters.irpa", + ) if __name__ == "__main__":