Skip to content

Commit

Permalink
Merge pull request #59 from naseemap47/qat
Browse files Browse the repository at this point in the history
Qat
  • Loading branch information
naseemap47 authored Mar 24, 2024
2 parents 1685fe2 + 6f838b0 commit 597e6db
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 13 deletions.
15 changes: 13 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,18 @@ dmypy.json
.pyre/

# Data
Data
*.jpg
*.jpeg
*.png
*.xml
*.txt
*.json
*.yaml
*.mp4

# Train & Inference
runs
*.pth
*.onnx
*.pkl
*.0
*.1
22 changes: 17 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,31 @@ git clone https://github.com/naseemap47/YOLO-NAS.git
cd YOLO-NAS
```
### Install dependencies
**Recommended**:
Create anaconda python environment
```
conda create -n yolo-nas python=3.9 -y
conda activate yolo-nas
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -y
```
**PyTorch v1.11.0** Installation
```
# conda installation
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch -y
/// OR
# PIP installation
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
**Quantization Aware Training**
```
# For Quantization Aware Training
pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com
pip install super-gradients==3.1.3
```
#### OR
Install **Super-Gradients**
```
pip3 install -r requirements.txt
pip install super-gradients==3.1.3
```

### 🎒 Prepare Dataset
Your custom dataset should be in **COCO JSON** data format.<br>
To convert **YOLO (.txt) / PASCAL VOC (.XML)** format to **COCO JSON**.<br>
Expand Down
8 changes: 4 additions & 4 deletions qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@
print(f"\033[1m[INFO] Number of Classes: {no_class}\033[0m")

# Training on GPU or CPU
name, ckpt_dir = args['weight'].split('/')[-3:-1]
_, name = args['weight'].split('/')[-3:-1]
if args['cpu']:
print('[INFO] Training on \033[1mCPU\033[0m')
trainer = Trainer(experiment_name=name, ckpt_root_dir=ckpt_dir, device='cpu')
trainer = Trainer(experiment_name=name, ckpt_root_dir='qat', device='cpu')
elif args['gpus']:
print(f'[INFO] Training on GPU: \033[1m{torch.cuda.get_device_name()}\033[0m')
trainer = Trainer(experiment_name=name, ckpt_root_dir=ckpt_dir, multi_gpu=args['gpus'])
trainer = Trainer(experiment_name=name, ckpt_root_dir='qat', multi_gpu=args['gpus'])
else:
print(f'[INFO] Training on GPU: \033[1m{torch.cuda.get_device_name()}\033[0m')
trainer = Trainer(experiment_name=name, ckpt_root_dir=ckpt_dir)
trainer = Trainer(experiment_name=name, ckpt_root_dir='qat')

# Load best model
best_model = models.get(args['model'],
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
super-gradients==3.1.3
# urllib3==1.25.9
super-gradients==3.1.3

0 comments on commit 597e6db

Please sign in to comment.