GitHub - victoresque/pytorch-template: PyTorch template project
source link: https://github.com/victoresque/pytorch-template
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
README.md
PyTorch Template Project
A simple template project using PyTorch which can be modified to fit many deep learning projects.
Basic Usage
The code in this repo is an MNIST example of the template, try run:
python main.py
The default arguments list is shown below:
usage: main.py [-h] [-b BATCH_SIZE] [-e EPOCHS] [--resume RESUME]
[--verbosity VERBOSITY] [--save-dir SAVE_DIR]
[--save-freq SAVE_FREQ] [--data-dir DATA_DIR]
[--validation-split VALIDATION_SPLIT] [--no-cuda]
PyTorch Template
optional arguments:
-h, --help show this help message and exit
-b BATCH_SIZE, --batch-size BATCH_SIZE
mini-batch size (default: 32)
-e EPOCHS, --epochs EPOCHS
number of total epochs (default: 32)
--resume RESUME
path to latest checkpoint (default: none)
--verbosity VERBOSITY
verbosity, 0: quiet, 1: per epoch, 2: complete (default: 2)
--save-dir SAVE_DIR
directory of saved model (default: model/saved)
--save-freq SAVE_FREQ
training checkpoint frequency (default: 1)
--data-dir DATA_DIR
directory of training/testing data (default: datasets)
--validation-split VALIDATION_SPLIT
ratio of split validation data, [0.0, 1.0) (default: 0.0)
--no-cuda use CPU in case there's no GPU support
You can add your own arguments.
Structure
├── base/ - abstract base classes
│ ├── base_data_loader.py - abstract base class for data loaders.
│ ├── base_model.py - abstract base class for models.
│ └── base_trainer.py - abstract base class for trainers
│
├── data_loader/ - anything about data loading goes here
│ └── data_loader.py
│
├── datasets/ - default dataset folder
│
├── logger/ - for training process logging
│ └── logger.py
│
├── model/ - models, losses, and metrics
│ ├── modules/ - submodules of your model
│ ├── saved/ - default checkpoint folder
│ ├── loss.py
│ ├── metric.py
│ └── model.py
│
├── trainer/ - trainers for your project
│ └── trainer.py
│
└── utils
├── utils.py
└── ...
Customization
Training
In most cases, you need to modify trainer/trainer.py
to fit the training logic of your project
Data loading
You can customize data loader to fit your project, just modify data_loader/data_loader.py
or add other files.
Model
Implement your model under model/
Loss/metrics
If you need to change the loss function or metrics, first import
those function in main.py
, then modify this part:
loss = my_loss metrics = [my_metric]
You'll see the logging has changed during training:
⋯
Train Epoch: 1 [53920/53984 (100%)] Loss: 0.033256
{'epoch': 1, 'loss': 0.14182623870152963, 'my_metric': 0.9568761114404268, 'val_loss': 0.06394806604976841, 'val_my_metric': 0.9804478609625669}
Saving checkpoint: model/saved/Model_checkpoint_epoch01_loss_0.14183.pth.tar ...
Train Epoch: 2 [0/53984 (0%)] Loss: 0.013225
⋯
Multiple metrics
If you have multiple metrics in your project, just add it to the metrics
list:
loss = my_loss metrics = [my_metric, my_metric2]
Now the logging shows two metrics:
⋯
Train Epoch: 1 [53920/53984 (100%)] Loss: 0.003278
{'epoch': 1, 'loss': 0.13541310020907665, 'my_metric': 0.9590804682868999, 'my_metric2': 1.9181609365737997, 'val_loss': 0.05264156081223173, 'val_my_metric': 0.9837901069518716, 'val_my_metric2': 1.9675802139037433}
Saving checkpoint: model/saved/Model_checkpoint_epoch01_loss_0.13541.pth.tar ...
Train Epoch: 2 [0/53984 (0%)] Loss: 0.023072
⋯
Currently the name shown in log is the name of the function.
Additional logging
If you have additional information to be logged, you can modify _train_epoch()
in class Trainer
, for example, say you have an additional log saved as a dictionary:
additional_log = {"x": x, "y": y}
just merge it with log
as shown below before returning:
log = {**log, **additional_log} return log
Validation data
If you have separate validation data, try implement another data loader for validation, otherwise if you just want to split validation data from training data, try pass --validation-split 0.1
, in some cases you might need to modify utils/util.py
Checkpoint naming
If you need to add prefix to your checkpoint, modify this line in main.py
identifier = type(model).__name__ + '_'
The prefix of the model will change, if you need to further change the naming of checkpoints, try modify _save_checkpoint()
in class BaseTrainer
Contributing
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
Acknowledgments
This project is heavily inspired by the project Tensorflow-Project-Template by Mahmoud Gemy, be sure to star it!
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK