- Python 3.x
- PyTorch
- Clear folder structure which is suitable for many projects
- Separate
trainer
,model
, anddata_loader
for more structured code BaseDataLoader
handles batch loading, data shuffling, and validation data aplitting for youBaseTrainer
handles checkpoint saving/loading, training process logging
The code in this repo is an MNIST example of the template, try run:
python train.py
The default arguments list is shown below:
usage: train.py [-h] [-b BATCH_SIZE] [-e EPOCHS] [--resume RESUME]
[--verbosity VERBOSITY] [--save-dir SAVE_DIR]
[--save-freq SAVE_FREQ] [--data-dir DATA_DIR]
[--validation-split VALIDATION_SPLIT] [--no-cuda]
PyTorch Template
optional arguments:
-h, --help show this help message and exit
-b BATCH_SIZE, --batch-size BATCH_SIZE
mini-batch size (default: 32)
-e EPOCHS, --epochs EPOCHS
number of total epochs (default: 32)
--resume RESUME path to latest checkpoint (default: none)
--verbosity VERBOSITY
verbosity, 0: quiet, 1: per epoch, 2: complete (default: 2)
--save-dir SAVE_DIR directory of saved model (default: saved)
--save-freq SAVE_FREQ
training checkpoint frequency (default: 1)
--data-dir DATA_DIR directory of training/testing data (default: datasets)
--validation-split VALIDATION_SPLIT
ratio of split validation data, [0.0, 1.0) (default: 0.1)
--no-cuda use CPU instead of GPU
pytorch-template/
│
├── base/ - abstract base classes
│ ├── base_data_loader.py - abstract base class for data loaders
│ ├── base_model.py - abstract base class for models
│ └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│ └── data_loaders.py
│
├── datasets/ - default datasets folder
│
├── logger/ - for training process logging
│ └── logger.py
│
├── model/ - models, losses, and metrics
│ ├── modules/ - submodules of your model
│ ├── loss.py
│ ├── metric.py
│ └── model.py
│
├── saved/ - default checkpoints folder
│
├── trainer/ - trainers
│ └── trainer.py
│
└── utils/
├── util.py
└── ...
-
Writing your own data loader
-
Inherit
BaseDataLoader
BaseDataLoader
handles:- Generating next batch
- Data shuffling
- Generating validation data loader
BaseDataLoader.split_validation()
-
Implementing abstract methods
There are some abstract methods you need to implement before using the methods in
BaseDataLoader
_pack_data()
: pack data members into a list of tuples_unpack_data
: unpack packed data_update_data
: updata data members_n_samples
: total number of samples
-
-
DataLoader Usage
BaseDataLoader
is an iterator, to iterate through batches:for batch_idx, (x_batch, y_batch) in data_loader: pass
-
Example
Please refer to
data_loader/data_loaders.py
for an MNIST example
-
Writing your own trainer
-
Inherit
BaseTrainer
BaseTrainer
handles:- Training process logging
- Checkpoint saving
- Checkpoint resuming
- Reconfigurable monitored value for saving current best
- Controlled by the arguments
monitor
andmonitor_mode
, ifmonitor_mode == 'min'
then the trainer will save a checkpointmodel_best.pth.tar
whenmonitor
is a current minimum
- Controlled by the arguments
-
Implementing abstract methods
You need to implement
_train_epoch()
for your training process, if you need validation then you can implement_valid_epoch()
as intrainer/trainer.py
-
-
Example
Please refer to
trainer/trainer.py
-
Writing your own model
-
Inherit
BaseModel
BaseModel
handles:- Inherited from
torch.nn.Module
summary()
: Model summary
- Inherited from
-
Implementing abstract methods
Implement the foward pass method
forward()
-
-
Example
Please refer to
model/model.py
If you need to change the loss function or metrics, first import
those function in train.py
, then modify:
loss = my_loss
metrics = [my_metric]
They will appear in the logging during training
If you have multiple metrics for your project, just add them to the metrics
list:
loss = my_loss
metrics = [my_metric, my_metric2]
Additional metric will be shown in the logging
If you have additional information to be logged, in _train_epoch()
of your trainer class, merge them with log
as shown below before returning:
additional_log = {"gradient_norm": g, "sensitivity": s}
log = {**log, **additional_log}
return log
If you need to split validation data from a data loader, call BaseDataLoader.split_validation(validation_split)
, it will return a validation data loader, with the number of samples according to the specified ratio
Note: the split_validation()
method will modify the original data loader
You can specify the name of the training session in train.py
training_name = type(model).__name__
Then the checkpoints will be saved in saved/training_name
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
- Multi-GPU support
-
TensorboardX
support - Support iteration-based training (instead of epoch)
- Load settings from
config
files - Configurable logging layout
- Configurable checkpoint naming
- Options to save logs to file
- More clear trainer structure
This project is licensed under the MIT License. See LICENSE for more details
This project is inspired by the project Tensorflow-Project-Template by Mahmoud Gemy