Skip to content

Commit

Permalink
Now use temp dirs for artifacts for cleaner file management
Browse files Browse the repository at this point in the history
  • Loading branch information
IanNod committed Jan 4, 2025
1 parent aa9f1d6 commit db73500
Showing 1 changed file with 53 additions and 49 deletions.
102 changes: 53 additions & 49 deletions sharktank/tests/models/vae/vae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,52 +36,55 @@
)
import iree.compiler
from collections import OrderedDict
from sharktank.utils.testing import TempDirTestBase


with_vae_data = pytest.mark.skipif("not config.getoption('with_vae_data')")


@with_vae_data
class VaeSDXLDecoderTest(unittest.TestCase):
class VaeSDXLDecoderTest(TempDirTestBase):
def setUp(self):
super().setUp()
hf_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
hf_hub_download(
repo_id=hf_model_id,
local_dir="sdxl_vae",
local_dir="{self._temp_dir}",
local_dir_use_symlinks=False,
revision="main",
filename="vae/config.json",
)
hf_hub_download(
repo_id=hf_model_id,
local_dir="sdxl_vae",
local_dir="{self._temp_dir}",
local_dir_use_symlinks=False,
revision="main",
filename="vae/diffusion_pytorch_model.safetensors",
)
hf_hub_download(
repo_id="amd-shark/sdxl-quant-models",
local_dir="sdxl_vae",
local_dir="{self._temp_dir}",
local_dir_use_symlinks=False,
revision="main",
filename="vae/vae.safetensors",
)
torch.manual_seed(12345)
f32_dataset = import_hf_dataset(
"sdxl_vae/vae/config.json",
["sdxl_vae/vae/diffusion_pytorch_model.safetensors"],
"{self._temp_dir}/vae/config.json",
["{self._temp_dir}/vae/diffusion_pytorch_model.safetensors"],
)
f32_dataset.save("sdxl_vae/vae_f32.irpa", io_report_callback=print)
f32_dataset.save("{self._temp_dir}/vae_f32.irpa", io_report_callback=print)
f16_dataset = import_hf_dataset(
"sdxl_vae/vae/config.json", ["sdxl_vae/vae/vae.safetensors"]
"{self._temp_dir}/vae/config.json", ["sdxl_vae/vae/vae.safetensors"]
)
f16_dataset.save("sdxl_vae/vae_f16.irpa", io_report_callback=print)
f16_dataset.save("{self._temp_dir}/vae_f16.irpa", io_report_callback=print)

def testCompareF32EagerVsHuggingface(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
ref_results = run_torch_vae("sdxl_vae", inputs)
ref_results = run_torch_vae("{self._temp_dir}", inputs)

ds = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa")
ds = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")

results = model.forward(inputs)
Expand All @@ -92,9 +95,9 @@ def testCompareF32EagerVsHuggingface(self):
def testCompareF16EagerVsHuggingface(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
ref_results = run_torch_vae("sdxl_vae", inputs)
ref_results = run_torch_vae("{self._temp_dir}", inputs)

ds = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa")
ds = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")

results = model.forward(inputs.to(torch.float16))
Expand All @@ -104,10 +107,10 @@ def testCompareF16EagerVsHuggingface(self):
def testVaeIreeVsHuggingFace(self):
dtype = getattr(torch, "float32")
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1)
ref_results = run_torch_vae("sdxl_vae", inputs)
ref_results = run_torch_vae("{self._temp_dir}", inputs)

ds_f16 = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa")
ds_f32 = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa")
ds_f16 = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa")
ds_f32 = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa")

model_f16 = VaeDecoderModel.from_dataset(ds_f16).to(device="cpu")
model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu")
Expand All @@ -116,8 +119,8 @@ def testVaeIreeVsHuggingFace(self):
module_f16 = export_vae(model_f16, inputs.to(torch.float16), True)
module_f32 = export_vae(model_f32, inputs, True)

module_f16.save_mlir("sdxl_vae/vae_f16.mlir")
module_f32.save_mlir("sdxl_vae/vae_f32.mlir")
module_f16.save_mlir("{self._temp_dir}/vae_f16.mlir")
module_f32.save_mlir("{self._temp_dir}/vae_f32.mlir")
extra_args = [
"--iree-hal-target-backends=rocm",
"--iree-hip-target=gfx942",
Expand All @@ -134,22 +137,22 @@ def testVaeIreeVsHuggingFace(self):
]

iree.compiler.compile_file(
"sdxl_vae/vae_f16.mlir",
output_file="sdxl_vae/vae_f16.vmfb",
"{self._temp_dir}/vae_f16.mlir",
output_file="{self._temp_dir}/vae_f16.vmfb",
extra_args=extra_args,
)
iree.compiler.compile_file(
"sdxl_vae/vae_f32.mlir",
output_file="sdxl_vae/vae_f32.vmfb",
"{self._temp_dir}/vae_f32.mlir",
output_file="{self._temp_dir}/vae_f32.vmfb",
extra_args=extra_args,
)

iree_devices = get_iree_devices(driver="hip", device_count=1)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="sdxl_vae/vae_f16.vmfb",
module_path="{self._temp_dir}/vae_f16.vmfb",
devices=iree_devices,
parameters_path="sdxl_vae/vae_f16.irpa",
parameters_path="{self._temp_dir}/vae_f16.irpa",
)

