OCaml bindings for PyTorch
source link: https://www.tuicool.com/articles/hit/iqeQnu7
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
ocaml-torch
ocaml-torchprovides some ocaml bindings for the PyTorch tensor library. This brings to OCaml NumPy-like tensor computations with GPU acceleration and tape-based automatic differentiation.
These bindings use the PyTorch C++ API and are mostly automatically generated. The current GitHub tip corresponds to PyTorch 1.0.0.
Opam Installation
The opam package can be installed using the following command. This automatically installs the CPU version of libtorch.
opam install torch
You can then compile some sample code, see some instructions below. ocaml-torch can also be used in interactive mode via utop or ocaml-jupyter .
Here is a sample utop session.
Build a Simple Example
To build a first torch program, create a file example.ml
with the following content.
open Torch let () = let tensor = Tensor.randn [ 4; 2 ] in Tensor.print tensor
Then create a dune
file with the following content:
(executables (names example) (libraries torch))
Run dune build example.exe
to compile the program and _build/default/example.exe
to run it!
Tutorials
- MNIST tutorial .
- Finetuning a ResNet-18 model .
- Generative Adverserial Networks .
- Running some Python model .
Examples
Below is an example of a linear model trained on the MNIST dataset ( full code ).
(* Create two tensors to store model weights. *) let ws = Tensor.zeros [image_dim; label_count] ~requires_grad:true in let bs = Tensor.zeros [label_count] ~requires_grad:true in let model xs = Tensor.(mm xs ws + bs) in for index = 1 to 100 do (* Compute the cross-entropy loss. *) let loss = Tensor.cross_entropy_for_logits (model train_images) ~targets:train_labels in Tensor.backward loss; (* Apply gradient descent, disable gradient tracking for these. *) Tensor.(no_grad (fun () -> ws -= grad ws * f learning_rate; bs -= grad bs * f learning_rate)); (* Compute the validation error. *) let test_accuracy = Tensor.(sum (argmax (model test_images) = test_labels) |> float_value) |> fun sum -> sum /. test_samples in printf "%d %f %.2f%%\n%!" index (Tensor.float_value loss) (100. *. test_accuracy); end
- Some ResNet examples on CIFAR-10 .
- A simplified version of char-rnn illustrating character level language modeling using Recurrent Neural Networks.
- Neural Style Transfer applies the style of an image to the content of another image. This uses some deep Convolutional Neural Network.
Models and Weights
Various pre-trained computer vision models are implemented in the vision library . The weight files can be downloaded at the following links:
- ResNet-18 weights .
- ResNet-34 weights .
- ResNet-50 weights .
- ResNet-101 weights .
- ResNet-152 weights .
- DenseNet-121 weights .
- DenseNet-161 weights .
- DenseNet-169 weights .
- SqueezeNet 1.0 weights .
- SqueezeNet 1.1 weights .
- VGG-13 weights .
- VGG-16 weights .
Running the pre-trained models on some sample images can the easily be done via the following commands.
make all _build/default/examples/pretrained/predict.exe path/to/resnet18.ot tiger.jpg
Alternative Installation Options
These alternative ways to install ocaml-torch could be useful to run with GPU acceleration enabled.
Option 1: Using PyTorch pre-built Binaries
The libtorch library can be downloaded from the PyTorch website ( 1.0.0 cpu version ).
Download and extract the libtorch library then to build all the examples run:
export LIBTORCH=/path/to/libtorch git clone https://github.com/LaurentMazare/ocaml-torch.git cd ocaml-torch make all
Option 2: Using PyTorch Conda package
Conda packages for PyTorch 1.0 can be used via the following command.
conda create -n torch source activate torch conda install pytorch-cpu=1.0.0 -c pytorch # Or for the CUDA version # conda install pytorch=1.0.0 -c pytorch git clone https://github.com/LaurentMazare/ocaml-torch.git cd ocaml-torch make all
Recommend
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK