From 471536473307c53f2f99c3946c17afdfb0cd0de3 Mon Sep 17 00:00:00 2001 From: chairc <974833488@qq.com> Date: Tue, 7 May 2024 00:15:13 +0800 Subject: [PATCH] Add: Add better FID calculator to verify image quality. --- README.md | 2 + README_zh.md | 2 + tools/FID_calculator_plus.py | 71 ++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 tools/FID_calculator_plus.py diff --git a/README.md b/README.md index 2dd191f..006f457 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ Integrated Design Diffusion Model ├── tools │ ├── deploy.py │ ├── FID_calculator.py +│ ├── FID_calculator_plus.py │ ├── generate.py │ └── train.py ├── utils @@ -114,6 +115,7 @@ Integrated Design Diffusion Model - [x] 10. Reconstruct the overall structure of the model (2023-12-06) - [x] 11. Write visual webui interface. (2024-01-23) - [x] 12. Adding PLMS Sampling Method. (2024-03-12) +- [x] 13. Adding FID calculator to verify image quality. (2024-05-06) ### Training diff --git a/README_zh.md b/README_zh.md index 5c92563..2531260 100644 --- a/README_zh.md +++ b/README_zh.md @@ -84,6 +84,7 @@ Integrated Design Diffusion Model ├── tools │ ├── deploy.py │ ├── FID_calculator.py +│ ├── FID_calculator_plus.py │ ├── generate.py │ └── train.py ├── utils @@ -113,6 +114,7 @@ Integrated Design Diffusion Model - [x] 10. 重构model整体结构(2023-12-06) - [x] 11. 编写可视化webui界面(2024-01-23) - [x] 12. 增加PLMS采样方法(2024-03-12) +- [x] 13. 增加FID方法验证图像质量(2024-05-06) ### 训练 diff --git a/tools/FID_calculator_plus.py b/tools/FID_calculator_plus.py new file mode 100644 index 0000000..c01765d --- /dev/null +++ b/tools/FID_calculator_plus.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +""" + @Date : 2024/5/4 23:36 + @Author : chairc + @Site : https://github.com/chairc +""" +import os +import sys +import argparse +import logging + +import coloredlogs + +from pytorch_fid.fid_score import save_fid_stats, calculate_fid_given_paths +from pytorch_fid.inception import InceptionV3 + +sys.path.append(os.path.dirname(sys.path[0])) +from utils.initializer import device_initializer + +logger = logging.getLogger(__name__) +coloredlogs.install(level="INFO") + + +def main(args): + logger.info(msg=f"[Note]: Input params: {args}") + device_id = args.use_gpu + paths = args.path + batch_size = args.batch_size + num_workers = args.num_workers + dims = args.dims + device = device_initializer(device_id=device_id) + # TODO: Check image size + # Compute fid + if args.save_stats: + save_fid_stats(paths=paths, batch_size=batch_size, device=device, dims=dims, num_workers=num_workers) + return + + fid_value = calculate_fid_given_paths(paths=paths, batch_size=batch_size, device=device, dims=dims, + num_workers=num_workers) + + logger.info(msg=f"The result of FID: {fid_value}") + + +if __name__ == "__main__": + # Before calculating + # [Note]: We recommend resizing both sets of images to the same format, the same size, and the same number + parser = argparse.ArgumentParser() + # Function1: Generated image folder and dataset image folder + # Function2: Save stats input path and output path (use `--save_stats`) + parser.add_argument("path", type=str, nargs="*", + default=["/your/generated/image/folder/or/stats/input/path", + "/your/dataset/image/folder/or/stats/output/path"], + help="Paths to the generated images or to .npz statistic files") + # Batch size + parser.add_argument("--batch_size", type=int, default=8, + help="Batch size for calculation.") + # Number of workers + parser.add_argument("--num-workers", type=int, default=0) + # Dimensionality of Inception features to use + # Option: 64/192/768/2048 + parser.add_argument("--dims", type=int, default=2048, + choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), + help="Dimensionality of Inception features to use. By default, uses pool3 features") + parser.add_argument("--save_stats", action="store_true", + help="Generate an npz archive from a directory of samples. " + "The first path is used as input and the second as output.") + # Set the use GPU in normal training (required) + parser.add_argument("--use_gpu", type=int, default=0) + args = parser.parse_args() + main(args)