input_args = OrderedDict([("inputs", inputs.to(torch.float16))])
Expand All @@ -175,9 +178,9 @@ def testVaeIreeVsHuggingFace(self):
)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="sdxl_vae/vae_f32.vmfb",
module_path="{self._temp_dir}/vae_f32.vmfb",
devices=iree_devices,
parameters_path="sdxl_vae/vae_f32.irpa",
parameters_path="{self._temp_dir}/vae_f32.irpa",
)

input_args = OrderedDict([("inputs", inputs)])
Expand All @@ -200,42 +203,43 @@ def testVaeIreeVsHuggingFace(self):


@with_vae_data
class VaeFluxDecoderTest(unittest.TestCase):
class VaeFluxDecoderTest(TempDirTestBase):
def setUp(self):
super().setUp()
hf_model_id = "black-forest-labs/FLUX.1-dev"
hf_hub_download(
repo_id=hf_model_id,
local_dir="flux_vae",
local_dir="{self._temp_dir}/flux_vae/",
local_dir_use_symlinks=False,
revision="main",
filename="vae/config.json",
)
hf_hub_download(
repo_id=hf_model_id,
local_dir="flux_vae",
local_dir="{self._temp_dir}/flux_vae/",
local_dir_use_symlinks=False,
revision="main",
filename="vae/diffusion_pytorch_model.safetensors",
)
torch.manual_seed(12345)
dataset = import_hf_dataset(
"flux_vae/vae/config.json",
["flux_vae/vae/diffusion_pytorch_model.safetensors"],
"{self._temp_dir}/flux_vae/vae/config.json",
["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
)
dataset.save("flux_vae/vae_bf16.irpa", io_report_callback=print)
dataset.save("{self._temp_dir}/flux_vae_bf16.irpa", io_report_callback=print)
dataset_f32 = import_hf_dataset(
"flux_vae/vae/config.json",
["flux_vae/vae/diffusion_pytorch_model.safetensors"],
"{self._temp_dir}/flux_vae/vae/config.json",
["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"],
target_dtype=torch.float32,
)
dataset_f32.save("flux_vae/vae_f32.irpa", io_report_callback=print)
dataset_f32.save("{self._temp_dir}/flux_vae_f32.irpa", io_report_callback=print)

def testCompareBF16EagerVsHuggingface(self):
dtype = torch.bfloat16
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux")
ref_results = run_flux(inputs, dtype)

ds = Dataset.load("flux_vae/vae_bf16.irpa", file_type="irpa")
ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")

results = model.forward(inputs)
Expand All @@ -247,7 +251,7 @@ def testCompareF32EagerVsHuggingface(self):
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux")
ref_results = run_flux(inputs, dtype)

ds = Dataset.load("flux_vae/vae_f32.irpa", file_type="irpa")
ds = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu", dtype=dtype)

results = model.forward(inputs)
Expand All @@ -261,8 +265,8 @@ def testVaeIreeVsHuggingFace(self):
ref_results = run_flux(inputs.to(dtype), dtype)
ref_results_f32 = run_flux(inputs, torch.float32)

ds = Dataset.load("flux_vae/vae_bf16.irpa", file_type="irpa")
ds_f32 = Dataset.load("flux_vae/vae_f32.irpa", file_type="irpa")
ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
ds_f32 = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")

model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu")
Expand All @@ -271,8 +275,8 @@ def testVaeIreeVsHuggingFace(self):
module = export_vae(model, inputs, True)
module_f32 = export_vae(model_f32, inputs, True)

module.save_mlir("flux_vae/vae_bf16.mlir")
module_f32.save_mlir("flux_vae/vae_f32.mlir")
module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir")
module_f32.save_mlir("{self._temp_dir}/flux_vae_f32.mlir")

extra_args = [
"--iree-hal-target-backends=rocm",
Expand All @@ -290,22 +294,22 @@ def testVaeIreeVsHuggingFace(self):
]

iree.compiler.compile_file(
"flux_vae/vae_bf16.mlir",
output_file="flux_vae/vae_bf16.vmfb",
"{self._temp_dir}/flux_vae_bf16.mlir",
output_file="{self._temp_dir}/flux_vae_bf16.vmfb",
extra_args=extra_args,
)
iree.compiler.compile_file(
"flux_vae/vae_f32.mlir",
output_file="flux_vae/vae_f32.vmfb",
"{self._temp_dir}/flux_vae_f32.mlir",
output_file="{self._temp_dir}/flux_vae_f32.vmfb",
extra_args=extra_args,
)

iree_devices = get_iree_devices(driver="hip", device_count=1)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="flux_vae/vae_bf16.vmfb",
module_path="{self._temp_dir}/flux_vae_bf16.vmfb",
devices=iree_devices,
parameters_path="flux_vae/vae_bf16.irpa",
parameters_path="{self._temp_dir}/flux_vae_bf16.irpa",
)

input_args = OrderedDict([("inputs", inputs)])
Expand All @@ -328,9 +332,9 @@ def testVaeIreeVsHuggingFace(self):
torch.testing.assert_close(ref_results, iree_result, atol=3.3e-2, rtol=4e5)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="flux_vae/vae_f32.vmfb",
module_path="{self._temp_dir}/flux_vae_f32.vmfb",
devices=iree_devices,
parameters_path="flux_vae/vae_f32.irpa",
parameters_path="{self._temp_dir}/flux_vae_f32.irpa",
)

input_args = OrderedDict([("inputs", inputs)])
Expand Down

0 comments on commit db73500

Please sign in to comment.