diff --git a/README.md b/README.md index fe18b1b..0d18155 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,9 @@ Integrated Design Diffusion Model │ ├── class_1 │ ├── class_2 │ └── class_3 +├── deploy +│ ├── deploy_socket.py +│ └── deploy_server.py ├── model │ ├── modules │ │ ├── activation.py @@ -84,7 +87,6 @@ Integrated Design Diffusion Model │ │ └── noise │ └── test_module.py ├── tools -│ ├── deploy.py │ ├── FID_calculator.py │ ├── FID_calculator_plus.py │ ├── generate.py @@ -119,6 +121,7 @@ Integrated Design Diffusion Model - [x] 12. Adding PLMS Sampling Method. (2024-03-12) - [x] 13. Adding FID calculator to verify image quality. (2024-05-06) - [x] 14. Adding the deployment of image-generating Sockets +- [x] 14. Adding the deployment of image-generating Sockets and Web server. (2024-11-13) ### Training diff --git a/README_zh.md b/README_zh.md index 605dfaa..e30064a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -52,6 +52,9 @@ Integrated Design Diffusion Model │ ├── class_1 │ ├── class_2 │ └── class_3 +├── deploy +│ ├── deploy_socket.py +│ └── deploy_server.py ├── model │ ├── modules │ │ ├── activation.py @@ -83,7 +86,6 @@ Integrated Design Diffusion Model │ │ └── noise │ └── test_module.py ├── tools -│ ├── deploy.py │ ├── FID_calculator.py │ ├── FID_calculator_plus.py │ ├── generate.py @@ -118,6 +120,7 @@ Integrated Design Diffusion Model - [x] 12. 增加PLMS采样方法(2024-03-12) - [x] 13. 增加FID方法验证图像质量(2024-05-06) - [x] 14. 增加生成图像Socket部署(2024-11-12) +- [x] 14. 增加生成图像Socket和网站服务部署(2024-11-13) ### 训练 diff --git a/deploy/deploy_server.py b/deploy/deploy_server.py new file mode 100644 index 0000000..d99599c --- /dev/null +++ b/deploy/deploy_server.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2024/11/4 23:04 + @Author : chairc + @Site : https://github.com/chairc +""" +import os +import sys +import json +import logging +import uuid + +import coloredlogs + +from flask import Flask, request, jsonify +from torchvision import transforms + +sys.path.append(os.path.dirname(sys.path[0])) +from config.version import get_version_banner +from tools.generate import Generator, init_generate_args +from utils import save_images +from utils.processing import image_to_base64 + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") +app = Flask(__name__) + + +@app.route("/") +def index(): + logger.info(msg="Route -> Hello IDDM") + return "Hello, IDDM!" + + +@app.route("/api/generate/df", methods=["POST"]) +def generate_diffusion_model_api(): + """ + Generate a diffusion model + """ + logger.info(msg="Route -> /api/df") + + data = request.json + logger.info(msg=f"Send json: {data}") + + # Sample type + sample = data["sample"] + # Image size + image_size = data["image_size"] + # Number of images + num_images = data["num_images"] if data["num_images"] >= 1 else 1 + # Weight path + weight_path = data["weight_path"] + result_path = data["result_path"] + # Recommend use base64 in server app + # Return mode, base64 or url + re_type = data["type"] + + logger.info(msg="[Web]: Start generation.") + # Type is url or base64 + re_json = {"image": [], "type": str(re_type)} + + if any(param is None for param in [sample, image_size, num_images, weight_path, result_path, re_type]): + return jsonify({"code": 400, "msg": "Illegal parameters.", "data": None}), 400 + + # Init args + args = init_generate_args() + args.sample = sample + args.image_size = image_size + args.weight_path = weight_path + args.result_path = result_path + # Only generate 1 image per + args.num_images = 1 + + try: + # Init server model + server_model = Generator(gen_args=args, deploy=True) + + logger.info(msg=f"[Web]: A total of {num_images} images are generated.") + # Generate images by diffusion models + for i in range(num_images): + logger.info(msg=f"[Web]: Current generate {i + 1} of {num_images}.") + # Generation name + generate_name = uuid.uuid1() + # Generate image + x = server_model.generate(index=i) + + # Select mode + # Recommend use base64 + if re_type == "base64": + x = transforms.ToPILImage()(x[0]) + re_x = image_to_base64(image=x) + else: + # Save images + re_x = os.path.join(result_path, f"{generate_name}.png") + save_images(images=x, path=re_x) + # Append return json + image_json = {"image_id": str(generate_name), "type": re_type, + "image": str(re_x)} + re_json["image"].append(image_json) + + logger.info(msg="[Web]: Finish generation.") + + return jsonify({"code": 200, "msg": "success!", "data": json.dumps(re_json, ensure_ascii=False)}), 200 + except Exception as e: + return jsonify({"code": 500, "msg": str(e), "data": None}), 500 + + +@app.route("/api/generate/sr") +def generate_super_resolution_model_api(): + logger.info(msg="Route -> /api/sr") + # TODO: super resolution api + return "SR!" + + +if __name__ == "__main__": + host = "127.0.0.1" + port = 12341 + logger.info(msg=f"Run -> {host}:{port}") + get_version_banner() + app.run(host=host, port=port, debug=True) diff --git a/tools/deploy.py b/deploy/deploy_socket.py similarity index 72% rename from tools/deploy.py rename to deploy/deploy_socket.py index 172340e..7bd9109 100644 --- a/tools/deploy.py +++ b/deploy/deploy_socket.py @@ -14,13 +14,13 @@ import threading import coloredlogs -import torch +from torchvision import transforms sys.path.append(os.path.dirname(sys.path[0])) -from model.networks.unet import UNet +from tools.generate import init_generate_args, Generator from utils.utils import save_images -from utils.initializer import device_initializer, sample_initializer -from utils.checkpoint import load_ckpt +from utils.processing import image_to_base64 +from config.version import get_version_banner logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") @@ -32,57 +32,56 @@ def generate(parse_json_data): :param parse_json_data: Parse send json message :return: JSON """ - logger.info(msg="[Client]: Start generation.") - re_json = {"image": []} - # Get the incoming json value information - # Enable conditional generation - conditional = parse_json_data["conditional"] # Sample type sample = parse_json_data["sample"] # Image size image_size = parse_json_data["image_size"] # Number of images num_images = parse_json_data["num_images"] if parse_json_data["num_images"] >= 1 else 1 - # Activation function - act = parse_json_data["act"] # Weight path weight_path = parse_json_data["weight_path"] - # Saving path result_path = parse_json_data["result_path"] - # Run device initializer - device = device_initializer() - # Initialize the diffusion model - diffusion = sample_initializer(sample=sample, image_size=image_size, device=device) - # Initialize model - if conditional: - # Number of classes - num_classes = parse_json_data["num_classes"] - # Generation class name - class_name = parse_json_data["class_name"] - # classifier-free guidance interpolation weight - cfg_scale = parse_json_data["cfg_scale"] - model = UNet(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device) - load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False) - y = torch.Tensor([class_name]).long().to(device) - else: - model = UNet(device=device, image_size=image_size, act=act).to(device) - load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False) - y = None - cfg_scale = None + # Return mode, base64 or url + re_type = parse_json_data["type"] + + logger.info(msg="[Client]: Start generation.") + # Type is url or base64 + re_json = {"image": [], "type": str(re_type)} + + # Init args + args = init_generate_args() + args.sample = sample + args.image_size = image_size + args.weight_path = weight_path + args.result_path = result_path + # Only generate 1 image per + args.num_images = 1 + + # Init model + model = Generator(gen_args=args, deploy=True) + + logger.info(msg=f"[Client]: A total of {num_images} images are generated.") # Generate images by diffusion models for i in range(num_images): + logger.info(msg=f"[Client]: Current generate {i + 1} of {num_images}.") # Generation name generate_name = uuid.uuid1() - x = diffusion.sample(model=model, n=1, labels=y, cfg_scale=cfg_scale) - # TODO: Convert to base64 - # Save images - save_images(images=x, path=os.path.join(result_path, f"{generate_name}.jpg")) - # Append return json - image_json = {"image_id": str(generate_name), - "image_name": f"{generate_name}.jpg"} + x = model.generate(index=i) + # Select mode + if re_type == "base64": + x = transforms.ToPILImage()(x[0]) + re_x = image_to_base64(image=x) + else: + # Save images + re_x = os.path.join(result_path, f"{generate_name}.png") + save_images(images=x, path=re_x) + # Append return json in data + image_json = {"image_id": str(generate_name), "type": re_type, + "image": str(re_x)} + re_json["image"].append(image_json) logger.info(msg="[Client]: Finish generation.") - return re_json + return json.dumps(re_json, ensure_ascii=False) def main(): @@ -122,6 +121,7 @@ class ServerThreading(threading.Thread): """ ServerThreading class """ + def __init__(self, client_socket, address, receive_size=1024 * 1024, encoding="utf-8"): """ ServerThreading initialization @@ -178,4 +178,5 @@ def __del__(self): if __name__ == "__main__": + get_version_banner() main() diff --git a/requirements.txt b/requirements.txt index 4816bc9..53561ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,11 @@ matplotlib==3.7.1 numpy==1.25.0 Pillow==10.3.0 Requests==2.31.0 +scikit-image==0.22.0 torch_summary==1.4.5 tqdm==4.65.0 pytorch_fid==0.3.0 +flask==3.0.3 # If that fails use: pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html # About more torch information please click: https://pytorch.org/get-started/previous-versions/#linux-and-windows-25