54

Matrix multiplication: The PyTorch way

 4 years ago
source link: https://towardsdatascience.com/matrix-multiplication-the-pytorch-way-c0ad724402ed?gi=1a87e952bfab
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Then we write 3 loops to multiply the matrices element wise. The shape of the final matrix will be (number of rows matrix_1) by (number of columns of matrix_2).

Now let’s create a basic neural net where we will use this function.

In this article, we will be using the MNIST dataset for demonstration purposes. It contains 50,000 samples of handwritten digits. These digits are originally 28*28 matrices (or 784 values in a linear vector after unpacking).

Hence our neural net takes 784 values as input and gives the 10 classes as output.

Let’s now grab 5 elements from the MNIST validation set and run them through this model.

BBRbQ3j.png!web

We see that for a mere 5 elements, it took us 650 milliseconds to perform matrix multiplication. This is relatively slow. Let’s try to speed it up.

Why is speed important?

Matrix multiplication forms the basis of neural networks. Most operations while training a neural network require some form of matrix multiplication. Hence doing it well and doing it fast is really important.

z26nEzI.png!web

source: fast.ai course: Deep learning from the foundations

We will speed up our matrix multiplication by eliminating loops and replacing them with PyTorch functionalities. This will give us C speed (underneath PyTorch) instead of Python speed. Let’s see how that works.

Eliminating the innermost loop

We start by eliminating the innermost loop. The idea behind eliminating this loop is that instead of doing operations on one element at a time, we can do them on one row (or column) at a time. Take a look at the image below.

rqmYjiA.png!web

We have 2 tensors and we want to add their elements together. We can write a loop to do so or we can make use of PyTorch’s elementwise operations (a + b directly) to do the same.

Using the same idea we will eliminate the innermost loop so that instead of doing

2eEjemV.png!web

we directly do

QfArQzj.png!web

Our function now looks as follows,

3i2yamz.png!web

and takes about 1.55 milliseconds to run which is massive improvement!

If you are not familiar with the indexing syntax, a[i,:] means select the ith row and all columns while b[:,j] means select all rows and the jth column.

We can write a little test to confirm that our updated function gives the same output as our original function.


Recommend

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK