Skip to content

Commit

Permalink
deprecate _img_open in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Oct 15, 2024
1 parent 241abfa commit d48e1df
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 108 deletions.
126 changes: 61 additions & 65 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@
from ..weight_paths import get_path


def _img_open(path: Path) -> Image.Image:
return Image.open(path) # type: ignore


@pytest.fixture(autouse=True)
def ensure_gc():
# Avoid GPU OOMs
Expand All @@ -68,132 +64,132 @@ def ref_path(test_e2e_path: Path) -> Path:

@pytest.fixture(scope="module")
def cutecat_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cutecat_init.png").convert("RGB")
return Image.open(ref_path / "cutecat_init.png").convert("RGB")


@pytest.fixture(scope="module")
def kitchen_dog(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "kitchen_dog.png").convert("RGB")
return Image.open(ref_path / "kitchen_dog.png").convert("RGB")


@pytest.fixture(scope="module")
def kitchen_dog_mask(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "kitchen_dog_mask.png").convert("RGB")
return Image.open(ref_path / "kitchen_dog_mask.png").convert("RGB")


@pytest.fixture(scope="module")
def woman_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "woman.png").convert("RGB")
return Image.open(ref_path / "woman.png").convert("RGB")


@pytest.fixture(scope="module")
def statue_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "statue.png").convert("RGB")
return Image.open(ref_path / "statue.png").convert("RGB")


@pytest.fixture
def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB")
return Image.open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB")


@pytest.fixture
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
return Image.open(ref_path / "expected_std_sde_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_sde_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")
return Image.open(ref_path / "expected_std_sde_karras_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_euler.png").convert("RGB")
return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB")


@pytest.fixture
def expected_karras_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_karras_random_init.png").convert("RGB")
return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")


@pytest.fixture
def expected_image_std_init_image(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_init_image.png").convert("RGB")
return Image.open(ref_path / "expected_std_init_image.png").convert("RGB")


@pytest.fixture
def expected_image_ella_adapter(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ella_adapter.png").convert("RGB")
return Image.open(ref_path / "expected_image_ella_adapter.png").convert("RGB")


@pytest.fixture
def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_std_inpainting.png").convert("RGB")
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")


@pytest.fixture
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_controlnet_stack.png").convert("RGB")
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")


@pytest.fixture
def expected_image_ip_adapter_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")
return Image.open(ref_path / "expected_image_ip_adapter_woman.png").convert("RGB")


@pytest.fixture
def expected_image_ip_adapter_multi(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")
return Image.open(ref_path / "expected_image_ip_adapter_multi.png").convert("RGB")


@pytest.fixture
def expected_image_ip_adapter_plus_statue(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")
return Image.open(ref_path / "expected_image_ip_adapter_plus_statue.png").convert("RGB")


@pytest.fixture
def expected_image_sdxl_ip_adapter_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_woman.png").convert("RGB")


@pytest.fixture
def expected_image_sdxl_ip_adapter_plus_woman(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")
return Image.open(ref_path / "expected_image_sdxl_ip_adapter_plus_woman.png").convert("RGB")


@pytest.fixture
def expected_image_ip_adapter_controlnet(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")
return Image.open(ref_path / "expected_ip_adapter_controlnet.png").convert("RGB")


@pytest.fixture
def expected_sdxl_ddim_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init.png").convert("RGB")


@pytest.fixture
def expected_sdxl_ddim_random_init_sag(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")
return Image.open(ref_path / "expected_cutecat_sdxl_ddim_random_init_sag.png").convert("RGB")


@pytest.fixture
def expected_sdxl_euler_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")
return Image.open(ref_path / "expected_cutecat_sdxl_euler_random_init.png").convert("RGB")


@pytest.fixture
def expected_style_aligned(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_style_aligned.png").convert(mode="RGB")
return Image.open(ref_path / "expected_style_aligned.png").convert(mode="RGB")


@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
Expand All @@ -207,8 +203,8 @@ def controlnet_data(
request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")

weights_fn = {
"depth": controlnet_depth_weights_path,
Expand All @@ -229,8 +225,8 @@ def controlnet_data_scale_decay(
request: pytest.FixtureRequest,
) -> Iterator[tuple[str, Image.Image, Image.Image, Path]]:
cn_name: str = request.param
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}_scale_decay.png").convert("RGB")

weights_fn = {
"canny": controlnet_canny_weights_path,
Expand All @@ -245,8 +241,8 @@ def controlnet_data_tile(
ref_path: Path,
controlnet_tiles_weights_path: Path,
) -> tuple[Image.Image, Image.Image, Path]:
condition_image = _img_open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore
expected_image = _img_open(ref_path / f"expected_controlnet_tile.png").convert("RGB")
condition_image = Image.open(ref_path / f"low_res_dog.png").convert("RGB").resize((1024, 1024)) # type: ignore
expected_image = Image.open(ref_path / f"expected_controlnet_tile.png").convert("RGB")
return condition_image, expected_image, controlnet_tiles_weights_path


Expand All @@ -256,8 +252,8 @@ def controlnet_data_canny(
controlnet_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "canny"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
return cn_name, condition_image, expected_image, controlnet_canny_weights_path


Expand All @@ -267,8 +263,8 @@ def controlnet_data_depth(
controlnet_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
return cn_name, condition_image, expected_image, controlnet_depth_weights_path


Expand Down Expand Up @@ -336,12 +332,12 @@ def controllora_sdxl_config(
) -> tuple[Image.Image, dict[str, ControlLoraResolvedConfig]]:
name: str = request.param[0]
configs: dict[str, ControlLoraConfig] = request.param[1]
expected_image = _img_open(ref_path / name).convert("RGB")
expected_image = Image.open(ref_path / name).convert("RGB")

loaded_configs = {
config_name: ControlLoraResolvedConfig(
scale=config.scale,
condition_image=_img_open(ref_path / config.condition_path).convert("RGB"),
condition_image=Image.open(ref_path / config.condition_path).convert("RGB"),
weights_path=get_path(config.weights, use_local_weights),
)
for config_name, config in configs.items()
Expand All @@ -356,8 +352,8 @@ def t2i_adapter_data_depth(
t2i_depth_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
name = "depth"
condition_image = _img_open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
condition_image = Image.open(ref_path / f"cutecat_guide_{name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_t2i_adapter_{name}.png").convert("RGB")
return name, condition_image, expected_image, t2i_depth_weights_path


Expand All @@ -367,8 +363,8 @@ def t2i_adapter_xl_data_canny(
t2i_sdxl_canny_weights_path: Path,
) -> tuple[str, Image.Image, Image.Image, Path]:
name = "canny"
condition_image = _img_open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = _img_open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
condition_image = Image.open(ref_path / f"fairy_guide_{name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_t2i_adapter_xl_{name}.png").convert("RGB")
return name, condition_image, expected_image, t2i_sdxl_canny_weights_path


Expand All @@ -377,7 +373,7 @@ def lora_data_pokemon(
ref_path: Path,
lora_pokemon_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_lora_pokemon.png").convert("RGB")
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
tensors = load_tensors(lora_pokemon_weights_path)
return expected_image, tensors

Expand All @@ -387,7 +383,7 @@ def lora_data_dpo(
ref_path: Path,
lora_dpo_weights_path: Path,
) -> tuple[Image.Image, dict[str, torch.Tensor]]:
expected_image = _img_open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
expected_image = Image.open(ref_path / "expected_sdxl_dpo_lora.png").convert("RGB")
tensors = load_from_safetensors(lora_dpo_weights_path)
return expected_image, tensors

Expand All @@ -411,62 +407,62 @@ def lora_sliders(

@pytest.fixture
def scene_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-scene.png").convert("RGB")
return Image.open(ref_path / "inpainting-scene.png").convert("RGB")


@pytest.fixture
def mask_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-mask.png").convert("RGB")
return Image.open(ref_path / "inpainting-mask.png").convert("RGB")


@pytest.fixture
def target_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "inpainting-target.png").convert("RGB")
return Image.open(ref_path / "inpainting-target.png").convert("RGB")


@pytest.fixture
def expected_image_inpainting_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_inpainting_refonly.png").convert("RGB")
return Image.open(ref_path / "expected_inpainting_refonly.png").convert("RGB")


@pytest.fixture
def expected_image_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_refonly.png").convert("RGB")
return Image.open(ref_path / "expected_refonly.png").convert("RGB")


@pytest.fixture
def condition_image_refonly(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "cyberpunk_guide.png").convert("RGB")
return Image.open(ref_path / "cyberpunk_guide.png").convert("RGB")


@pytest.fixture
def expected_image_textual_inversion_random_init(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")
return Image.open(ref_path / "expected_textual_inversion_random_init.png").convert("RGB")


@pytest.fixture
def expected_multi_diffusion(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
return Image.open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")


@pytest.fixture
def expected_multi_diffusion_dpm(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB")
return Image.open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB")


@pytest.fixture
def expected_restart(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
return Image.open(ref_path / "expected_restart.png").convert(mode="RGB")


@pytest.fixture
def expected_freeu(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_freeu.png").convert(mode="RGB")
return Image.open(ref_path / "expected_freeu.png").convert(mode="RGB")


@pytest.fixture
def expected_sdxl_multi_loras(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")
return Image.open(ref_path / "expected_sdxl_multi_loras.png").convert(mode="RGB")


@pytest.fixture
Expand All @@ -476,10 +472,10 @@ def hello_world_assets(ref_path: Path) -> tuple[Image.Image, Image.Image, Image.
image_prompt = assets / "dragon_quest_slime.jpg"
condition_image = assets / "dropy_canny.png"
return (
_img_open(dropy).convert(mode="RGB"),
_img_open(image_prompt).convert(mode="RGB"),
_img_open(condition_image).convert(mode="RGB"),
_img_open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
Image.open(dropy).convert(mode="RGB"),
Image.open(image_prompt).convert(mode="RGB"),
Image.open(condition_image).convert(mode="RGB"),
Image.open(ref_path / "expected_dropy_slime_9752.png").convert(mode="RGB"),
)


Expand Down Expand Up @@ -2639,7 +2635,7 @@ def test_multi_upscaler_small(

@pytest.fixture(scope="module")
def expected_ic_light(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_ic_light.png").convert("RGB")
return Image.open(ref_path / "expected_ic_light.png").convert("RGB")


@pytest.fixture(scope="module")
Expand Down
Loading

0 comments on commit d48e1df

Please sign in to comment.