Skip to content

Commit

Permalink
removed class number input option to training, it will take from anno…
Browse files Browse the repository at this point in the history
…tation data
  • Loading branch information
naseemap47 committed Oct 1, 2023
1 parent 318b334 commit 8c050ee
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
4 changes: 0 additions & 4 deletions data.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
names:
- Paper
- Rock
- Scissors
Dir: 'Data'
images:
test: test
Expand Down
34 changes: 19 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
import time
import yaml
import json
import os


Expand Down Expand Up @@ -83,8 +84,8 @@
break
else:
n += 1

print(f"[INFO] Checkpoints saved in \033[1m{os.path.join('runs', name)}\033[0m")

# Training on GPU or CPU
if args['cpu']:
print('[INFO] Training on \033[1mCPU\033[0m')
Expand All @@ -96,9 +97,14 @@
print(f'[INFO] Training on GPU: \033[1m{torch.cuda.get_device_name()}\033[0m')
trainer = Trainer(experiment_name=name, ckpt_root_dir='runs')

# Load Path Params
yaml_params = yaml.safe_load(open(args['data'], 'r'))

# Load Dataset
with open(os.path.join(yaml_params['Dir'], yaml_params['labels']['train'])) as f:
no_class = len(json.load(f)['categories'])
f.close()
print(f"\033[1m[INFO] Number of Classes: {no_class}\033[0m")

# Reain Dataset
trainset = COCOFormatDetectionDataset(data_dir=yaml_params['Dir'],
images_dir=yaml_params['images']['train'],
json_annotation_file=yaml_params['labels']['train'],
Expand All @@ -116,7 +122,6 @@
DetectionTargetsFormatTransform(max_targets=300, input_dim=(args['size'], args['size']),
output_format="LABEL_CXCYWH")
])

train_loader = dataloaders.get(dataset=trainset, dataloader_params={
"shuffle": True,
"batch_size": args['batch'],
Expand All @@ -126,7 +131,7 @@
"worker_init_fn": worker_init_reset_seed,
"min_samples": 512
})

# Valid Data
valset = COCOFormatDetectionDataset(data_dir=yaml_params['Dir'],
images_dir=yaml_params['images']['val'],
json_annotation_file=yaml_params['labels']['val'],
Expand All @@ -148,7 +153,7 @@
"worker_init_fn": worker_init_reset_seed
})


# Test Data
if 'test' in (yaml_params['images'].keys() or yaml_params['labels'].keys()):
testset = COCOFormatDetectionDataset(data_dir=yaml_params['Dir'],
images_dir=yaml_params['images']['test'],
Expand All @@ -175,18 +180,17 @@
if args['resume']:
model = models.get(
args['model'],
num_classes=len(yaml_params['names']),
num_classes=no_class,
checkpoint_path=args["weight"]
)
else:
model = models.get(
args['model'],
num_classes=len(yaml_params['names']),
num_classes=no_class,
pretrained_weights=args["weight"]
)

train_params = {
# ENABLING SILENT MODE
'silent_mode': False,
"average_best_models":True,
"warmup_mode": args['warmup_mode'],
Expand All @@ -204,15 +208,14 @@
"mixed_precision": True,
"loss": PPYoloELoss(
use_static_assigner=False,
num_classes=len(yaml_params['names']),
num_classes=no_class,
reg_max=16
),
"valid_metrics_list": [
DetectionMetrics_050(
score_thres=0.1,
top_k_predictions=300,
# NOTE: num_classes needs to be defined here
num_cls=len(yaml_params['names']),
num_cls=no_class,
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(
score_threshold=0.01,
Expand All @@ -232,6 +235,7 @@
# Print Training Params
print('[INFO] Training Params:\n', train_params)

# Model Training...
trainer.train(
model=model,
training_params=train_params,
Expand All @@ -241,15 +245,15 @@

# Load best model
best_model = models.get(args['model'],
num_classes=len(yaml_params['names']),
num_classes=no_class,
checkpoint_path=os.path.join('runs', name, 'ckpt_best.pth'))

# Evaluating on Val Dataset
eval_model = trainer.test(model=best_model,
test_loader=valid_loader,
test_metrics_list=DetectionMetrics_050(score_thres=0.1,
top_k_predictions=300,
num_cls=len(yaml_params['names']),
num_cls=no_class,
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01,
nms_top_k=1000,
Expand All @@ -266,7 +270,7 @@
test_loader=test_loader,
test_metrics_list=DetectionMetrics_050(score_thres=0.1,
top_k_predictions=300,
num_cls=len(yaml_params['names']),
num_cls=no_class,
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(score_threshold=0.01,
nms_top_k=1000,
Expand Down

0 comments on commit 8c050ee

Please sign in to comment.