From 14881f9b8a077a392dafbd684c440398de59767a Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Mon, 9 Dec 2024 20:53:45 +0800 Subject: [PATCH 1/4] Update: Update deploy init. --- deploy/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 deploy/__init__.py diff --git a/deploy/__init__.py b/deploy/__init__.py new file mode 100644 index 0000000..363b0f0 --- /dev/null +++ b/deploy/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2024/11/12 20:50 + @Author : chairc + @Site : https://github.com/chairc +""" +from deploy.deploy_server import generate_diffusion_model_api, generate_super_resolution_model_api +from deploy.deploy_socket import generate From 40697afc54a65a997eef54d2b60889df0db8895b Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Mon, 9 Dec 2024 20:54:39 +0800 Subject: [PATCH 2/4] Update: Update gitignore. --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 2dc53ca..d7b72fd 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,8 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ + +# IDDM +/weights/ +/results/ +/datasets/ From 6f230d642b6c0ae340eed224f0e41e0f62872d1a Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 10 Dec 2024 20:51:31 +0800 Subject: [PATCH 3/4] Add: Add official model list. --- config/model_list.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 config/model_list.py diff --git a/config/model_list.py b/config/model_list.py new file mode 100644 index 0000000..c4a98ee --- /dev/null +++ b/config/model_list.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2024/11/18 19:57 + @Author : chairc + @Site : https://github.com/chairc +""" + +pretrain_model_choices = { + "df": { + "default": { + "unet": { + "conditional": { + "64": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/cifar10-64-weight.pt", + "120": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/celebahq-120-weight.pt", + }, + "unconditional": { + "64": "", + "120": "", + }, + }, + "unetv2": { + "conditional": { + "64": "", + "120": "", + }, + "unconditional": { + "64": "", + "120": "", + }, + }, + "cspdarkunet": { + "conditional": { + "64": "", + "120": "", + }, + "unconditional": { + "64": "", + "120": "", + }, + } + }, + "exp": { + "unet": { + "gelu": { + "64": { + "neu-cls": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.7/neu-cls-64-weight.pt", + "cifar10": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/cifar10-64-weight.pt", + "animate-face": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/animate-face-64-weight.pt" + }, + "120": { + "neu": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/neu-120-weight.pt", + "animate-ganyu": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/animate-ganyu-120-weight.pt", + "celebahq": "https://github.com/chairc/Integrated-Design-Diffusion-Model/releases/download/v1.1.5/celebahq-120-weight.pt", + } + } + }, + } + }, + "sr": { + "srv1": { + "gelu": { + "64": "", + "120": "", + }, + "silu": { + "64": "", + "120": "", + }, + "relu": { + "64": "", + "120": "", + }, + "relu6": { + "64": "", + "120": "", + }, + "lrelu": { + "64": "", + "120": "", + }, + } + } +} + +if __name__ == "__main__": + pass \ No newline at end of file From 9ad1bd5ffd53f6642eedca74164b1be12b22549b Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 10 Dec 2024 20:53:54 +0800 Subject: [PATCH 4/4] Add: Add download model function. --- config/setting.py | 3 ++ test/test_module.py | 18 ++++++++- utils/check.py | 35 +++++++++++++++++ utils/utils.py | 92 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 1 deletion(-) diff --git a/config/setting.py b/config/setting.py index 9b7ccf5..da62a22 100644 --- a/config/setting.py +++ b/config/setting.py @@ -7,6 +7,9 @@ """ from config.choices import image_type_choices +# Temp files download path +DOWNLOAD_FILE_TEMP_PATH = "../.temp/download_files" + # Train MASTER_ADDR = "localhost" MASTER_PORT = "12345" diff --git a/test/test_module.py b/test/test_module.py index 664beaf..0402ba6 100644 --- a/test/test_module.py +++ b/test/test_module.py @@ -23,7 +23,7 @@ from config.choices import parse_image_size_type from utils.dataset import get_dataset, set_resize_images_size -from utils.utils import delete_files +from utils.utils import delete_files, download_files, download_model_pretrain_model from utils.initializer import device_initializer, network_initializer, sample_initializer, generate_initializer from utils.lr_scheduler import set_cosine_lr from utils.check import check_is_nan, check_image_size @@ -319,6 +319,22 @@ def test_image_to_base64_to_image(self): for y in range(image.height): assert image.getpixel((x, y)) == re_image.getpixel((x, y)) + def test_download_files(self): + """ + Test download files + """ + url = ["https://github.com/chairc/Integrated-Design-Diffusion-Model/archive/refs/tags/v1.1.6.zip"] + url_list = ["https://github.com/chairc/Integrated-Design-Diffusion-Model/archive/refs/tags/v1.1.6.zip", + "https://github.com/chairc/Integrated-Design-Diffusion-Model/archive/refs/tags/v1.1.6.tar.gz"] + download_files(url_list=url) + download_files(url_list=url_list) + + def test_download_pretrain_model(self): + """ + Test download pretrain model + """ + download_model_pretrain_model(pretrain_type="df", network="unet", conditional=True, image_size=64) + if __name__ == "__main__": pass diff --git a/utils/check.py b/utils/check.py index edf37aa..517b664 100644 --- a/utils/check.py +++ b/utils/check.py @@ -11,6 +11,8 @@ import torch +from urllib.parse import urlparse + from config.setting import DEFAULT_IMAGE_SIZE logger = logging.getLogger(__name__) @@ -69,3 +71,36 @@ def check_image_size(image_size): raise ValueError(f"Invalid 'image_size' tuple and list format: {image_size}") else: raise TypeError(f"Invalid 'image_size' format: {image_size}") + + +def check_url(url=""): + """ + Check the url is valid + :param url: Url + """ + try: + # Parse URL + parsed_url = urlparse(url) + # Check that all parts of the parsed URL make sense + # Here we mainly check whether the network location part (netloc) is empty + # And whether the URL scheme is a common network protocol (such as http, https, etc.) + if all([parsed_url.scheme, parsed_url.netloc]): + file_name = parsed_url.path.split("/")[-1] + logger.info(msg=f"The URL: {url} is legal.") + return file_name + else: + raise ValueError(f"Invalid 'url' format: {url}") + except ValueError: + # If a Value Error exception is thrown when parsing the URL, it means that the URL format is illegal. + raise ValueError("Invalid 'url' format.") + + +def check_pretrain_path(pretrain_path): + """ + Check the pretrain path is valid + :param pretrain_path: Pretrain path + :return: Boolean + """ + if pretrain_path is None or not os.path.exists(pretrain_path): + return True + return False diff --git a/utils/utils.py b/utils/utils.py index 87a65e3..6c3f438 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -11,11 +11,17 @@ import time import coloredlogs +import requests import torch import torchvision from PIL import Image from matplotlib import pyplot as plt +from tqdm import tqdm + +from config.model_list import pretrain_model_choices +from config.setting import DOWNLOAD_FILE_TEMP_PATH +from utils.check import check_url logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") @@ -148,3 +154,89 @@ def check_and_create_dir(path): """ logger.info(msg=f"Check and create folder '{path}'.") os.makedirs(name=path, exist_ok=True) + + +def download_files(url_list=None, save_path=None): + """ + Downloads files + :param url_list: url list + :param save_path: Save path + """ + # Temp download path + if save_path is None: + download_file_temp_path = DOWNLOAD_FILE_TEMP_PATH + else: + download_file_temp_path = save_path + # Check and create + check_and_create_dir(path=download_file_temp_path) + # Check url list + for url in url_list: + logger.info(msg=f"Current download url is {url}") + file_name = check_url(url=url) + + # Send the request with stream=True + with requests.get(url, stream=True) as response: + # Check for HTTP errors + response.raise_for_status() + # Get the total size of the file (if possible) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1KB for each read + progress_bar = tqdm(total=total_size_in_bytes, unit="B", unit_scale=True, desc=f"Downloading {file_name}") + + # Open the file in binary write mode + with open(os.path.join(download_file_temp_path, file_name), "wb") as file: + for data in response.iter_content(block_size): + # Write the data to the file + file.write(data) + # Update the progress bar + progress_bar.update(len(data)) + # Close the progress bar + progress_bar.close() + logger.info(msg=f"Current {url} is download successfully.") + logger.info(msg="Everything is downloaded.") + + +def download_model_pretrain_model(pretrain_type="df", network="unet", image_size=64, **kwargs): + """ + Download pre-trained model in GitHub repository + :param pretrain_type: Type of pre-trained model + :param network: Network + :param image_size: Image size + :param kwargs: Other parameters + :return new_pretrain_path + """ + # Check image size + if isinstance(image_size, int): + image_size = str(image_size) + else: + raise ValueError("Official pretrain model's image size must be int, such as 64 or 120.") + # Download diffusion pretrain model + if pretrain_type == "df": + df_type = kwargs.get("df_type", "default") + conditional_type = "conditional" if kwargs.get("df_type", True) else "unconditional" + # Download pretrain model + if df_type == "default": + pretrain_model_url = pretrain_model_choices[pretrain_type][df_type][network][conditional_type][image_size] + # Download sample model. + # If use cifar-10 dataset, you can set cifar10 pretrain model + elif df_type == "exp": + model_name = kwargs.get("model_name", "cifar10") + pretrain_model_url = pretrain_model_choices[pretrain_type][df_type][network][conditional_type][image_size][ + model_name] + else: + raise TypeError(f"Diffusion model type '{df_type}' is not supported.") + # Download super resolution pretrain model + elif pretrain_type == "sr": + act = kwargs.get("act", "silu") + pretrain_model_url = pretrain_model_choices[pretrain_type][network][act][image_size] + else: + raise TypeError(f"Pretrain type '{pretrain_type}' is not supported.") + + # Download model + download_files(url_list=[pretrain_model_url]) + logger.info(msg=f"Current pretrain model path '{pretrain_model_url}' is download successfully.") + # Get file name + parts = pretrain_model_url.split("/") + filename = parts[-1] + new_pretrain_path = os.path.join(DOWNLOAD_FILE_TEMP_PATH, filename) + return new_pretrain_path