-
-
Notifications
You must be signed in to change notification settings - Fork 143
/
globals.py
136 lines (105 loc) · 3.61 KB
/
globals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import struct
from enum import Enum
import aiohttp
from typing import List, Union, Any, Optional
from PIL import Image, ImageOps
from io import BytesIO
from pydantic import BaseModel as PydanticBaseModel
class BaseModel(PydanticBaseModel):
class Config:
arbitrary_types_allowed = True
class Status(Enum):
NOT_STARTED = "not-started"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
UPLOADING = "uploading"
class StreamingPrompt(BaseModel):
workflow_api: Any
auth_token: str
inputs: dict[str, Union[str, bytes, Image.Image]]
running_prompt_ids: set[str] = set()
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
workflow: Any
gpu_event_id: Optional[str] = None
class SimplePrompt(BaseModel):
status_endpoint: Optional[str]
file_upload_endpoint: Optional[str]
token: Optional[str]
workflow_api: dict
status: Status = Status.NOT_STARTED
progress: set = set()
last_updated_node: Optional[str] = None
uploading_nodes: set = set()
done: bool = False
is_realtime: bool = False
start_time: Optional[float] = None
gpu_event_id: Optional[str] = None
sockets = dict()
prompt_metadata: dict[str, SimplePrompt] = {}
streaming_prompt_metadata: dict[str, StreamingPrompt] = {}
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
max_output_id_length = 24
async def send_image(image_data, sid=None, output_id: str = None):
max_length = max_output_id_length
output_id = output_id[:max_length]
padded_output_id = output_id.ljust(max_length, "\x00")
encoded_output_id = padded_output_id.encode("ascii", "replace")
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
quality = image_data[3]
if max_size is not None:
if hasattr(Image, "Resampling"):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.ANTIALIAS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
elif image_type == "WEBP":
type_num = 3
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
# 4 bytes for the type
bytesIO.write(header)
# 10 bytes for the output_id
position_before = bytesIO.tell()
bytesIO.write(encoded_output_id)
position_after = bytesIO.tell()
bytes_written = position_after - position_before
print(f"Bytes written: {bytes_written}")
image.save(bytesIO, format=image_type, quality=quality, compress_level=1)
preview_bytes = bytesIO.getvalue()
await send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (
aiohttp.ClientError,
aiohttp.ClientPayloadError,
ConnectionResetError,
) as err:
print("send error:", err)
def encode_bytes(event, data):
if not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
message = bytearray(packed)
message.extend(data)
return message
async def send_bytes(event, data, sid=None):
message = encode_bytes(event, data)
print("sending image to ", event, sid)
if sid is None:
_sockets = list(sockets.values())
for ws in _sockets:
await send_socket_catch_exception(ws.send_bytes, message)
elif sid in sockets:
await send_socket_catch_exception(sockets[sid].send_bytes, message)