From 387545f24bafc975ae6eaaecb156c4399bce56cd Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 5 Dec 2023 23:20:40 +0800 Subject: [PATCH 1/2] Update: Rewrite checkpoint storage and loading functions, add checkpoint files. --- README.md | 21 ++++++-- README_zh.md | 19 ++++++-- tools/deploy.py | 7 +-- tools/generate.py | 7 +-- tools/train.py | 76 +++++++++++++---------------- utils/checkpoint.py | 111 +++++++++++++++++++++++++++++++++++++++++++ utils/initializer.py | 30 ------------ 7 files changed, 184 insertions(+), 87 deletions(-) create mode 100644 utils/checkpoint.py diff --git a/README.md b/README.md index c567407..c988304 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod │ ├── generate.py │ └── train.py ├── utils +│ ├── checkpoint.py │ ├── initializer.py │ ├── lr_scheduler.py │ └── utils.py @@ -139,13 +140,26 @@ The training GPU implements environment for this README is as follows: models ar **Conditional Resume Training Command** ```bash - python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path + # This is using --start_epoch, default use current epoch checkpoint + python train.py --resume True --start_epoch 10 --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path + ``` + + ```bash + # This is not using --start_epoch, default use last checkpoint + python train.py --resume True --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path ``` **Unconditional Resume Training Command** ```bash - python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path + # This is using --start_epoch, default use current epoch checkpoint + python train.py --resume True --start_epoch 10 --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path + ``` + + ```bash + # This is not using --start_epoch, default use last checkpoint + python train.py --resume True --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path ``` + #### Distributed Training 1. The basic configuration is similar to regular training, but note that enabling distributed training requires setting `--distributed` to `True`. To prevent arbitrary use of distributed training, we have several conditions for enabling distributed training, such as `args.distributed`, `torch.cuda.device_count() > 1`, and `torch.cuda.is_available()`. @@ -197,8 +211,7 @@ The training GPU implements environment for this README is as follows: models ar | --vis | | Visualize dataset information | bool | Enable visualization of dataset information for model selection based on visualization | | --num_vis | | Number of visualization images generated | int | Number of visualization images generated. If not filled, the default is the number of image classes | | --resume | | Resume interrupted training | bool | Set to "True" to resume interrupted training. Note: If the epoch number of interruption is outside the condition of --start_model_interval, it will not take effect. For example, if the start saving model time is 100 and the interruption number is 50, we cannot set any loading epoch points because we did not save the model. We save the xxx_last.pt file every training, so we need to use the last saved model for interrupted training | -| --start_epoch | | Epoch number of interruption | int | Epoch number where the training was interrupted | -| --load_model_dir | | Folder name of the loaded model | str | Folder name of the previously loaded model | +| --start_epoch | | Epoch number of interruption | int | Epoch number where the training was interrupted, the model will load current checkpoint | | --distributed | | Distributed training | bool | Enable distributed training | | --main_gpu | | Main GPU for distributed | int | Set the main GPU for distributed training | | --world_size | | Number of distributed nodes | int | Number of distributed nodes, corresponds to the actual number of GPUs or distributed nodes being used | diff --git a/README_zh.md b/README_zh.md index 3dd01fa..c96da69 100644 --- a/README_zh.md +++ b/README_zh.md @@ -38,6 +38,7 @@ │ ├── generate.py │ └── train.py ├── utils +│ ├── checkpoint.py │ ├── initializer.py │ ├── lr_scheduler.py │ └── utils.py @@ -138,12 +139,23 @@ **有条件恢复训练命令** ```bash - python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path + # 此处为输入--start_epoch参数,使用当前编号权重 + python train.py --resume True --start_epoch 10 --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path + ``` + + ```bash + # 此处为不输入--start_epoch参数,默认使用last权重 + python train.py --resume True --sample ddpm --conditional True --run_name df --epochs 300 --batch_size 16 --image_size 64 --num_classes 10 --dataset_path /your/dataset/path --result_path /your/save/path ``` **无条件恢复训练命令** ```bash - python train.py --resume True --start_epoch 10 --load_model_dir df --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path + python train.py --resume True --start_epoch 10 --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path + ``` + + ```bash + # 此处为不输入--start_epoch参数,默认使用last权重 + python train.py --resume True --sample ddpm --conditional False --run_name df --epochs 300 --batch_size 16 --image_size 64 --dataset_path /your/dataset/path --result_path /your/save/path ``` #### 分布式训练 @@ -200,8 +212,7 @@ | --vis | | 可视化数据集信息 | bool | 打开可视化数据集信息,根据可视化生成样本信息筛选模型 | | --num_vis | | 生成的可视化图像数量 | int | 生成的可视化图像数量。如果不填写,则默认生成图片个数为数据集类别的个数 | | --resume | | 中断恢复训练 | bool | 恢复训练将设置为“True”。注意:设置异常中断的epoch编号若在--start_model_interval参数条件外,则不生效。例如开始保存模型时间为100,中断编号为50,由于我们没有保存模型,所以无法设置任意加载epoch点。每次训练我们都会保存xxx_last.pt文件,所以我们需要使用最后一次保存的模型进行中断训练 | -| --start_epoch | | 中断迭代编号 | int | 设置异常中断的epoch编号 | -| --load_model_dir | | 加载模型所在文件夹 | str | 写入中断的epoch上一个加载模型的所在文件夹 | +| --start_epoch | | 中断迭代编号 | int | 设置异常中断的epoch编号,模型会自动加载当前编号的检查点 | | --distributed | | 分布式训练 | bool | 开启分布式训练 | | --main_gpu | | 分布式训练主显卡 | int | 设置分布式中主显卡 | | --world_size | | 分布式训练的节点等级 | int | 分布式训练的节点等级, world_size的值会与实际使用的GPU数量或分布式节点数量相对应 | diff --git a/tools/deploy.py b/tools/deploy.py index f5b473c..172340e 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -19,7 +19,8 @@ sys.path.append(os.path.dirname(sys.path[0])) from model.networks.unet import UNet from utils.utils import save_images -from utils.initializer import device_initializer, load_model_weight_initializer, sample_initializer +from utils.initializer import device_initializer, sample_initializer +from utils.checkpoint import load_ckpt logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") @@ -61,11 +62,11 @@ def generate(parse_json_data): # classifier-free guidance interpolation weight cfg_scale = parse_json_data["cfg_scale"] model = UNet(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device) - load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False) + load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False) y = torch.Tensor([class_name]).long().to(device) else: model = UNet(device=device, image_size=image_size, act=act).to(device) - load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False) + load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False) y = None cfg_scale = None # Generate images by diffusion models diff --git a/tools/generate.py b/tools/generate.py index f9f57b1..2f93665 100644 --- a/tools/generate.py +++ b/tools/generate.py @@ -14,8 +14,9 @@ import coloredlogs sys.path.append(os.path.dirname(sys.path[0])) -from utils.initializer import device_initializer, load_model_weight_initializer, network_initializer, sample_initializer +from utils.initializer import device_initializer, network_initializer, sample_initializer from utils.utils import plot_images, save_images, save_one_image_in_images, check_and_create_dir +from utils.checkpoint import load_ckpt logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") @@ -64,7 +65,7 @@ def generate(args): # classifier-free guidance interpolation weight cfg_scale = args.cfg_scale model = Network(num_classes=num_classes, device=device, image_size=image_size, act=act).to(device) - load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False) + load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False) if class_name == -1: y = torch.arange(num_classes).long().to(device) num_images = num_classes @@ -73,7 +74,7 @@ def generate(args): x = diffusion.sample(model=model, n=num_images, labels=y, cfg_scale=cfg_scale) else: model = Network(device=device, image_size=image_size, act=act).to(device) - load_model_weight_initializer(model=model, weight_path=weight_path, device=device, is_train=False) + load_ckpt(ckpt_path=weight_path, model=model, device=device, is_train=False) x = diffusion.sample(model=model, n=num_images) # If there is no path information, it will only be displayed # If it exists, it will be saved to the specified path and displayed diff --git a/tools/train.py b/tools/train.py index 8890ea5..34c7abe 100644 --- a/tools/train.py +++ b/tools/train.py @@ -23,9 +23,10 @@ sys.path.append(os.path.dirname(sys.path[0])) from model.modules.module import EMA -from utils.initializer import device_initializer, seed_initializer, load_model_weight_initializer, network_initializer, \ - optimizer_initializer, sample_initializer, lr_initializer, fp16_initializer +from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \ + sample_initializer, lr_initializer, fp16_initializer from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging +from utils.checkpoint import load_ckpt, save_ckpt logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") @@ -126,18 +127,17 @@ def train(rank=None, args=None): optimizer = optimizer_initializer(model=model, optim=optim, init_lr=init_lr, device=device) # Resume training if resume: - load_model_dir = args.load_model_dir + ckpt_path = None start_epoch = args.start_epoch - # Load the previous model - load_epoch = str(start_epoch - 1).zfill(3) - model_path = os.path.join(result_path, load_model_dir, f"model_{load_epoch}.pt") - optim_path = os.path.join(result_path, load_model_dir, f"optim_model_{load_epoch}.pt") - load_model_weight_initializer(model=model, weight_path=model_path, device=device) - logger.info(msg=f"[{device}]: Successfully load model model_{load_epoch}.pt") - # Load the previous model optimizer - optim_weights_dict = torch.load(f=optim_path, map_location=device) - optimizer.load_state_dict(state_dict=optim_weights_dict) - logger.info(msg=f"[{device}]: Successfully load optimizer optim_model_{load_epoch}.pt") + # Determine which checkpoint to load + # 'start_epoch' is correct + if start_epoch is not None: + ckpt_path = os.path.join(results_dir, f"ckpt_{str(start_epoch - 1).zfill(3)}.pt") + # Parameter 'ckpt_path' is None in the train mode + if ckpt_path is None: + ckpt_path = os.path.join(results_dir, "ckpt_last.pt") + start_epoch = load_ckpt(ckpt_path=ckpt_path, model=model, device=device, optimizer=optimizer) + logger.info(msg=f"[{device}]: Successfully load resume model checkpoint.") else: start_epoch = 0 # Set harf-precision @@ -232,46 +232,37 @@ def train(rank=None, args=None): # Saving and validating models in the main process if save_models: - # Saving model - save_name = f"model_{str(epoch).zfill(3)}" + # Saving model, set the checkpoint name + save_name = f"ckpt_{str(epoch).zfill(3)}" + # Init ckpt params + ckpt_model, ckpt_ema_model, ckpt_optimizer = None, None, None if not conditional: - # Saving pt files - torch.save(obj=model.state_dict(), f=os.path.join(results_dir, f"model_last.pt")) - torch.save(obj=optimizer.state_dict(), f=os.path.join(results_dir, f"optim_last.pt")) + ckpt_model = model.state_dict() + ckpt_optimizer = optimizer.state_dict() # Enable visualization if vis: # images.shape[0] is the number of images in the current batch n = num_vis if num_vis > 0 else images.shape[0] sampled_images = diffusion.sample(model=model, n=n) save_images(images=sampled_images, path=os.path.join(results_vis_dir, f"{save_name}.jpg")) - # Saving pt files in epoch interval - if save_model_interval and epoch > start_model_interval: - torch.save(obj=model.state_dict(), f=os.path.join(results_dir, f"{save_name}.pt")) - torch.save(obj=optimizer.state_dict(), f=os.path.join(results_dir, f"optim_{save_name}.pt")) - logger.info(msg=f"Save the {save_name}.pt, and optim_{save_name}.pt.") - logger.info(msg="Save the model.") else: - # Saving pt files - torch.save(obj=model.state_dict(), f=os.path.join(results_dir, f"model_last.pt")) - torch.save(obj=ema_model.state_dict(), f=os.path.join(results_dir, f"ema_model_last.pt")) - torch.save(obj=optimizer.state_dict(), f=os.path.join(results_dir, f"optim_last.pt")) + ckpt_model = model.state_dict() + ckpt_ema_model = ema_model.state_dict() + ckpt_optimizer = optimizer.state_dict() # Enable visualization if vis: labels = torch.arange(num_classes).long().to(device) n = num_vis if num_vis > 0 else len(labels) sampled_images = diffusion.sample(model=model, n=n, labels=labels, cfg_scale=cfg_scale) - ema_sampled_images = diffusion.sample(model=ema_model, n=n, labels=labels, - cfg_scale=cfg_scale) + ema_sampled_images = diffusion.sample(model=ema_model, n=n, labels=labels, cfg_scale=cfg_scale) # This is a method to display the results of each model during training and can be commented out # plot_images(images=sampled_images) save_images(images=sampled_images, path=os.path.join(results_vis_dir, f"{save_name}.jpg")) - save_images(images=ema_sampled_images, path=os.path.join(results_vis_dir, f"{save_name}_ema.jpg")) - if save_model_interval and epoch > start_model_interval: - torch.save(obj=model.state_dict(), f=os.path.join(results_dir, f"{save_name}.pt")) - torch.save(obj=ema_model.state_dict(), f=os.path.join(results_dir, f"ema_{save_name}.pt")) - torch.save(obj=optimizer.state_dict(), f=os.path.join(results_dir, f"optim_{save_name}.pt")) - logger.info(msg=f"Save the {save_name}.pt, ema_{save_name}.pt, and optim_{save_name}.pt.") - logger.info(msg="Save the model.") + save_images(images=ema_sampled_images, path=os.path.join(results_vis_dir, f"ema_{save_name}.jpg")) + # Save checkpoint + save_ckpt(epoch=epoch, save_name=save_name, ckpt_model=ckpt_model, ckpt_ema_model=ckpt_ema_model, + ckpt_optimizer=ckpt_optimizer, results_dir=results_dir, save_model_interval=save_model_interval, + start_model_interval=start_model_interval, num_classes=num_classes) logger.info(msg=f"[{device}]: Finish epoch {epoch}:") # Synchronization during distributed training @@ -365,16 +356,15 @@ def main(args): # If not filled, the default is the number of image classes (unconditional) or images.shape[0] (conditional) parser.add_argument("--num_vis", type=int, default=-1) # Resume interrupted training (needed) - # 1. Set to 'True' to resume interrupted training. - # 2. Set the resume interrupted epoch number - # 3. Set the directory of the previous loaded model from the interrupted epoch. + # 1. Set to 'True' to resume interrupted training and check if the parameter 'run_name' is correct. + # 2. Set the resume interrupted epoch number. (If not, we would select the last) # Note: If the epoch number of interruption is outside the condition of '--start_model_interval', # it will not take effect. For example, if the start saving model time is 100 and the interruption number is 50, # we cannot set any loading epoch points because we did not save the model. - # We save the 'xxx_last.pt' file every training, so we need to use the last saved model for interrupted training + # We save the 'ckpt_last.pt' file every training, so we need to use the last saved model for interrupted training + # If you do not know what epoch the checkpoint is, rename this checkpoint is 'ckpt_last'.pt parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--start_epoch", type=int, default=-1) - parser.add_argument("--load_model_dir", type=str, default="") + parser.add_argument("--start_epoch", type=int, default=None) # =================================Enable distributed training (if applicable)================================= # Enable distributed training (needed) diff --git a/utils/checkpoint.py b/utils/checkpoint.py new file mode 100644 index 0000000..0ed739e --- /dev/null +++ b/utils/checkpoint.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/2 22:43 + @Author : chairc + @Site : https://github.com/chairc +""" +import os +import numpy as np +import logging +import torch +import shutil +import coloredlogs + +from collections import OrderedDict + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") + + +def load_ckpt(ckpt_path, model, device, optimizer=None, is_train=True): + """ + Load checkpoint weight files + :param ckpt_path: Checkpoint path + :param model: Network + :param optimizer: Optimizer + :param device: GPU or CPU + :param is_train: Whether to train mode + :return: start_epoch + 1 + """ + # Load checkpoint + ckpt_state = torch.load(f=ckpt_path, map_location=device) + logger.info(msg=f"[{device}]: Successfully load checkpoint, path is '{ckpt_path}'.") + # Load the current model + ckpt_model = ckpt_state["model"] + load_model_ckpt(model=model, model_ckpt=ckpt_model, is_train=is_train) + logger.info(msg=f"[{device}]: Successfully load model checkpoint.") + # Train mode + if is_train: + # Load the previous model optimizer + optim_weights_dict = ckpt_state["optimizer"] + optimizer.load_state_dict(state_dict=optim_weights_dict) + logger.info(msg=f"[{device}]: Successfully load optimizer checkpoint.") + # Current checkpoint epoch + start_epoch = ckpt_state["start_epoch"] + # Next epoch + return start_epoch + 1 + + +def load_model_ckpt(model, model_ckpt, is_train=True): + """ + Initialize weight loading + :param model: Model + :param model_ckpt: Model checkpoint + :param is_train: Whether to train mode + :return: None + """ + model_dict = model.state_dict() + model_weights_dict = model_ckpt + # Check if key contains 'module.' prefix. + # This method is the name after training in the distribution, check the weight and delete + if not is_train: + new_model_weights_dict = {} + for key, value in model_weights_dict.items(): + if key.startswith("module."): + new_key = key[len("module."):] + new_model_weights_dict[new_key] = value + else: + new_model_weights_dict[key] = value + model_weights_dict = new_model_weights_dict + logger.info(msg="Successfully check the load weight and rename.") + model_weights_dict = {k: v for k, v in model_weights_dict.items() if np.shape(model_dict[k]) == np.shape(v)} + model_dict.update(model_weights_dict) + model.load_state_dict(state_dict=OrderedDict(model_dict)) + + +def save_ckpt(epoch, save_name, ckpt_model, ckpt_ema_model, ckpt_optimizer, results_dir, save_model_interval, + start_model_interval, num_classes, classes_name=None, **kwargs): + """ + Save the model checkpoint weight files + :param epoch: Current epoch + :param save_name: Save the model's name + :param ckpt_model: Model + :param ckpt_ema_model: EMA model + :param ckpt_optimizer: Optimizer + :param results_dir: Results dir + :param save_model_interval: Whether to save weight each training + :param start_model_interval: Start epoch for saving models + :param num_classes: Number of classes + :param classes_name: All classes name + :return: None + """ + # Checkpoint + ckpt_state = { + "start_epoch": epoch, + "model": ckpt_model, + "ema_model": ckpt_ema_model, + "optimizer": ckpt_optimizer, + "num_classes": num_classes, + "classes_name": classes_name, + } + # Save last checkpoint, it must be done + last_filename = os.path.join(results_dir, f"ckpt_last.pt") + torch.save(obj=ckpt_state, f=last_filename) + logger.info(msg=f"Save the ckpt_last.pt") + # If save each checkpoint, just copy the last saved checkpoint and rename it + if save_model_interval and epoch > start_model_interval: + filename = os.path.join(results_dir, f"{save_name}.pt") + shutil.copyfile(last_filename, filename) + logger.info(msg=f"Save the {save_name}.pt") + logger.info(msg="Finish saving the model.") diff --git a/utils/initializer.py b/utils/initializer.py index d5a8e32..a1560c2 100644 --- a/utils/initializer.py +++ b/utils/initializer.py @@ -11,8 +11,6 @@ import logging import coloredlogs -from collections import OrderedDict - from torch.cuda.amp import GradScaler from model.networks.unet import UNet @@ -69,34 +67,6 @@ def seed_initializer(seed_id=0): logger.info(msg=f"The seed is initialized, and the seed ID is {seed_id}.") -def load_model_weight_initializer(model, weight_path, device, is_train=True): - """ - Initialize weight loading - :param model: Model - :param weight_path: Weight model path - :param device: GPU or CPU - :param is_train: Whether to train mode - :return: None - """ - model_dict = model.state_dict() - model_weights_dict = torch.load(f=weight_path, map_location=device) - # Check if key contains 'module.' prefix. - # This method is the name after training in the distribution, check the weight and delete - if not is_train: - new_model_weights_dict = {} - for key, value in model_weights_dict.items(): - if key.startswith("module."): - new_key = key[len("module."):] - new_model_weights_dict[new_key] = value - else: - new_model_weights_dict[key] = value - model_weights_dict = new_model_weights_dict - logger.info(msg="Successfully check the load weight and rename.") - model_weights_dict = {k: v for k, v in model_weights_dict.items() if np.shape(model_dict[k]) == np.shape(v)} - model_dict.update(model_weights_dict) - model.load_state_dict(state_dict=OrderedDict(model_dict)) - - def network_initializer(network, device): """ Initialize base network From b2caa377b27be1b49f13696a351a07d8c08fce39 Mon Sep 17 00:00:00 2001 From: ChairC <974833488@qq.com> Date: Tue, 5 Dec 2023 23:24:16 +0800 Subject: [PATCH 2/2] Update: Reconstruct modules.py, and update package path. --- README.md | 6 + README_zh.md | 6 + model/modules/activation.py | 36 ++++ model/modules/attention.py | 53 ++++++ model/modules/block.py | 131 ++++++++++++++ model/modules/conv.py | 98 ++++++++++ model/modules/ema.py | 69 +++++++ model/modules/module.py | 331 +--------------------------------- model/networks/cspdarkunet.py | 4 +- model/networks/unet.py | 4 +- tools/train.py | 2 +- 11 files changed, 408 insertions(+), 332 deletions(-) create mode 100644 model/modules/activation.py create mode 100644 model/modules/attention.py create mode 100644 model/modules/block.py create mode 100644 model/modules/conv.py create mode 100644 model/modules/ema.py diff --git a/README.md b/README.md index c988304..81284a6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod **Repository Structure** ```yaml +Industrial Defect Diffusion Model ├── datasets │ └── dataset_demo │ ├── class_1 @@ -18,6 +19,11 @@ We named this project IDDM: Industrial Defect Diffusion Model. It aims to reprod │ └── class_3 ├── model │ ├── modules +│ │ ├── activation.py +│ │ ├── attention.py +│ │ ├── block.py +│ │ ├── conv.py +│ │ ├── ema.py │ │ └── module.py │ ├── networks │ │ ├── base.py diff --git a/README_zh.md b/README_zh.md index c96da69..c479ea3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -11,6 +11,7 @@ **本仓库整体结构** ```yaml +Industrial Defect Diffusion Model ├── datasets │ └── dataset_demo │ ├── class_1 @@ -18,6 +19,11 @@ │ └── class_3 ├── model │ ├── modules +│ │ ├── activation.py +│ │ ├── attention.py +│ │ ├── block.py +│ │ ├── conv.py +│ │ ├── ema.py │ │ └── module.py │ ├── networks │ │ ├── base.py diff --git a/model/modules/activation.py b/model/modules/activation.py new file mode 100644 index 0000000..f844394 --- /dev/null +++ b/model/modules/activation.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:19 + @Author : chairc + @Site : https://github.com/chairc +""" +import logging +import coloredlogs +import torch.nn as nn + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") + + +def get_activation_function(name="silu", inplace=False): + """ + Get activation function + :param name: Activation function name + :param inplace: can optionally do the operation in-place + :return Activation function + """ + if name == "relu": + act = nn.ReLU(inplace=inplace) + elif name == "relu6": + act = nn.ReLU6(inplace=inplace) + elif name == "silu": + act = nn.SiLU(inplace=inplace) + elif name == "lrelu": + act = nn.LeakyReLU(0.1, inplace=inplace) + elif name == "gelu": + act = nn.GELU() + else: + logger.warning(msg=f"Unsupported activation function type: {name}") + act = nn.SiLU(inplace=inplace) + return act diff --git a/model/modules/attention.py b/model/modules/attention.py new file mode 100644 index 0000000..9d4e230 --- /dev/null +++ b/model/modules/attention.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:19 + @Author : chairc + @Site : https://github.com/chairc +""" +import torch.nn as nn +from model.modules.activation import get_activation_function + + +class SelfAttention(nn.Module): + """ + SelfAttention block + """ + + def __init__(self, channels, size, act="silu"): + """ + Initialize the self-attention block + :param channels: Channels + :param size: Size + :param act: Activation function + """ + super(SelfAttention, self).__init__() + self.channels = channels + self.size = size + # batch_first is not supported in pytorch 1.8. + # If you want to support upgrading to 1.9 and above, or use the following code to transpose + self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True) + self.ln = nn.LayerNorm(normalized_shape=[channels]) + self.ff_self = nn.Sequential( + nn.LayerNorm(normalized_shape=[channels]), + nn.Linear(in_features=channels, out_features=channels), + get_activation_function(name=act), + nn.Linear(in_features=channels, out_features=channels), + ) + + def forward(self, x): + """ + SelfAttention forward + :param x: Input + :return: attention_value + """ + # First perform the shape transformation, and then use 'swapaxes' to exchange the first + # second dimensions of the new tensor + x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2) + x_ln = self.ln(x) + # batch_first is not supported in pytorch 1.8. + # If you want to support upgrading to 1.9 and above, or use the following code to transpose + attention_value, _ = self.mha(x_ln, x_ln, x_ln) + attention_value = attention_value + x + attention_value = self.ff_self(attention_value) + attention_value + return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size) diff --git a/model/modules/block.py b/model/modules/block.py new file mode 100644 index 0000000..1b1ad84 --- /dev/null +++ b/model/modules/block.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:21 + @Author : chairc + @Site : https://github.com/chairc +""" +import torch +import torch.nn as nn + +from model.modules.conv import BaseConv, DoubleConv +from model.modules.module import CSPLayer + + +class DownBlock(nn.Module): + """ + Downsample block + """ + + def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): + """ + Initialize the downsample block + :param in_channels: Input channels + :param out_channels: Output channels + :param emb_channels: Embed channels + :param act: Activation function + """ + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), + DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act), + ) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, time): + """ + DownBlock forward + :param x: Input + :param time: Time + :return: x + emb + """ + x = self.maxpool_conv(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb + + +class UpBlock(nn.Module): + """ + Upsample Block + """ + + def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): + """ + Initialize the upsample block + :param in_channels: Input channels + :param out_channels: Output channels + :param emb_channels: Embed channels + :param act: Activation function + """ + super().__init__() + + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = nn.Sequential( + DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), + DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act), + ) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, skip_x, time): + """ + UpBlock forward + :param x: Input + :param skip_x: Merged input + :param time: Time + :return: x + emb + """ + x = self.up(x) + x = torch.cat([skip_x, x], dim=1) + x = self.conv(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb + + +class CSPDarkDownBlock(nn.Module): + def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): + super().__init__() + self.conv_csp = nn.Sequential( + BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, act=act), + CSPLayer(in_channels=out_channels, out_channels=out_channels, n=n, act=act) + ) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, time): + x = self.conv_csp(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb + + +class CSPDarkUpBlock(nn.Module): + + def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): + super().__init__() + self.up = nn.Upsample(scale_factor=2, mode="nearest") + self.conv = BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, act=act) + self.csp = CSPLayer(in_channels=in_channels, out_channels=out_channels, n=n, shortcut=False, act=act) + + self.emb_layer = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=emb_channels, out_features=out_channels), + ) + + def forward(self, x, skip_x, time): + x = self.conv(x) + x = self.up(x) + x = torch.cat([skip_x, x], dim=1) + x = self.conv(x) + emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) + return x + emb diff --git a/model/modules/conv.py b/model/modules/conv.py new file mode 100644 index 0000000..6114bf0 --- /dev/null +++ b/model/modules/conv.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:22 + @Author : chairc + @Site : https://github.com/chairc +""" +import logging +import coloredlogs + +import torch.nn as nn +import torch.nn.functional as F + +from model.modules.activation import get_activation_function + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") + + +class DoubleConv(nn.Module): + """ + Double convolution + """ + + def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, act="silu"): + """ + Initialize the double convolution block + :param in_channels: Input channels + :param out_channels: Output channels + :param mid_channels: Middle channels + :param residual: Whether residual + :param act: Activation function + """ + super().__init__() + self.residual = residual + if not mid_channels: + mid_channels = out_channels + self.act = act + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1, bias=False), + nn.GroupNorm(num_groups=1, num_channels=mid_channels), + get_activation_function(name=self.act), + nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + ) + + def forward(self, x): + """ + DoubleConv forward + :param x: Input + :return: Residual or non-residual results + """ + if self.residual: + out = x + self.double_conv(x) + if self.act == "relu": + return F.relu(out) + elif self.act == "relu6": + return F.relu6(out) + elif self.act == "silu": + return F.silu(out) + elif self.act == "lrelu": + return F.leaky_relu(out) + elif self.act == "gelu": + return F.gelu(out) + else: + logger.warning(msg=f"Unsupported activation function type: {self.act}") + return F.silu(out) + else: + return self.double_conv(x) + + +class BaseConv(nn.Module): + """ + Base convolution + Conv2d -> BatchNorm -> Activation function block + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False, act="silu"): + """ + Initialize the Base convolution + :param in_channels: Input channels + :param out_channels: Output channels + :param kernel_size: Kernel size + :param stride: Stride + :param groups: Groups + :param bias: Bias + :param act: Activation function + """ + super().__init__() + # Same padding + pad = (kernel_size - 1) // 2 + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=pad, groups=groups, bias=bias) + self.gn = nn.GroupNorm(num_groups=1, num_channels=out_channels) + self.act = get_activation_function(name=act, inplace=True) + + def forward(self, x): + return self.act(self.gn(self.conv(x))) diff --git a/model/modules/ema.py b/model/modules/ema.py new file mode 100644 index 0000000..f04fdba --- /dev/null +++ b/model/modules/ema.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2023/12/5 10:18 + @Author : chairc + @Site : https://github.com/chairc +""" + + +class EMA: + """ + Exponential Moving Average + """ + + def __init__(self, beta): + """ + Initialize EMA + :param beta: β + """ + super().__init__() + self.beta = beta + self.step = 0 + + def update_model_average(self, ema_model, current_model): + """ + Update model average + :param ema_model: EMA model + :param current_model: Current model + :return: None + """ + for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): + old_weight, up_weight = ema_params, current_params.data + ema_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old_weight, new_weight): + """ + Update average + :param old_weight: Old weight + :param new_weight: New weight + :return: new_weight or old_weight * self.beta + (1 - self.beta) * new_weight + """ + if old_weight is None: + return new_weight + return old_weight * self.beta + (1 - self.beta) * new_weight + + def step_ema(self, ema_model, model, step_start_ema=2000): + """ + EMA step + :param ema_model: EMA model + :param model: Original model + :param step_start_ema: Start EMA step + :return: None + """ + if self.step < step_start_ema: + self.reset_parameters(ema_model, model) + self.step += 1 + return + self.update_model_average(ema_model, model) + self.step += 1 + + @staticmethod + def reset_parameters(ema_model, model): + """ + Reset parameters + :param ema_model: EMA model + :param model: Original model + :return: None + """ + ema_model.load_state_dict(model.state_dict()) diff --git a/model/modules/module.py b/model/modules/module.py index b073518..024b2b7 100644 --- a/model/modules/module.py +++ b/model/modules/module.py @@ -10,299 +10,13 @@ import torch import torch.nn as nn -import torch.nn.functional as F + +from model.modules.conv import BaseConv logger = logging.getLogger(__name__) coloredlogs.install(level="INFO") -def get_activation_function(name="silu", inplace=False): - """ - Get activation function - :param name: Activation function name - :param inplace: can optionally do the operation in-place - :return Activation function - """ - if name == "relu": - act = nn.ReLU(inplace=inplace) - elif name == "relu6": - act = nn.ReLU6(inplace=inplace) - elif name == "silu": - act = nn.SiLU(inplace=inplace) - elif name == "lrelu": - act = nn.LeakyReLU(0.1, inplace=inplace) - elif name == "gelu": - act = nn.GELU() - else: - logger.warning(msg=f"Unsupported activation function type: {name}") - act = nn.SiLU(inplace=inplace) - return act - - -class EMA: - """ - Exponential Moving Average - """ - - def __init__(self, beta): - """ - Initialize EMA - :param beta: β - """ - super().__init__() - self.beta = beta - self.step = 0 - - def update_model_average(self, ema_model, current_model): - """ - Update model average - :param ema_model: EMA model - :param current_model: Current model - :return: None - """ - for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): - old_weight, up_weight = ema_params, current_params.data - ema_params.data = self.update_average(old_weight, up_weight) - - def update_average(self, old_weight, new_weight): - """ - Update average - :param old_weight: Old weight - :param new_weight: New weight - :return: new_weight or old_weight * self.beta + (1 - self.beta) * new_weight - """ - if old_weight is None: - return new_weight - return old_weight * self.beta + (1 - self.beta) * new_weight - - def step_ema(self, ema_model, model, step_start_ema=2000): - """ - EMA step - :param ema_model: EMA model - :param model: Original model - :param step_start_ema: Start EMA step - :return: None - """ - if self.step < step_start_ema: - self.reset_parameters(ema_model, model) - self.step += 1 - return - self.update_model_average(ema_model, model) - self.step += 1 - - def reset_parameters(self, ema_model, model): - """ - Reset parameters - :param ema_model: EMA model - :param model: Original model - :return: None - """ - ema_model.load_state_dict(model.state_dict()) - - -class SelfAttention(nn.Module): - """ - SelfAttention block - """ - - def __init__(self, channels, size, act="silu"): - """ - Initialize the self-attention block - :param channels: Channels - :param size: Size - :param act: Activation function - """ - super(SelfAttention, self).__init__() - self.channels = channels - self.size = size - # batch_first is not supported in pytorch 1.8. - # If you want to support upgrading to 1.9 and above, or use the following code to transpose - self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=4, batch_first=True) - self.ln = nn.LayerNorm(normalized_shape=[channels]) - self.ff_self = nn.Sequential( - nn.LayerNorm(normalized_shape=[channels]), - nn.Linear(in_features=channels, out_features=channels), - get_activation_function(name=act), - nn.Linear(in_features=channels, out_features=channels), - ) - - def forward(self, x): - """ - SelfAttention forward - :param x: Input - :return: attention_value - """ - # First perform the shape transformation, and then use 'swapaxes' to exchange the first - # second dimensions of the new tensor - x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2) - x_ln = self.ln(x) - # batch_first is not supported in pytorch 1.8. - # If you want to support upgrading to 1.9 and above, or use the following code to transpose - attention_value, _ = self.mha(x_ln, x_ln, x_ln) - attention_value = attention_value + x - attention_value = self.ff_self(attention_value) + attention_value - return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size) - - -class DoubleConv(nn.Module): - """ - Double convolution - """ - - def __init__(self, in_channels, out_channels, mid_channels=None, residual=False, act="silu"): - """ - Initialize the double convolution block - :param in_channels: Input channels - :param out_channels: Output channels - :param mid_channels: Middle channels - :param residual: Whether residual - :param act: Activation function - """ - super().__init__() - self.residual = residual - if not mid_channels: - mid_channels = out_channels - self.act = act - self.double_conv = nn.Sequential( - nn.Conv2d(in_channels=in_channels, out_channels=mid_channels, kernel_size=3, padding=1, bias=False), - nn.GroupNorm(num_groups=1, num_channels=mid_channels), - get_activation_function(name=self.act), - nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), - nn.GroupNorm(num_groups=1, num_channels=out_channels), - ) - - def forward(self, x): - """ - DoubleConv forward - :param x: Input - :return: Residual or non-residual results - """ - if self.residual: - out = x + self.double_conv(x) - if self.act == "relu": - return F.relu(out) - elif self.act == "relu6": - return F.relu6(out) - elif self.act == "silu": - return F.silu(out) - elif self.act == "lrelu": - return F.leaky_relu(out) - elif self.act == "gelu": - return F.gelu(out) - else: - logger.warning(msg=f"Unsupported activation function type: {self.act}") - return F.silu(out) - else: - return self.double_conv(x) - - -class DownBlock(nn.Module): - """ - Downsample block - """ - - def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): - """ - Initialize the downsample block - :param in_channels: Input channels - :param out_channels: Output channels - :param emb_channels: Embed channels - :param act: Activation function - """ - super().__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool2d(2), - DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), - DoubleConv(in_channels=in_channels, out_channels=out_channels, act=act), - ) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, time): - """ - DownBlock forward - :param x: Input - :param time: Time - :return: x + emb - """ - x = self.maxpool_conv(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb - - -class UpBlock(nn.Module): - """ - Upsample Block - """ - - def __init__(self, in_channels, out_channels, emb_channels=256, act="silu"): - """ - Initialize the upsample block - :param in_channels: Input channels - :param out_channels: Output channels - :param emb_channels: Embed channels - :param act: Activation function - """ - super().__init__() - - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.conv = nn.Sequential( - DoubleConv(in_channels=in_channels, out_channels=in_channels, residual=True, act=act), - DoubleConv(in_channels=in_channels, out_channels=out_channels, mid_channels=in_channels // 2, act=act), - ) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, skip_x, time): - """ - UpBlock forward - :param x: Input - :param skip_x: Merged input - :param time: Time - :return: x + emb - """ - x = self.up(x) - x = torch.cat([skip_x, x], dim=1) - x = self.conv(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb - - -class BaseConv(nn.Module): - """ - Base convolution - Conv2d -> BatchNorm -> Activation function block - """ - - def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False, act="silu"): - """ - Initialize the Base convolution - :param in_channels: Input channels - :param out_channels: Output channels - :param kernel_size: Kernel size - :param stride: Stride - :param groups: Groups - :param bias: Bias - :param act: Activation function - """ - super().__init__() - # Same padding - pad = (kernel_size - 1) // 2 - self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, - stride=stride, padding=pad, groups=groups, bias=bias) - self.gn = nn.GroupNorm(num_groups=1, num_channels=out_channels) - self.act = get_activation_function(name=act, inplace=True) - - def forward(self, x): - return self.act(self.gn(self.conv(x))) - - class Bottleneck(nn.Module): """ Standard bottleneck @@ -394,44 +108,3 @@ def forward(self, x): x_1 = self.m(x_1) x = torch.cat([x_1, x_2], dim=1) return self.conv3(x) - - -class CSPDarkDownBlock(nn.Module): - def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): - super().__init__() - self.conv_csp = nn.Sequential( - BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, act=act), - CSPLayer(in_channels=out_channels, out_channels=out_channels, n=n, act=act) - ) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, time): - x = self.conv_csp(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb - - -class CSPDarkUpBlock(nn.Module): - - def __init__(self, in_channels, out_channels, emb_channels=256, n=1, act="silu"): - super().__init__() - self.up = nn.Upsample(scale_factor=2, mode="nearest") - self.conv = BaseConv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, act=act) - self.csp = CSPLayer(in_channels=in_channels, out_channels=out_channels, n=n, shortcut=False, act=act) - - self.emb_layer = nn.Sequential( - nn.SiLU(), - nn.Linear(in_features=emb_channels, out_features=out_channels), - ) - - def forward(self, x, skip_x, time): - x = self.conv(x) - x = self.up(x) - x = torch.cat([skip_x, x], dim=1) - x = self.conv(x) - emb = self.emb_layer(time)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) - return x + emb diff --git a/model/networks/cspdarkunet.py b/model/networks/cspdarkunet.py index ec9b7cf..21f052e 100644 --- a/model/networks/cspdarkunet.py +++ b/model/networks/cspdarkunet.py @@ -9,7 +9,9 @@ import torch.nn as nn from model.networks.base import BaseNet -from model.modules.module import SelfAttention, CSPDarkUpBlock, CSPDarkDownBlock, BaseConv +from model.modules.attention import SelfAttention +from model.modules.block import CSPDarkDownBlock,CSPDarkUpBlock +from model.modules.conv import BaseConv class CSPDarkUnet(BaseNet): diff --git a/model/networks/unet.py b/model/networks/unet.py index e0189b0..9bb99c8 100644 --- a/model/networks/unet.py +++ b/model/networks/unet.py @@ -9,7 +9,9 @@ import torch.nn as nn from model.networks.base import BaseNet -from model.modules.module import UpBlock, DownBlock, DoubleConv, SelfAttention +from model.modules.attention import SelfAttention +from model.modules.block import DownBlock, UpBlock +from model.modules.conv import DoubleConv class UNet(BaseNet): diff --git a/tools/train.py b/tools/train.py index 34c7abe..ec61aea 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,7 +22,7 @@ from tqdm import tqdm sys.path.append(os.path.dirname(sys.path[0])) -from model.modules.module import EMA +from model.modules.ema import EMA from utils.initializer import device_initializer, seed_initializer, network_initializer, optimizer_initializer, \ sample_initializer, lr_initializer, fp16_initializer from utils.utils import plot_images, save_images, get_dataset, setup_logging, save_train_logging