forked from runpod/serverless-ckpt-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
runpod_infer.py
124 lines (107 loc) · 3.79 KB
/
runpod_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
'''
RunPod | serverless-ckpt-template | runpod_infer.py
Entry point for job requests from RunPod serverless platform.
'''
import os
import sd_runner
import runpod
from runpod.serverless.utils import rp_download, rp_cleanup, rp_upload
from runpod.serverless.utils.rp_validator import validate
MODEL_RUNNER = sd_runner.Predictor()
MODEL_RUNNER.setup()
INPUT_SCHEMA = {
'prompt': {
'type': str,
'required': True
},
'negative_prompt': {
'type': str,
'required': False,
'default': None
},
'width': {
'type': int,
'required': False,
'default': 768,
'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024]
},
'height': {
'type': int,
'required': False,
'default': 768,
'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024]
},
'num_outputs': {
'type': int,
'required': False,
'default': 1,
'constraints': lambda num_outputs: num_outputs in range(1, 4)
},
'num_inference_steps': {
'type': int,
'required': False,
'default': 50,
'constraints': lambda num_inference_steps: num_inference_steps in range(1, 500)
},
'guidance_scale': {
'type': float,
'required': False,
'default': 7.5,
'constraints': lambda guidance_scale: 0 <= guidance_scale <= 20
},
'scheduler': {
'type': str,
'required': False,
'default': 'DPMSolverMultistep',
'constraints': lambda scheduler: scheduler in ['DDIM', 'K_EULER', 'DPMSolverMultistep', 'K_EULER_ANCESTRAL', 'PNDM', 'KLMS']
},
'seed': {
'type': int,
'required': False,
'default': int.from_bytes(os.urandom(2), "big")
}
}
def handler(job):
'''
Takes in raw data from the API call, prepares it for the model.
Passes the data to the model to get the results.
Prepares the resulting output to be returned to the API call.
'''
job_input = job['input']
job_output = []
# -------------------------------- Validation -------------------------------- #
validated_input = validate(job_input, INPUT_SCHEMA)
if 'errors' in validated_input:
return {"errors": validated_input['errors']}
valid_input = validated_input['validated_input']
image_paths = MODEL_RUNNER.predict(
prompt=valid_input['prompt'],
negative_prompt=valid_input['negative_prompt'],
width=valid_input['width'],
height=valid_input['height'],
num_outputs=valid_input['num_outputs'],
num_inference_steps=valid_input['num_inference_steps'],
guidance_scale=valid_input['guidance_scale'],
scheduler=valid_input['scheduler'],
seed=valid_input['seed']
)
for index, img_path in enumerate(image_paths):
image_url = rp_upload.upload_image(job['id'], img_path)
job_output.append({
"image": image_url,
"prompt": job_input["prompt"],
"negative_prompt": job_input["negative_prompt"],
"width": job_input['width'],
"height": job_input['height'],
"num_inference_steps": job_input['num_inference_steps'],
"guidance_scale": job_input['guidance_scale'],
"scheduler": job_input['scheduler'],
"seed": job_input['seed'] + index
})
# Remove downloaded input objects
# rp_cleanup.clean(['input_objects'])
return job_output
# ---------------------------------------------------------------------------- #
# Start the Worker #
# ---------------------------------------------------------------------------- #
runpod.serverless.start({"handler": handler})