16

LSTM Gradients

 3 years ago
source link: https://towardsdatascience.com/lstm-gradients-b3996e6a0296?gi=74bb82db5055
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

LSTM or Long Short Term Memory is a very important building block of complex and state of the art neural network architectures. The main idea behind this article is explaining the math behind it. To get an initial understanding of what LSTM is, I would suggest the following blog.

nuUFvmJ.png!web

Credits

Contents :

A — Concept

  • Introduction
  • Explanation
  • Derivation Prerequisites

B — Derivation

  • Output of the LSTM
  • Hidden state
  • Output gate
  • Cell state
  • Input gate
  • Forget gate
  • Input to the LSTM
  • Weights and biases

C — Back propagation through time

D — Conclusion

Concept

Introduction

3yiEBn6.jpg!web

Fig 1 : LSTM Cell

The above is a diagram for a single LSTM cell. I know it looks scary :cold_sweat: , but we will go through it one by one and by the end of the article, hopefully it will be pretty clear.

Explanation

Basically a single LSTM cell has 4 different components. Forget gate, input gate, output gate and the cell state. We will first discuss the use of these parts in brief (for detailed explanation please refer to the above blog) and then dive into the math part of it.

Forget gate

As the name suggests, this part is responsible for deciding what information is to be thrown away or kept from the last step. This is done by the first sigmoid layer.

Z3eyUfa.jpg!web

Fig 2: Forget gate marked in Blue

Based on h_t-1 (previous hidden state) and x_t (current input at time-step t), this decides a value between 0 and 1 for each value in cell state C_t-1.

NrqIver.png!web
Fig 3: Forget gate and previous cell state

For all 1’s, all the information is kept as it is, for all 0’s all the information is discarded and with other values it decides how much information from previous state is to be carried to the next state.

Input gate

Jzmiuan.jpg!web

Fig 4: Input gate marked in Blue

Christopher Olah has a beautiful explanation of what happens in the input gate. To cite his blog:

The next step is to decide what new information we’re going to store in the cell state. This has two parts. First, a sigmoid layer called the “input gate layer” decides which values we’ll update. Next, a tanh layer creates a vector of new candidate values, C~t, that could be added to the state. In the next step, we’ll combine these two to create an update to the state.

Now these two values i.e i_t and c~t combine to decide what new input is to be fed to the cell state.

Cell state

vEFj2i3.jpg!web

Fig 5: Cell state marked in Blue

Cell state serves as the memory of an LSTM. This is where they perform way better than vanilla RNN’s when dealing with longer sequences of input. At each time-step the previous cell state (C_t-1) combines with the forget gate to decide what information is to be carried forward which in turn combines with the input gate (i_t and c~t) to form the new cell state or the new memory of the cell.

Fig 6: New cell state equation

Output gate

YRvueiA.jpg!web

Fig 7: Output gate marked in Blue

At last the LSTM cell has to give some output. The cell state obtained from above is passed through a hyperbolic function called tanh so that the cell state values are filtered between -1 and 1. For details into different activation function,this is a nice blog.

Now i hope the basic cell structure of a LSTM cell is clear and we can proceed to the derivation of equations which we will use in our implementation.

Derivation Prerequisites

  1. Requirements : The core concept of deriving equations is based on backpropogation, cost function and loss. If you are not familiar with these , these are few links that will help in getting a good understanding. This article also assumes a basic understanding of high school calculus (calculating derivatives and there rules).

2. Variables : For each gates we have a set of weights and biases which will be denoted as:

  • W_f,b_f->Forget gate weight and bias
  • W_i,b_i->Input gate weight and bias
  • W_c,b_c->Candidate cell state weight and bias
  • W_o,b_o->Output gate weight and bias

W_v ,b_v -> Weight and bias associated with the Softmax layer.

f_t, i_t,c_tilede_t, o_t -> Output of the activation functions

a_f, a_i, a_c, a_o -> Input to the activation functions

J is the cost function, with respect to which we will be calculating the derivatives. Note the ( character after the underscore(_) is a subscript)

3. Forward prop equations:

Z3AzUbA.png!web

Fig 8: Gate equations

nuErMbM.png!web

Fig 9: Cell state and output equations

4. Process for calculation : Let’s take forget gate example to illustrate the calculation of the derivatives. We need to follow the path of red arrows in the below figure.

bIVzIzV.png!web

So we chalk out a path from f_t to our cost function J i.e

f_t →C_t →h_t →J.

The backpropagation happens exactly in the same step but in reverse i.e

f_t ←C_t ←h_t ←J.

J is differentiated with respect to h_t, h_t with respect to _C_t and C_t with respect to f_t.

So if we observe here , J and h_t is the last step of the cell, and if we calculate dJ/dh_t , then it can be used for calculations like dJ/dC_t since :

dJ/dC_t = dJ/dh_t * dh_t/dC_t ( Chain rule )

bIVzIzV.png!web

Similarly, the derivatives will be calculated for all the variables mentioned in point no 1.

Now that we have the variables ready and we are clear with the forward prop equations, its time to dive into deriving the derivatives through back-propagation. We will start with the output equations as we saw that the same derivatives is used in other equations. This is where the chain rule comes in. So let’s start now.

Derivation

Output of the lstm

The output has two values which we need to calculate.

  1. Softmax : For derivative of Cross Entropy Loss with Softmax we will be using the final equation directly.

