-
Notifications
You must be signed in to change notification settings - Fork 2
/
turbo.py
82 lines (67 loc) · 2.48 KB
/
turbo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from fastapi import Response, Request
from modal import Image, Stub, gpu, web_endpoint, build, enter
inference_image = Image.debian_slim().pip_install(
"Pillow~=10.1.0",
"diffusers~=0.24.0",
"transformers~=4.35.2",
"accelerate~=0.25",
"safetensors~=0.4.1",
)
stub = Stub("stable-diffusion-xl-turbo", image=inference_image)
with inference_image.imports():
import torch
from diffusers import AutoencoderKL, AutoPipelineForImage2Image
from diffusers.utils import load_image
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from io import BytesIO
import base64
@stub.cls(gpu=gpu.T4(), container_idle_timeout=240)
class Model:
@build()
def download_models(self):
ignore = [
"*.bin",
"*.onnx_data",
"*/diffusion_pytorch_model.safetensors",
]
snapshot_download("stabilityai/sdxl-turbo", ignore_patterns=ignore)
snapshot_download("madebyollin/sdxl-vae-fp16-fix", ignore_patterns=ignore)
@enter()
def enter(self):
self.pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
variant="fp16",
device_map="auto",
vae=AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
device_map="auto",
),
)
@web_endpoint(method="POST")
async def inference(self, request: Request):
data = await request.json()
image = data.get("image")
num_iterations = data.get("num_iterations")
prompt = data.get("prompt")
img_data_in = base64.b64decode(image.split(",")[-1])
byte_stream = BytesIO(img_data_in)
pil_image = PILImage.open(byte_stream)
init_image = load_image(pil_image).resize((512, 512))
num_inference_steps = int(num_iterations)
strength = 0.999 if num_iterations == 2 else 0.75
assert num_inference_steps * strength >= 1
output = self.pipe(
prompt,
image=init_image,
num_inference_steps=num_inference_steps,
strength=strength,
guidance_scale=0.0,
seed=42,
).images[0]
byte_stream = BytesIO()
output.save(byte_stream, format="jpeg")
img_data_out = byte_stream.getvalue()
return Response(content=img_data_out, media_type="image/jpeg")