From 525b1ccf30c3f532ca90513f5a9bf442607348c6 Mon Sep 17 00:00:00 2001 From: Ian Date: Fri, 3 Jan 2025 19:52:03 -0600 Subject: [PATCH] Now use temp dirs for artifacts for cleaner file management --- sharktank/tests/models/vae/vae_test.py | 102 +++++++++++++------------ 1 file changed, 53 insertions(+), 49 deletions(-) diff --git a/sharktank/tests/models/vae/vae_test.py b/sharktank/tests/models/vae/vae_test.py index 8e15229ce..c6b0e2153 100644 --- a/sharktank/tests/models/vae/vae_test.py +++ b/sharktank/tests/models/vae/vae_test.py @@ -36,52 +36,55 @@ ) import iree.compiler from collections import OrderedDict +from sharktank.utils.testing import TempDirTestBase + with_vae_data = pytest.mark.skipif("not config.getoption('with_vae_data')") @with_vae_data -class VaeSDXLDecoderTest(unittest.TestCase): +class VaeSDXLDecoderTest(TempDirTestBase): def setUp(self): + super().setUp() hf_model_id = "stabilityai/stable-diffusion-xl-base-1.0" hf_hub_download( repo_id=hf_model_id, - local_dir="sdxl_vae", + local_dir="{self._temp_dir}", local_dir_use_symlinks=False, revision="main", filename="vae/config.json", ) hf_hub_download( repo_id=hf_model_id, - local_dir="sdxl_vae", + local_dir="{self._temp_dir}", local_dir_use_symlinks=False, revision="main", filename="vae/diffusion_pytorch_model.safetensors", ) hf_hub_download( repo_id="amd-shark/sdxl-quant-models", - local_dir="sdxl_vae", + local_dir="{self._temp_dir}", local_dir_use_symlinks=False, revision="main", filename="vae/vae.safetensors", ) torch.manual_seed(12345) f32_dataset = import_hf_dataset( - "sdxl_vae/vae/config.json", - ["sdxl_vae/vae/diffusion_pytorch_model.safetensors"], + "{self._temp_dir}/vae/config.json", + ["{self._temp_dir}/vae/diffusion_pytorch_model.safetensors"], ) - f32_dataset.save("sdxl_vae/vae_f32.irpa", io_report_callback=print) + f32_dataset.save("{self._temp_dir}/vae_f32.irpa", io_report_callback=print) f16_dataset = import_hf_dataset( - "sdxl_vae/vae/config.json", ["sdxl_vae/vae/vae.safetensors"] + "{self._temp_dir}/vae/config.json", ["{self._temp_dir}/vae/vae.safetensors"] ) - f16_dataset.save("sdxl_vae/vae_f16.irpa", io_report_callback=print) + f16_dataset.save("{self._temp_dir}/vae_f16.irpa", io_report_callback=print) def testCompareF32EagerVsHuggingface(self): dtype = getattr(torch, "float32") inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) - ref_results = run_torch_vae("sdxl_vae", inputs) + ref_results = run_torch_vae("{self._temp_dir}", inputs) - ds = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa") + ds = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa") model = VaeDecoderModel.from_dataset(ds).to(device="cpu") results = model.forward(inputs) @@ -92,9 +95,9 @@ def testCompareF32EagerVsHuggingface(self): def testCompareF16EagerVsHuggingface(self): dtype = getattr(torch, "float32") inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) - ref_results = run_torch_vae("sdxl_vae", inputs) + ref_results = run_torch_vae("{self._temp_dir}", inputs) - ds = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa") + ds = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa") model = VaeDecoderModel.from_dataset(ds).to(device="cpu") results = model.forward(inputs.to(torch.float16)) @@ -104,10 +107,10 @@ def testCompareF16EagerVsHuggingface(self): def testVaeIreeVsHuggingFace(self): dtype = getattr(torch, "float32") inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1) - ref_results = run_torch_vae("sdxl_vae", inputs) + ref_results = run_torch_vae("{self._temp_dir}", inputs) - ds_f16 = Dataset.load("sdxl_vae/vae_f16.irpa", file_type="irpa") - ds_f32 = Dataset.load("sdxl_vae/vae_f32.irpa", file_type="irpa") + ds_f16 = Dataset.load("{self._temp_dir}/vae_f16.irpa", file_type="irpa") + ds_f32 = Dataset.load("{self._temp_dir}/vae_f32.irpa", file_type="irpa") model_f16 = VaeDecoderModel.from_dataset(ds_f16).to(device="cpu") model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu") @@ -116,8 +119,8 @@ def testVaeIreeVsHuggingFace(self): module_f16 = export_vae(model_f16, inputs.to(torch.float16), True) module_f32 = export_vae(model_f32, inputs, True) - module_f16.save_mlir("sdxl_vae/vae_f16.mlir") - module_f32.save_mlir("sdxl_vae/vae_f32.mlir") + module_f16.save_mlir("{self._temp_dir}/vae_f16.mlir") + module_f32.save_mlir("{self._temp_dir}/vae_f32.mlir") extra_args = [ "--iree-hal-target-backends=rocm", "--iree-hip-target=gfx942", @@ -134,22 +137,22 @@ def testVaeIreeVsHuggingFace(self): ] iree.compiler.compile_file( - "sdxl_vae/vae_f16.mlir", - output_file="sdxl_vae/vae_f16.vmfb", + "{self._temp_dir}/vae_f16.mlir", + output_file="{self._temp_dir}/vae_f16.vmfb", extra_args=extra_args, ) iree.compiler.compile_file( - "sdxl_vae/vae_f32.mlir", - output_file="sdxl_vae/vae_f32.vmfb", + "{self._temp_dir}/vae_f32.mlir", + output_file="{self._temp_dir}/vae_f32.vmfb", extra_args=extra_args, ) iree_devices = get_iree_devices(driver="hip", device_count=1) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( - module_path="sdxl_vae/vae_f16.vmfb", + module_path="{self._temp_dir}/vae_f16.vmfb", devices=iree_devices, - parameters_path="sdxl_vae/vae_f16.irpa", + parameters_path="{self._temp_dir}/vae_f16.irpa", ) input_args = OrderedDict([("inputs", inputs.to(torch.float16))]) @@ -175,9 +178,9 @@ def testVaeIreeVsHuggingFace(self): ) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( - module_path="sdxl_vae/vae_f32.vmfb", + module_path="{self._temp_dir}/vae_f32.vmfb", devices=iree_devices, - parameters_path="sdxl_vae/vae_f32.irpa", + parameters_path="{self._temp_dir}/vae_f32.irpa", ) input_args = OrderedDict([("inputs", inputs)]) @@ -200,42 +203,43 @@ def testVaeIreeVsHuggingFace(self): @with_vae_data -class VaeFluxDecoderTest(unittest.TestCase): +class VaeFluxDecoderTest(TempDirTestBase): def setUp(self): + super().setUp() hf_model_id = "black-forest-labs/FLUX.1-dev" hf_hub_download( repo_id=hf_model_id, - local_dir="flux_vae", + local_dir="{self._temp_dir}/flux_vae/", local_dir_use_symlinks=False, revision="main", filename="vae/config.json", ) hf_hub_download( repo_id=hf_model_id, - local_dir="flux_vae", + local_dir="{self._temp_dir}/flux_vae/", local_dir_use_symlinks=False, revision="main", filename="vae/diffusion_pytorch_model.safetensors", ) torch.manual_seed(12345) dataset = import_hf_dataset( - "flux_vae/vae/config.json", - ["flux_vae/vae/diffusion_pytorch_model.safetensors"], + "{self._temp_dir}/flux_vae/vae/config.json", + ["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"], ) - dataset.save("flux_vae/vae_bf16.irpa", io_report_callback=print) + dataset.save("{self._temp_dir}/flux_vae_bf16.irpa", io_report_callback=print) dataset_f32 = import_hf_dataset( - "flux_vae/vae/config.json", - ["flux_vae/vae/diffusion_pytorch_model.safetensors"], + "{self._temp_dir}/flux_vae/vae/config.json", + ["{self._temp_dir}/flux_vae/vae/diffusion_pytorch_model.safetensors"], target_dtype=torch.float32, ) - dataset_f32.save("flux_vae/vae_f32.irpa", io_report_callback=print) + dataset_f32.save("{self._temp_dir}/flux_vae_f32.irpa", io_report_callback=print) def testCompareBF16EagerVsHuggingface(self): dtype = torch.bfloat16 inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux") ref_results = run_flux(inputs, dtype) - ds = Dataset.load("flux_vae/vae_bf16.irpa", file_type="irpa") + ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa") model = VaeDecoderModel.from_dataset(ds).to(device="cpu") results = model.forward(inputs) @@ -247,7 +251,7 @@ def testCompareF32EagerVsHuggingface(self): inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux") ref_results = run_flux(inputs, dtype) - ds = Dataset.load("flux_vae/vae_f32.irpa", file_type="irpa") + ds = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa") model = VaeDecoderModel.from_dataset(ds).to(device="cpu", dtype=dtype) results = model.forward(inputs) @@ -261,8 +265,8 @@ def testVaeIreeVsHuggingFace(self): ref_results = run_flux(inputs.to(dtype), dtype) ref_results_f32 = run_flux(inputs, torch.float32) - ds = Dataset.load("flux_vae/vae_bf16.irpa", file_type="irpa") - ds_f32 = Dataset.load("flux_vae/vae_f32.irpa", file_type="irpa") + ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa") + ds_f32 = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa") model = VaeDecoderModel.from_dataset(ds).to(device="cpu") model_f32 = VaeDecoderModel.from_dataset(ds_f32).to(device="cpu") @@ -271,8 +275,8 @@ def testVaeIreeVsHuggingFace(self): module = export_vae(model, inputs, True) module_f32 = export_vae(model_f32, inputs, True) - module.save_mlir("flux_vae/vae_bf16.mlir") - module_f32.save_mlir("flux_vae/vae_f32.mlir") + module.save_mlir("{self._temp_dir}/flux_vae_bf16.mlir") + module_f32.save_mlir("{self._temp_dir}/flux_vae_f32.mlir") extra_args = [ "--iree-hal-target-backends=rocm", @@ -290,22 +294,22 @@ def testVaeIreeVsHuggingFace(self): ] iree.compiler.compile_file( - "flux_vae/vae_bf16.mlir", - output_file="flux_vae/vae_bf16.vmfb", + "{self._temp_dir}/flux_vae_bf16.mlir", + output_file="{self._temp_dir}/flux_vae_bf16.vmfb", extra_args=extra_args, ) iree.compiler.compile_file( - "flux_vae/vae_f32.mlir", - output_file="flux_vae/vae_f32.vmfb", + "{self._temp_dir}/flux_vae_f32.mlir", + output_file="{self._temp_dir}/flux_vae_f32.vmfb", extra_args=extra_args, ) iree_devices = get_iree_devices(driver="hip", device_count=1) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( - module_path="flux_vae/vae_bf16.vmfb", + module_path="{self._temp_dir}/flux_vae_bf16.vmfb", devices=iree_devices, - parameters_path="flux_vae/vae_bf16.irpa", + parameters_path="{self._temp_dir}/flux_vae_bf16.irpa", ) input_args = OrderedDict([("inputs", inputs)]) @@ -328,9 +332,9 @@ def testVaeIreeVsHuggingFace(self): torch.testing.assert_close(ref_results, iree_result, atol=3.3e-2, rtol=4e5) iree_module, iree_vm_context, iree_vm_instance = load_iree_module( - module_path="flux_vae/vae_f32.vmfb", + module_path="{self._temp_dir}/flux_vae_f32.vmfb", devices=iree_devices, - parameters_path="flux_vae/vae_f32.irpa", + parameters_path="{self._temp_dir}/flux_vae_f32.irpa", ) input_args = OrderedDict([("inputs", inputs)])