Skip to content

Commit

Permalink
Merge pull request #69 from chairc/dev
Browse files Browse the repository at this point in the history
Add: Add better FID calculator to verify image quality.
  • Loading branch information
chairc authored May 6, 2024
2 parents e321920 + 4715364 commit f9d68e3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Integrated Design Diffusion Model
├── tools
│ ├── deploy.py
│ ├── FID_calculator.py
│ ├── FID_calculator_plus.py
│ ├── generate.py
│ └── train.py
├── utils
Expand Down Expand Up @@ -114,6 +115,7 @@ Integrated Design Diffusion Model
- [x] 10. Reconstruct the overall structure of the model (2023-12-06)
- [x] 11. Write visual webui interface. (2024-01-23)
- [x] 12. Adding PLMS Sampling Method. (2024-03-12)
- [x] 13. Adding FID calculator to verify image quality. (2024-05-06)

### Training

Expand Down
2 changes: 2 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Integrated Design Diffusion Model
├── tools
│ ├── deploy.py
│ ├── FID_calculator.py
│ ├── FID_calculator_plus.py
│ ├── generate.py
│ └── train.py
├── utils
Expand Down Expand Up @@ -113,6 +114,7 @@ Integrated Design Diffusion Model
- [x] 10. 重构model整体结构(2023-12-06)
- [x] 11. 编写可视化webui界面(2024-01-23)
- [x] 12. 增加PLMS采样方法(2024-03-12)
- [x] 13. 增加FID方法验证图像质量(2024-05-06)

### 训练

Expand Down
71 changes: 71 additions & 0 deletions tools/FID_calculator_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@Date : 2024/5/4 23:36
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import sys
import argparse
import logging

import coloredlogs

from pytorch_fid.fid_score import save_fid_stats, calculate_fid_given_paths
from pytorch_fid.inception import InceptionV3

sys.path.append(os.path.dirname(sys.path[0]))
from utils.initializer import device_initializer

logger = logging.getLogger(__name__)
coloredlogs.install(level="INFO")


def main(args):
logger.info(msg=f"[Note]: Input params: {args}")
device_id = args.use_gpu
paths = args.path
batch_size = args.batch_size
num_workers = args.num_workers
dims = args.dims
device = device_initializer(device_id=device_id)
# TODO: Check image size
# Compute fid
if args.save_stats:
save_fid_stats(paths=paths, batch_size=batch_size, device=device, dims=dims, num_workers=num_workers)
return

fid_value = calculate_fid_given_paths(paths=paths, batch_size=batch_size, device=device, dims=dims,
num_workers=num_workers)

logger.info(msg=f"The result of FID: {fid_value}")


if __name__ == "__main__":
# Before calculating
# [Note]: We recommend resizing both sets of images to the same format, the same size, and the same number
parser = argparse.ArgumentParser()
# Function1: Generated image folder and dataset image folder
# Function2: Save stats input path and output path (use `--save_stats`)
parser.add_argument("path", type=str, nargs="*",
default=["/your/generated/image/folder/or/stats/input/path",
"/your/dataset/image/folder/or/stats/output/path"],
help="Paths to the generated images or to .npz statistic files")
# Batch size
parser.add_argument("--batch_size", type=int, default=8,
help="Batch size for calculation.")
# Number of workers
parser.add_argument("--num-workers", type=int, default=0)
# Dimensionality of Inception features to use
# Option: 64/192/768/2048
parser.add_argument("--dims", type=int, default=2048,
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
help="Dimensionality of Inception features to use. By default, uses pool3 features")
parser.add_argument("--save_stats", action="store_true",
help="Generate an npz archive from a directory of samples. "
"The first path is used as input and the second as output.")
# Set the use GPU in normal training (required)
parser.add_argument("--use_gpu", type=int, default=0)
args = parser.parse_args()
main(args)

0 comments on commit f9d68e3

Please sign in to comment.