39

反向传播(Backpropagation)笔记

 5 years ago
source link: http://blog.stupidme.me/2018/08/25/backpropagation/?amp%3Butm_medium=referral
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

反向传播是深度学习的基石。

导数

先回顾下导数:

\frac{df(x)}{dx}=\lim_{h->0}\frac{f(x+h)-f(x)}{h}

函数在每个变量的导数就是偏导数。

对于函数 f(x,y)=x+y\frac{\partial f}{\partial x}=1 ,同时, \frac{\partial f}{\partial y}=1

梯度就是偏导数组成的矢量。上述例子中, \Delta f=[\frac{\partial f}{\partial x},\frac{\partial f}{\partial y}]

链式法则

对于简单函数,我们可以根据公式直接计算出其导数。但是对于复杂的函数,我们就没那么容易直接写出导数。但是我们有 链式法则(chain rule)

定义不多说,咱们举个例子,感受一下链式法则的魅力。

我们熟悉的sigmoid函数 \sigma(x)=\frac{1}{1+e^{-x}} ,如果你记不住它的导数,我们怎么求解呢?

求解步骤如下:

  • 将函数模块化,分成多个基本的部分,对于每一个部分都可以使用简单的求导法则进行求导
  • 使用链式法则,将这些导数链接起来,计算出最终的导数

具体如下:

a=x ,则 \frac{\partial a}{\partial x}=1

b=-a ,则 \frac{\partial b}{\partial a}=-1

c=e^{b} ,则 \frac{\partial c}{\partial b}=e^{b}

d=1+c ,则 \frac{\partial d}{\partial c}=1

e=\frac{1}{d}

,则

\frac{\partial e}{\partial d}=\frac{-1}{d^2}

上面的e实际上就是我们的 \sigma(x) ,那么根据链式法则,有:

Jfaaiqy.png!web

sigmoid函数的导数可以直接用自身表示,这也是很奇妙的性质了。这样的求导过程是不是很简单?

反向传播代码实现

求导和链式法则我都会了,那么具体的前向传播和反向传播的代码是怎么样的呢?

这次我们使用一个更复杂一点点的例子:

f(x,y)=\frac{x+\sigma(x)}{\sigma(x)+(x+y)^2}

我们先看下它地forward pass代码:

import math
 
x = 3
y = -4
 
sigy = 1.0 / (1 + math.exp(-y)) # sigmoid function
num = x + sigy # 分子
sigx = 1.0 / (1 + math.exp(-x))
xpy = x + y
xpy_sqr = xpy**2
den = sigx + xpy_sqr # 分母
invden = 1.0 / den
f = num * invden # 函数
 

上述过程很简单对不对,就是把复杂的函数拆解成一个一个简单函数。

我们看看接下来的反向传播过程:

dnum = invden
 

因为

f = num * invden

所以有

\frac{\partial f}{\partial num} = invden

也就是

$$dnum=invden $

dinvden = num # 同理
 
dden = (-1.0 / (den**2)) * dinvden # 链式法则
 

展开来说:

\frac{\partial invden}{\partial den}=\frac{-1}{den^2}

\frac{\partial f}{\partial invden}=num

所以

dden=\frac{\partial f}{\partial den}=\frac{partial f}{\partial invden}\cdot \frac{\partial invden}{\partial den} = \frac{-1.0}{den^2}\cdot dinvden

所以,同理,我们可以写出所有的导数:

dsigx = (1) * dden 
dxpy_sqr = (1) * dden
 
dxpy = (2 * xpy) * dxpy_sqr
 
# backprob xpy = x + y
dx = (1) * dxpy
dy = (1) * dxpy
 
# 这里开始,请注意使用的是"+=",而不是"=”
dx += ((1 - sigx) * sigx) * dsigx # dsigma(x) = (1 - sigma(x))*sigma(x)
dx += (1) * dnum
 
# backprob num = x + sigy
dsigy = (1) * dnum
# 注意“+=”
dy += ((1 - sigy) * sigy) * dsigy
 

问题:

  • 上面计算过程中,为什么要用“+=”替代“=”呢?

如果变量x,y在前向传播的表达式中出现多次,那么进行反向传播的时候就要非常小心,使用+=而不是=来累计这些变量的梯度(不然就会造成覆写)。这是遵循了在微积分中的多元链式法则,该法则指出如果变量在线路中分支走向不同的部分,那么梯度在回传的时候,就应该进行累加。

联系我

  • Email: [email protected]
  • WeChat: luozhouyang0528

    fqErmq7.jpg!web
  • 个人公众号,你可能会感兴趣:

    jA36neI.jpg!web

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK