Skip to content

Commit

Permalink
Add: Added deploy support (server and socket).
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Nov 13, 2024
1 parent 941ba44 commit 2e8cd13
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 42 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ Integrated Design Diffusion Model
│ ├── class_1
│ ├── class_2
│ └── class_3
├── deploy
│ ├── deploy_socket.py
│ └── deploy_server.py
├── model
│ ├── modules
│ │ ├── activation.py
Expand Down Expand Up @@ -84,7 +87,6 @@ Integrated Design Diffusion Model
│ │ └── noise
│ └── test_module.py
├── tools
│ ├── deploy.py
│ ├── FID_calculator.py
│ ├── FID_calculator_plus.py
│ ├── generate.py
Expand Down Expand Up @@ -119,6 +121,7 @@ Integrated Design Diffusion Model
- [x] 12. Adding PLMS Sampling Method. (2024-03-12)
- [x] 13. Adding FID calculator to verify image quality. (2024-05-06)
- [x] 14. Adding the deployment of image-generating Sockets
- [x] 14. Adding the deployment of image-generating Sockets and Web server. (2024-11-13)

### Training

Expand Down
5 changes: 4 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ Integrated Design Diffusion Model
│ ├── class_1
│ ├── class_2
│ └── class_3
├── deploy
│ ├── deploy_socket.py
│ └── deploy_server.py
├── model
│ ├── modules
│ │ ├── activation.py
Expand Down Expand Up @@ -83,7 +86,6 @@ Integrated Design Diffusion Model
│ │ └── noise
│ └── test_module.py
├── tools
│ ├── deploy.py
│ ├── FID_calculator.py
│ ├── FID_calculator_plus.py
│ ├── generate.py
Expand Down Expand Up @@ -118,6 +120,7 @@ Integrated Design Diffusion Model
- [x] 12. 增加PLMS采样方法(2024-03-12)
- [x] 13. 增加FID方法验证图像质量(2024-05-06)
- [x] 14. 增加生成图像Socket部署(2024-11-12)
- [x] 14. 增加生成图像Socket和网站服务部署(2024-11-13)

### 训练

Expand Down
121 changes: 121 additions & 0 deletions deploy/deploy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2024/11/4 23:04
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import sys
import json
import logging
import uuid

import coloredlogs

from flask import Flask, request, jsonify
from torchvision import transforms

sys.path.append(os.path.dirname(sys.path[0]))
from config.version import get_version_banner
from tools.generate import Generator, init_generate_args
from utils import save_images
from utils.processing import image_to_base64

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
app = Flask(__name__)


@app.route("/")
def index():
logger.info(msg="Route -> Hello IDDM")
return "Hello, IDDM!"


@app.route("/api/generate/df", methods=["POST"])
def generate_diffusion_model_api():
"""
Generate a diffusion model
"""
logger.info(msg="Route -> /api/df")

data = request.json
logger.info(msg=f"Send json: {data}")

# Sample type
sample = data["sample"]
# Image size
image_size = data["image_size"]
# Number of images
num_images = data["num_images"] if data["num_images"] >= 1 else 1
# Weight path
weight_path = data["weight_path"]
result_path = data["result_path"]
# Recommend use base64 in server app
# Return mode, base64 or url
re_type = data["type"]

logger.info(msg="[Web]: Start generation.")
# Type is url or base64
re_json = {"image": [], "type": str(re_type)}

if any(param is None for param in [sample, image_size, num_images, weight_path, result_path, re_type]):
return jsonify({"code": 400, "msg": "Illegal parameters.", "data": None}), 400

# Init args
args = init_generate_args()
args.sample = sample
args.image_size = image_size
args.weight_path = weight_path
args.result_path = result_path
# Only generate 1 image per
args.num_images = 1

try:
# Init server model
server_model = Generator(gen_args=args, deploy=True)

logger.info(msg=f"[Web]: A total of {num_images} images are generated.")
# Generate images by diffusion models
for i in range(num_images):
logger.info(msg=f"[Web]: Current generate {i + 1} of {num_images}.")
# Generation name
generate_name = uuid.uuid1()
# Generate image
x = server_model.generate(index=i)

# Select mode
# Recommend use base64
if re_type == "base64":
x = transforms.ToPILImage()(x[0])
re_x = image_to_base64(image=x)
else:
# Save images
re_x = os.path.join(result_path, f"{generate_name}.png")
save_images(images=x, path=re_x)
# Append return json
image_json = {"image_id": str(generate_name), "type": re_type,
"image": str(re_x)}
re_json["image"].append(image_json)

logger.info(msg="[Web]: Finish generation.")

return jsonify({"code": 200, "msg": "success!", "data": json.dumps(re_json, ensure_ascii=False)}), 200
except Exception as e:
return jsonify({"code": 500, "msg": str(e), "data": None}), 500

Check warning

Code scanning / CodeQL

Information exposure through an exception Medium

Stack trace information
flows to this location and may be exposed to an external user.


@app.route("/api/generate/sr")
def generate_super_resolution_model_api():
logger.info(msg="Route -> /api/sr")
# TODO: super resolution api
return "SR!"


if __name__ == "__main__":
host = "127.0.0.1"
port = 12341
logger.info(msg=f"Run -> {host}:{port}")
get_version_banner()
app.run(host=host, port=port, debug=True)

Check failure

Code scanning / CodeQL

Flask app is run in debug mode High

A Flask app appears to be run in debug mode. This may allow an attacker to run arbitrary code through the debugger.
81 changes: 41 additions & 40 deletions tools/deploy.py → deploy/deploy_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import threading
import coloredlogs

import torch
from torchvision import transforms

sys.path.append(os.path.dirname(sys.path[0]))
from model.networks.unet import UNet
from tools.generate import init_generate_args, Generator
from utils.utils import save_images
from utils.initializer import device_initializer, sample_initializer
from utils.checkpoint import load_ckpt
from utils.processing import image_to_base64
from config.version import get_version_banner

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
Expand All @@ -32,57 +32,56 @@ def generate(parse_json_data):
:param parse_json_data: Parse send json message
:return: JSON
"""
logger.info(msg="[Client]: Start generation.")
re_json = {"image": []}
# Get the incoming json value information
# Enable conditional generation
conditional = parse_json_data["conditional"]
# Sample type
sample = parse_json_data["sample"]
# Image size
image_size = parse_json_data["image_size"]
# Number of images
num_images = parse_json_data["num_images"] if parse_json_data["num_images"] >= 1 else 1
# Activation function
act = parse_json_data["act"]
# Weight path
weight_path = parse_json_data["weight_path"]
# Saving path
result_path = parse_json_data["result_path"]
# Run device initializer
device = device_initializer()
# Initialize the diffusion model
diffusion = sample_initializer(sample=sample, image_size=image_size, device=device)
# Initialize model
if conditional:
# Number of classes
num_classes = parse_json_data["num_classes"]
# Generation class name
class_name = parse_json_data["class_name"]
# classifier-free guidance interpolation weight
cfg_scale = parse_json_data["cfg_scale"]
model = UNet(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device)
load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False)
y = torch.Tensor([class_name]).long().to(device)
else:
model = UNet(device=device, image_size=image_size, act=act).to(device)
load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False)
y = None
cfg_scale = None
# Return mode, base64 or url
re_type = parse_json_data["type"]

logger.info(msg="[Client]: Start generation.")
# Type is url or base64
re_json = {"image": [], "type": str(re_type)}

# Init args
args = init_generate_args()
args.sample = sample
args.image_size = image_size
args.weight_path = weight_path
args.result_path = result_path
# Only generate 1 image per
args.num_images = 1

# Init model
model = Generator(gen_args=args, deploy=True)

logger.info(msg=f"[Client]: A total of {num_images} images are generated.")
# Generate images by diffusion models
for i in range(num_images):
logger.info(msg=f"[Client]: Current generate {i + 1} of {num_images}.")
# Generation name
generate_name = uuid.uuid1()
x = diffusion.sample(model=model, n=1, labels=y, cfg_scale=cfg_scale)
# TODO: Convert to base64
# Save images
save_images(images=x, path=os.path.join(result_path, f"{generate_name}.jpg"))
# Append return json
image_json = {"image_id": str(generate_name),
"image_name": f"{generate_name}.jpg"}
x = model.generate(index=i)
# Select mode
if re_type == "base64":
x = transforms.ToPILImage()(x[0])
re_x = image_to_base64(image=x)
else:
# Save images
re_x = os.path.join(result_path, f"{generate_name}.png")
save_images(images=x, path=re_x)
# Append return json in data
image_json = {"image_id": str(generate_name), "type": re_type,
"image": str(re_x)}

re_json["image"].append(image_json)
logger.info(msg="[Client]: Finish generation.")
return re_json
return json.dumps(re_json, ensure_ascii=False)


def main():
Expand Down Expand Up @@ -122,6 +121,7 @@ class ServerThreading(threading.Thread):
"""
ServerThreading class
"""

def __init__(self, client_socket, address, receive_size=1024 * 1024, encoding="utf-8"):
"""
ServerThreading initialization
Expand Down Expand Up @@ -178,4 +178,5 @@ def __del__(self):


if __name__ == "__main__":
get_version_banner()
main()
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ matplotlib==3.7.1
numpy==1.25.0
Pillow==10.3.0
Requests==2.31.0
scikit-image==0.22.0
torch_summary==1.4.5
tqdm==4.65.0
pytorch_fid==0.3.0
flask==3.0.3

# If that fails use: pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
# About more torch information please click: https://pytorch.org/get-started/previous-versions/#linux-and-windows-25
Expand Down

0 comments on commit 2e8cd13

Please sign in to comment.