diff --git a/utils/initializer.py b/utils/initializer.py index 2bb727c..7d771ed 100644 --- a/utils/initializer.py +++ b/utils/initializer.py @@ -5,6 +5,7 @@ @Author : chairc @Site : https://github.com/chairc """ +import os import random import numpy as np import torch @@ -21,6 +22,7 @@ from model.samples.ddim import DDIMDiffusion from model.samples.ddpm import DDPMDiffusion from model.samples.plms import PLMSDiffusion +from utils.check import check_path_is_exist from utils.lr_scheduler import set_cosine_lr logger = logging.getLogger(__name__) @@ -258,3 +260,22 @@ def check_param_in_dict(param, dict_params, args_param): act = check_param_in_dict(param="act", dict_params=ckpt_state, args_param=args.act) logger.info(msg=f"[{device}]: Successfully checked parameters.") return conditional, network, image_size, num_classes, act + + +def classes_initializer(dataset_path): + """ + Initialize number of classes + :param dataset_path: Dataset path + :return: num_classes + """ + check_path_is_exist(path=dataset_path) + num_classes = 0 + # Loop dataset path + for classes_dir in os.listdir(path=dataset_path): + # Check current dir + if os.path.isdir(s=os.path.join(dataset_path, classes_dir)): + num_classes += 1 + logger.info(msg=f"Current number of classes is {num_classes}.") + if num_classes == 0: + raise Exception(f"No dataset folders found in '{dataset_path}'.") + return num_classes