Skip to content

Commit

Permalink
Add: Add download model function.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Dec 10, 2024
1 parent 6f230d6 commit 9ad1bd5
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 1 deletion.
3 changes: 3 additions & 0 deletions config/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"""
from config.choices import image_type_choices

# Temp files download path
DOWNLOAD_FILE_TEMP_PATH = "../.temp/download_files"

# Train
MASTER_ADDR = "localhost"
MASTER_PORT = "12345"
Expand Down
18 changes: 17 additions & 1 deletion test/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from config.choices import parse_image_size_type
from utils.dataset import get_dataset, set_resize_images_size
from utils.utils import delete_files
from utils.utils import delete_files, download_files, download_model_pretrain_model
from utils.initializer import device_initializer, network_initializer, sample_initializer, generate_initializer
from utils.lr_scheduler import set_cosine_lr
from utils.check import check_is_nan, check_image_size
Expand Down Expand Up @@ -319,6 +319,22 @@ def test_image_to_base64_to_image(self):
for y in range(image.height):
assert image.getpixel((x, y)) == re_image.getpixel((x, y))

def test_download_files(self):
"""
Test download files
"""
url = ["https://github.com/chairc/Integrated-Design-Diffusion-Model/archive/refs/tags/v1.1.6.zip"]
url_list = ["https://github.com/chairc/Integrated-Design-Diffusion-Model/archive/refs/tags/v1.1.6.zip",
"https://github.com/chairc/Integrated-Design-Diffusion-Model/archive/refs/tags/v1.1.6.tar.gz"]
download_files(url_list=url)
download_files(url_list=url_list)

def test_download_pretrain_model(self):
"""
Test download pretrain model
"""
download_model_pretrain_model(pretrain_type="df", network="unet", conditional=True, image_size=64)


if __name__ == "__main__":
pass
35 changes: 35 additions & 0 deletions utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

from urllib.parse import urlparse

from config.setting import DEFAULT_IMAGE_SIZE

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,3 +71,36 @@ def check_image_size(image_size):
raise ValueError(f"Invalid 'image_size' tuple and list format: {image_size}")
else:
raise TypeError(f"Invalid 'image_size' format: {image_size}")


def check_url(url=""):
"""
Check the url is valid
:param url: Url
"""
try:
# Parse URL
parsed_url = urlparse(url)
# Check that all parts of the parsed URL make sense
# Here we mainly check whether the network location part (netloc) is empty
# And whether the URL scheme is a common network protocol (such as http, https, etc.)
if all([parsed_url.scheme, parsed_url.netloc]):
file_name = parsed_url.path.split("/")[-1]
logger.info(msg=f"The URL: {url} is legal.")
return file_name
else:
raise ValueError(f"Invalid 'url' format: {url}")
except ValueError:
# If a Value Error exception is thrown when parsing the URL, it means that the URL format is illegal.
raise ValueError("Invalid 'url' format.")


def check_pretrain_path(pretrain_path):
"""
Check the pretrain path is valid
:param pretrain_path: Pretrain path
:return: Boolean
"""
if pretrain_path is None or not os.path.exists(pretrain_path):
return True
return False
92 changes: 92 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
import time

import coloredlogs
import requests
import torch
import torchvision

from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm

from config.model_list import pretrain_model_choices
from config.setting import DOWNLOAD_FILE_TEMP_PATH
from utils.check import check_url

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")
Expand Down Expand Up @@ -148,3 +154,89 @@ def check_and_create_dir(path):
"""
logger.info(msg=f"Check and create folder '{path}'.")
os.makedirs(name=path, exist_ok=True)


def download_files(url_list=None, save_path=None):
"""
Downloads files
:param url_list: url list
:param save_path: Save path
"""
# Temp download path
if save_path is None:
download_file_temp_path = DOWNLOAD_FILE_TEMP_PATH
else:
download_file_temp_path = save_path
# Check and create
check_and_create_dir(path=download_file_temp_path)
# Check url list
for url in url_list:
logger.info(msg=f"Current download url is {url}")
file_name = check_url(url=url)

# Send the request with stream=True
with requests.get(url, stream=True) as response:
# Check for HTTP errors
response.raise_for_status()
# Get the total size of the file (if possible)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1KB for each read
progress_bar = tqdm(total=total_size_in_bytes, unit="B", unit_scale=True, desc=f"Downloading {file_name}")

# Open the file in binary write mode
with open(os.path.join(download_file_temp_path, file_name), "wb") as file:
for data in response.iter_content(block_size):
# Write the data to the file
file.write(data)
# Update the progress bar
progress_bar.update(len(data))
# Close the progress bar
progress_bar.close()
logger.info(msg=f"Current {url} is download successfully.")
logger.info(msg="Everything is downloaded.")


def download_model_pretrain_model(pretrain_type="df", network="unet", image_size=64, **kwargs):
"""
Download pre-trained model in GitHub repository
:param pretrain_type: Type of pre-trained model
:param network: Network
:param image_size: Image size
:param kwargs: Other parameters
:return new_pretrain_path
"""
# Check image size
if isinstance(image_size, int):
image_size = str(image_size)
else:
raise ValueError("Official pretrain model's image size must be int, such as 64 or 120.")
# Download diffusion pretrain model
if pretrain_type == "df":
df_type = kwargs.get("df_type", "default")
conditional_type = "conditional" if kwargs.get("df_type", True) else "unconditional"
# Download pretrain model
if df_type == "default":
pretrain_model_url = pretrain_model_choices[pretrain_type][df_type][network][conditional_type][image_size]
# Download sample model.
# If use cifar-10 dataset, you can set cifar10 pretrain model
elif df_type == "exp":
model_name = kwargs.get("model_name", "cifar10")
pretrain_model_url = pretrain_model_choices[pretrain_type][df_type][network][conditional_type][image_size][
model_name]
else:
raise TypeError(f"Diffusion model type '{df_type}' is not supported.")
# Download super resolution pretrain model
elif pretrain_type == "sr":
act = kwargs.get("act", "silu")
pretrain_model_url = pretrain_model_choices[pretrain_type][network][act][image_size]
else:
raise TypeError(f"Pretrain type '{pretrain_type}' is not supported.")

# Download model
download_files(url_list=[pretrain_model_url])
logger.info(msg=f"Current pretrain model path '{pretrain_model_url}' is download successfully.")
# Get file name
parts = pretrain_model_url.split("/")
filename = parts[-1]
new_pretrain_path = os.path.join(DOWNLOAD_FILE_TEMP_PATH, filename)
return new_pretrain_path

0 comments on commit 9ad1bd5

Please sign in to comment.