-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest_lora.py
71 lines (55 loc) · 2.1 KB
/
test_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from t2i_adapters import patch_pipe, Adapter, sketch_extracter
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import torch
from PIL import Image
import numpy as np
if __name__ == "__main__":
device = "cuda:0"
# 0. Define model
model_id = "Linaqruf/anything-v3.0"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
).to(device)
patch_pipe(pipe)
from lora_diffusion import LoRAManager, image_grid
manager = LoRAManager(["./contents/lora_krk.safetensors"], pipe)
# 1. Define Adapter feature extractor
manager.tune([0.9])
for ext_type, prompt in [("keypose", "a photo of <s0-0><s0-1> sitting down")]:
adapter = Adapter.from_pretrained(ext_type).to(device)
# 2. Prepare Condition via adapter.
cond_img_src = Image.open(f"./contents/examples/{ext_type}_1.png")
if ext_type == "sketch":
cond_img = cond_img_src.convert("L")
cond_img = np.array(cond_img) / 255.0
cond_img = torch.from_numpy(cond_img).unsqueeze(0).unsqueeze(0).to(device)
cond_img = (cond_img > 0.5).float()
else:
cond_img = cond_img_src.convert("RGB")
cond_img = np.array(cond_img) / 255.0
cond_img = (
torch.from_numpy(cond_img)
.permute(2, 0, 1)
.unsqueeze(0)
.to(device)
.float()
)
with torch.no_grad():
adapter_features = adapter(cond_img)
pipe.unet.set_adapter_features(adapter_features)
pipe.safety_checker = None
neg_prompt = "out of frame, duplicate, watermark "
torch.manual_seed(1)
n = 1
imgs = pipe(
[prompt] * n,
negative_prompt=[neg_prompt] * n,
num_inference_steps=50,
guidance_scale=4.5,
height=cond_img.shape[2],
width=cond_img.shape[3],
).images
out_imgs = imgs[0]
image_grid([cond_img_src, out_imgs], 1, 2).save(
f"./contents/{ext_type}_lora.jpg"
)