52
GitHub - dmarnerides/pydlt: PyTorch based Deep Learning Toolbox
source link: https://github.com/dmarnerides/pydlt
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
PyTorch Deep Learning Toolbox
PyDLT is a set of tools aimed to make experimenting with PyTorch easier (than it already is).
Features
- Trainers (currently supporting Vanilla, VanillaGAN, WGAN-GP, BEGAN, FisherGAN)
trainer = dlt.train.VanillaGANTrainer(generator, discriminator, g_optim, d_optim) for batch, (prediction, losses) in trainer(data_loader): # Training happens in the iterator and relevant results are returned for each step
- Built in configurable parser with arguments.
opt = dlt.config.parse() # Has built in options (can add extra) print('Some Settings: ', opt.experiment_name, opt.batch_size, opt.lr)
- Configuration files support and parser compatible functions.
$ python main.py @settings.cfg Some Settings: config_test 32 0.0001
- HDR imaging support (.hdr, .exr, and .pfm formats)
img = dlt.hdr.imread('test.pfm') dlt.hdr.imwrite('test.exr', img)
- Checkpointing of (torch serializable) objects; Network state dicts supported.
data_chkp = Checkpointer('data') data_chkp.save(np.array([1,2,3])) a = data_chkp.load()
- Image operations and easy conversions between multiple library views (torch, cv, plt)
img = cv2.imread('image.jpg') # Height x Width x Channels - BGR dlt.viz.imshow(img, view='cv') # Height x Width x Channels - RGB tensor_with_torch_view = cv2torch(img) # Channels x Height x Width - RGB
- Easy visualization (and make_grid supporting Arrays, Tensors, Variables and lists)
for batch, (prediction, loss) in trainer(loader): grid = dlt.util.make_grid([ batch[0], batch[1], prediction], size(3, opt.batch_size)) dlt.viz.imshow(grid, pause=0.01, title='Training Progress')
- Parameter and input/outputs/gradients layer visualization.
net = nn.Sequential(nn.Linear(10, 10)) dlt.viz.modules.forward_hook(net, [nn.Linear], tag='layer_outputs', histogram=False) net(Variable(torch.Tensor(3,10)))
- CSV Logger.
log = dlt.util.Logger('losses', ['train_loss', 'val_loss']) log({'train_loss': 10, 'val_loss':20})
- Command line tool for easy plotting of CSV files (with live updating).
$ dlt-plot --file losses.csv train_loss val_loss --refresh 5 --loglog True --tail 100
- A minimal Progress bar (with global on/off switch).
from dlt.util import barit barit.silent = False # Default is False for batch in barit(loader, start='Loading'): pass
Installation
Make sure you have PyTorch installed. OpenCV is also required:
conda install -c menpo opencv
conda install (recommended):
conda install -c demetris pydlt
From source:
git clone https://github.com/dmarnerides/pydlt.git cd pydlt python setup.py install
About
I created this toolbox while learning Python and PyTorch, after working with (Lua) Torch, to help speed up experiment prototyping.
If you notice something is wrong or missing please do a pull request or open up an issue.
Contact
Demetris Marnerides: [email protected]
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK