-
Notifications
You must be signed in to change notification settings - Fork 1
/
latent_colors.py
67 lines (55 loc) · 2.27 KB
/
latent_colors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gc
import torch
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
VAES = {
"FLUX": ("black-forest-labs/FLUX.1-dev", "vae"),
"FTMSE": "stabilityai/sd-vae-ft-mse",
"SD3": ("stabilityai/stable-diffusion-3-medium-diffusers", "vae"),
"XL": "stabilityai/sdxl-vae",
}
dtype = torch.float32
device = torch.device("cuda")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
CARDINALS = {
"black": [0.0, 0.0, 0.0],
"white": [1.0, 1.0, 1.0],
"red": [1.0, 0.0, 0.0],
"green": [0.0, 1.0, 0.0],
"blue": [0.0, 0.0, 1.0],
"cyan": [0.0, 1.0, 1.0],
"magenta": [1.0, 0.0, 1.0],
"yellow": [1.0, 1.0, 0.0],
}
@torch.inference_mode()
def measure(vae: str, subfolder: str | None = None) -> dict[str, list[float]]:
gc.collect()
torch.cuda.empty_cache()
cols = {}
encoder = AutoencoderKL.from_pretrained(vae, use_safetensors=True, torch_dtype=dtype, subfolder=subfolder).to(device)
factor = encoder.config.get("scaling_factor", 1)
shift = encoder.config.get("shift_factor", 0)
factor = 1 if factor is None else factor
shift = 0 if shift is None else shift
processor = VaeImageProcessor(2 ** (len(encoder.config.block_out_channels) - 1))
with torch.inference_mode():
for k, v in CARDINALS.items():
# permute to w, h, c, like in PIL
image: torch.Tensor = (
torch.tensor(v).expand([encoder.config.sample_size, encoder.config.sample_size, len(v)]).permute(2, 0, 1).contiguous()
)
# flatten to c, w×h
encoded: torch.Tensor = encoder.encode(processor.preprocess(image)).latent_dist.sample().permute(1, 0, 2, 3).flatten(start_dim=1)
cols[k] = encoded.quantile(0.5, dim=1).sub(shift).mul(factor).tolist()
return cols
if __name__ == "__main__":
print("# Auto-generated by latent_colors.py")
for id, vae in VAES.items():
if type(vae) is tuple:
vae, subfolder = vae
else:
subfolder = None
results = measure(vae, subfolder)
# anything below 1e-6 is pure noise, anything below 1e-4 is unstable
print(f"COLS_{id} = {{{', '.join([f"'{k}': [{', '.join([f"{f:.4f}" for f in v])}]" for k, v in results.items()])}}} # fmt: off")