Skip to content

Commit

Permalink
Add: Add classes_initializer function, auto check the number of classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
chairc committed Apr 15, 2024
1 parent 4a35333 commit cec82ec
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@Author : chairc
@Site : https://github.com/chairc
"""
import os
import random
import numpy as np
import torch
Expand All @@ -21,6 +22,7 @@
from model.samples.ddim import DDIMDiffusion
from model.samples.ddpm import DDPMDiffusion
from model.samples.plms import PLMSDiffusion
from utils.check import check_path_is_exist
from utils.lr_scheduler import set_cosine_lr

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -258,3 +260,22 @@ def check_param_in_dict(param, dict_params, args_param):
act = check_param_in_dict(param="act", dict_params=ckpt_state, args_param=args.act)
logger.info(msg=f"[{device}]: Successfully checked parameters.")
return conditional, network, image_size, num_classes, act


def classes_initializer(dataset_path):
"""
Initialize number of classes
:param dataset_path: Dataset path
:return: num_classes
"""
check_path_is_exist(path=dataset_path)
num_classes = 0
# Loop dataset path
for classes_dir in os.listdir(path=dataset_path):
# Check current dir
if os.path.isdir(s=os.path.join(dataset_path, classes_dir)):
num_classes += 1
logger.info(msg=f"Current number of classes is {num_classes}.")
if num_classes == 0:
raise Exception(f"No dataset folders found in '{dataset_path}'.")
return num_classes

0 comments on commit cec82ec

Please sign in to comment.