biQFVrZ.jpg!web

The detailed derivation can be found below:

Hidden State

We have the hidden state as h_t. h_t is differentiated w.r.t J. According to chain rule, the derivation can be seen in the below figure. We use the value of V_t as mentioned in Fig 9 equation 7 i.e :

V_t = W_v.h_t + b_v

YBZvM3Z.jpg!web

Output gate

Variables associated : a_o and o_t.

o_t: In the below image, the path between o_t and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dV_t * dV_t/dh_t * dh_t/dO_t

dJ/dV_t * dV_t/dh_t can be written as dJ/dh_t (we have this value from hidden state).

The value of h_t = o_t * tanh(c_t) -> Fig 9 equation 6. So we only need to differentiate h_t w.r.t o_t. The differentiation will be as :-

E7R773A.jpg!web

a_o: Similarly, the path between a_o and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dV_t * dV_t/dh_t * dh_t/dO_t * dO_t/da_o

dJ/dV_t * dV_t/dh_t * dh_t/dO_t can be written as dJ/dO_t (we have this value from above o_t).

o_t = sigmoid (a_o) -> Fig 8 equation 4 . So we only need to differentiate o_t w.r.t a_o. T he differentiation will be as :-

nUrYZvi.jpg!web

Cell State

C_t is the cell state of the cell. Along with it, we also handle the candidate cell state a_c and c~_t here.

C_t :The derivation for C_t is pretty trivial, as the path from C_t to J is simple enough. C_t → h_t → V_t → J. As we already have dJ/dh_t, we directly differentiate h_t w.r.t C_t.

h_t = o_t * tanh(c_t) -> Fig 9 equation 6. So we only need to differentiate h_t w.r.t C_t.

Brqaumu.jpg!web

Note: The cell state clubbed will be explained at the end of the article.

c~_t: In the below image, the path between c~_t and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dh_t * dh_t/dC_t * dC_t/dc~_t

dJ/dh_t * dh_t/dC_t can be written as dJ/dC_t (we have this value from above).

The value of C_t is as shown in Fig 9 equation 5 (tilde (~) sign is missing in the last c_t in line no 3 in below figure -> writing mistake). So we only need to differentiate C_t w.r.t c~_t.

bUn2MzQ.jpg!web

a_c :In the below image, the path between a_c and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dh_t * dh_t/dC_t * dC_t/dc~_t * dc~_t/da_c

dJ/dh_t * dh_t/dC_t * dC_t/dc~_t can be written as dJ/dc~_t (we have this value from above).

The value of c~_t is as shown in Fig 8 equation 3. So we only need to differentiate c~_t w.r.t a_c .

FRf2EfV.jpg!web

Input gate

Variables associated : i_t and a_i

i_t: In the below image, the path between i_t and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dh_t * dh_t/dC_t * dC_t/di_t

dJ/dh_t * dh_t/dC_t can be written as dJ/dC_t (we have this value from cell state). So we only need to differentiate C_t w.r.t i_t.

fmEJrqn.jpg!web

The value of C_t is as shown in Fig 9 equation 5. So the differentiation will be as :-

YRbuqui.jpg!web

a_i :In the below image, the path between a_i and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dh_t * dh_t/dC_t * dC_t/di_t * di_t/da_i

dJ/dh_t * dh_t/dC_t * dC_t/di_t can be written as dJ/di_t (we have this value from above). So we only need to differentiate i_t w.r.t a_i.

Bru2yaI.jpg!web

Forget Gate

Variables associated : f_t and a_f

f_t: In the below image, the path between f_t and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dh_t * dh_t/dC_t * dC_t/df_t

dJ/dh_t * dh_t/dC_t can be written as dJ/dC_t (we have this value from cell state). So we only need to differentiate C_t w.r.t f_t.

bIVzIzV.png!web

The value of C_t is as shown in Fig 9 equation 5. So the differentiation will be as :-

I7JbYrF.jpg!web

a_f: In the below image, the path between f_t and J is shown. According the arrows the full equation for the differentiation will be as follows:

dJ/dh_t * dh_t/dC_t * dC_t/df_t * df_t/da_t

dJ/dh_t * dh_t/dC_t * dC_t/df_t can be written as dJ/df_t (we have this value from above). So we only need to differentiate f_t w.r.t a_f.

BJf2yyB.jpg!web

Input to the Lstm

There are 2 variables associated with input for each cell i.e previous cell state C_t-1 and previous hidden state concatenated with current input i.e

[h_t-1 ,x_t] -> Z_t

C_t-1 :This is the memory of the Lstm cell. Figure 5 shows the cell state. The derivation of C_t-1 is pretty simple as only C_t-1 and C_t are involved.

uI3UZnF.jpg!web

Z_t: As shown in the below figure, Z_t goes into 4 different path, a_f,a_i,a_o,a_c.

Z_t → a_f → f_t → C_t → h_t → J . -> Forget gate

Z_t → a_i→ i_t → C_t → h_t → J . -> Input gate

Z_t → a_c → c~_t → C_t → h_t → J . -> Candidate cell state

Z_t → a_o → o_t → C_t → h_t → J . -> Output gate

zuaaAza.jpg!web

Weights and biases

The derivation for W and b is straight forward. The below derivation is for the output gate of the Lstm. For the rest of the gates, similar process is done for weights and biases.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK