-
Notifications
You must be signed in to change notification settings - Fork 10
/
remote_server.py
61 lines (47 loc) · 2.11 KB
/
remote_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import io
import cv2
import numpy as np
from PIL import Image
from argparse import ArgumentParser
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect
from demo_utils import FaceAnimationClass
parser = ArgumentParser()
parser.add_argument("--source_image", default="./assets/source.jpg", help="path to source image")
parser.add_argument("--restore_face", default=False, type=str, help="restore face")
args = parser.parse_args()
restore_face = True if args.restore_face == 'True' else False if args.restore_face == 'False' else exit('restore_face must be True or False')
faceanimation = FaceAnimationClass(source_image_path=args.source_image, use_sr=restore_face)
# remote server fps is lower than local camera fps, so we need to increase the frequency of face detection and increase the smooth factor
faceanimation.detect_interval = 2
faceanimation.smooth_factor = 0.8
app = FastAPI()
websocket_port = 8066
# WebSocket endpoint to receive and process images
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Receive the image as a binary stream
image_data = await websocket.receive_bytes()
processed_image = process_image(image_data)
# Send the processed image back to the client
await websocket.send_bytes(processed_image)
except WebSocketDisconnect:
pass
def process_image(image_data):
image = Image.open(io.BytesIO(image_data))
image_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
face, result = faceanimation.inference(image_cv2)
# resize to 256x256
if face.shape[1] != 256 or face.shape[0] != 256:
face = cv2.resize(face, (256, 256))
if result.shape[0] != 256 or result.shape[1] != 256:
result = cv2.resize(result, (256, 256))
result = cv2.hconcat([face, result])
_, processed_image_data = cv2.imencode(".jpg", result, [cv2.IMWRITE_JPEG_QUALITY, 95])
return processed_image_data.tobytes()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=websocket_port)