diff --git a/README.md b/README.md index 10c4638..8a014a1 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,12 @@ You can train your **YOLO-NAS** model with **Single Command Line** ``` python3 train.py --data /dir/dataset/data.yaml --batch 6 --epoch 100 --model yolo_nas_m --size 640 ``` +### If your training ends in 65th epoch (total 100 epochs), now you can start from 65th epoch and complete your 100 epochs training. +**Example:** +``` +python3 train.py --data /dir/dataset/data.yaml --batch 6 --epoch 100 --model yolo_nas_m --size 640 \ + --weight runs/train2/ckpt_latest.pth --resume +``` ## 📺 Inference You can Inference your **YOLO-NAS** model with **Single Command Line** diff --git a/train.py b/train.py index 896d093..27f43b0 100644 --- a/train.py +++ b/train.py @@ -60,20 +60,25 @@ s_time = time.time() + if args['name'] is None: name = 'train' else: name = args['name'] - n = 0 - while True: - if not os.path.exists(os.path.join('runs', f'{name}{n}')): - name = f'{name}{n}' - os.makedirs(os.path.join('runs', name)) - print(f"[INFO] Checkpoints saved in \033[1m{os.path.join('runs', name)}\033[0m") - break - else: - n += 1 - + + if args['resume']: + name = os.path.split(args['weight'])[0].split('/')[-1] + else: + n = 0 + while True: + if not os.path.exists(os.path.join('runs', f'{name}{n}')): + name = f'{name}{n}' + os.makedirs(os.path.join('runs', name)) + break + else: + n += 1 + + print(f"[INFO] Checkpoints saved in \033[1m{os.path.join('runs', name)}\033[0m") # Training on GPU or CPU if args['cpu']: print('[INFO] Training on \033[1mCPU\033[0m') @@ -184,6 +189,13 @@ "metric_to_watch": 'mAP@0.50' } + # to Resume Training + if args['resume']: + train_params['resume'] = True + + # Print Training Params + print('[INFO] Training Params:\n', train_params) + trainer.train( model=model, training_params=train_params,