From 8c050ee2b54f55ed297d5f69245eeb15612a3355 Mon Sep 17 00:00:00 2001 From: NASEEM A P Date: Sun, 1 Oct 2023 10:19:08 +0530 Subject: [PATCH] removed class number input option to training, it will take from annotation data --- data.yaml | 4 ---- train.py | 34 +++++++++++++++++++--------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/data.yaml b/data.yaml index 4a5aebd..8fee902 100644 --- a/data.yaml +++ b/data.yaml @@ -1,7 +1,3 @@ -names: -- Paper -- Rock -- Scissors Dir: 'Data' images: test: test diff --git a/train.py b/train.py index a5be6c0..aa7d967 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,7 @@ import torch import time import yaml +import json import os @@ -83,8 +84,8 @@ break else: n += 1 - print(f"[INFO] Checkpoints saved in \033[1m{os.path.join('runs', name)}\033[0m") + # Training on GPU or CPU if args['cpu']: print('[INFO] Training on \033[1mCPU\033[0m') @@ -96,9 +97,14 @@ print(f'[INFO] Training on GPU: \033[1m{torch.cuda.get_device_name()}\033[0m') trainer = Trainer(experiment_name=name, ckpt_root_dir='runs') + # Load Path Params yaml_params = yaml.safe_load(open(args['data'], 'r')) - - # Load Dataset + with open(os.path.join(yaml_params['Dir'], yaml_params['labels']['train'])) as f: + no_class = len(json.load(f)['categories']) + f.close() + print(f"\033[1m[INFO] Number of Classes: {no_class}\033[0m") + + # Reain Dataset trainset = COCOFormatDetectionDataset(data_dir=yaml_params['Dir'], images_dir=yaml_params['images']['train'], json_annotation_file=yaml_params['labels']['train'], @@ -116,7 +122,6 @@ DetectionTargetsFormatTransform(max_targets=300, input_dim=(args['size'], args['size']), output_format="LABEL_CXCYWH") ]) - train_loader = dataloaders.get(dataset=trainset, dataloader_params={ "shuffle": True, "batch_size": args['batch'], @@ -126,7 +131,7 @@ "worker_init_fn": worker_init_reset_seed, "min_samples": 512 }) - + # Valid Data valset = COCOFormatDetectionDataset(data_dir=yaml_params['Dir'], images_dir=yaml_params['images']['val'], json_annotation_file=yaml_params['labels']['val'], @@ -148,7 +153,7 @@ "worker_init_fn": worker_init_reset_seed }) - + # Test Data if 'test' in (yaml_params['images'].keys() or yaml_params['labels'].keys()): testset = COCOFormatDetectionDataset(data_dir=yaml_params['Dir'], images_dir=yaml_params['images']['test'], @@ -175,18 +180,17 @@ if args['resume']: model = models.get( args['model'], - num_classes=len(yaml_params['names']), + num_classes=no_class, checkpoint_path=args["weight"] ) else: model = models.get( args['model'], - num_classes=len(yaml_params['names']), + num_classes=no_class, pretrained_weights=args["weight"] ) train_params = { - # ENABLING SILENT MODE 'silent_mode': False, "average_best_models":True, "warmup_mode": args['warmup_mode'], @@ -204,15 +208,14 @@ "mixed_precision": True, "loss": PPYoloELoss( use_static_assigner=False, - num_classes=len(yaml_params['names']), + num_classes=no_class, reg_max=16 ), "valid_metrics_list": [ DetectionMetrics_050( score_thres=0.1, top_k_predictions=300, - # NOTE: num_classes needs to be defined here - num_cls=len(yaml_params['names']), + num_cls=no_class, normalize_targets=True, post_prediction_callback=PPYoloEPostPredictionCallback( score_threshold=0.01, @@ -232,6 +235,7 @@ # Print Training Params print('[INFO] Training Params:\n', train_params) + # Model Training... trainer.train( model=model, training_params=train_params, @@ -241,7 +245,7 @@ # Load best model best_model = models.get(args['model'], - num_classes=len(yaml_params['names']), + num_classes=no_class, checkpoint_path=os.path.join('runs', name, 'ckpt_best.pth')) # Evaluating on Val Dataset @@ -249,7 +253,7 @@ test_loader=valid_loader, test_metrics_list=DetectionMetrics_050(score_thres=0.1, top_k_predictions=300, - num_cls=len(yaml_params['names']), + num_cls=no_class, normalize_targets=True, post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01, nms_top_k=1000, @@ -266,7 +270,7 @@ test_loader=test_loader, test_metrics_list=DetectionMetrics_050(score_thres=0.1, top_k_predictions=300, - num_cls=len(yaml_params['names']), + num_cls=no_class, normalize_targets=True, post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01, nms_top_k=1000,