17

Numpy编写BP传播过程全解

 3 years ago
source link: https://mp.weixin.qq.com/s?__biz=MzIwMTc4ODE0Mw%3D%3D&%3Bmid=2247519914&%3Bidx=1&%3Bsn=af2bad8963d67845daca82f7bac47280
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

EnAnQz.gif!mobile

©PaperWeekly 原创 · 作者|孙裕道

学校|北京邮电大学博士生

研究方向|GAN图像生成、情绪对抗样本生成

jiuMzu2.png!mobile

引言

BP反向传播矩阵推导图示详解 一文中在矩阵视角下对 BP 的原理进行了详细的介绍, 神经网络中权重的梯度由前一层前向传播值与后一层的误差值整合计算得到 。该文中也对吴恩达的斯坦福机器学习的讲义中的相关部分进行了证明。

BP 的矩阵形式的推导的好处在于它的矩阵表示形式对编程非常有指导意义,当前有很多的热门的深度学习框架,例如 Pytorch 和 Tensorflow,像这种深度学习框架集成性很高,神经网络中 BP 更新参数的过程几行代码就搞定,这对于从代码中理解其原理造成了一定的困难。

Numpy 编写 BP 更新参数的过程则是一个好的方式去了解其原理, 代码链接如下:

https://github.com/guidao20/BP_Numpy

文件并不复杂,两个 py 文件加一个 mnist 数据集。本文会对该代码进行详细的介绍。

7Zf6NfJ.png!mobile

IbQFVj6.png!mobile

预备知识

一个 4 层的神经网络如下图所示,各个层的维度分别是 784,128,64,10;神经网络的权重分别为,,,具体的前向计算过程如下所示。

MjY3aeM.png!mobile

▲ 图1.神经网络前向计算过程

根据 BP 反向传播矩阵推导图示详解 中 Section 6 的推导过程可得到各层网络的权重梯度计算示意图(示意图中各字母代表的含义查阅 BP 反向传播矩阵推导图示详解 的 Section 6)。

bQbeUbi.png!mobile

▲ 图2.神经网络BP原理

BP 反向传播矩阵推导图示详解 中的 Section 8 是对吴恩达机器学习讲义中的关于BP 原理部分的等价证明,所以可以将上图 2 的示意图重新整理如下,其中反向传播的误差的计算公式在图中蓝色字体。本文的要介绍的 Numpy 代码就按照下图的形式进行编程。

nEBzEbM.png!mobile

▲ 图3.神经网络BP原理(吴)

ZJfe2i7.png!mobile

代码详解 

本文的代码结构非常简单如下图所示一共三个文件。data 文件夹下是 mnist 手写体数字集的压缩包 mnist.pkl.gz;mnist_loader.py 是用于加载 mnist 数据集的;BP_Numpy.py 是 BP 训练神经网络的程序,也是本文重点要讲的程序。

uIVbIfQ.png!mobile

由图 4 可知程序中使用的激活函数为 Sigmoid 函数,其中 Sigmoid 函数的定义为:

yIJNjme.png!mobile

Sigmoid 函数的导数定义为:

UnEfMvM.png!mobile

jqmQnqj.png!mobile

▲ 图4.Sigmoid函数及其导数

图 5 是类 NerualNetwork 的初始化,进而构建一个神经网络,分别对神经网络的尺寸(有几层,每一层的单元数是多少),每一层的权重和偏置进行初始化。

Nfmyqai.png!mobile

▲ 图5.类初始化

图 6 是神经网络的前向计算过程,先做线性变换,然后再进行激活。以 Section 2 预备知识中图 1 的神经网络为例(以便更清楚的交代出各个矩阵的维度)。假如一共有四层神经网络,各个层的单元数为 784,128,64,10,根据代码则前向计算过程可以归结为:

bUnYZrV.png!mobile

▲ 图6.前向计算过程

图 7 是 BP 反向传播求梯度的过程,图 7 中的 黄色框区域 是用列表存储前计算的激活值,最后求出损失函数,这里损失函数为。这里需要注意的是 反向传播求梯度并不是不需要前向计算值 。图 7 中的蓝色框区域是计算输出层的权重梯度,对应于图 3 中的计算过程:

BV7Bja6.png!mobile

图 7 中的绿色框区域是从后往前以此计算各个层的权重梯度,对应于图 3 中和的计算过程:

A3a6Zr.png!mobile

这里需要注意的是符号表示的是矩阵相乘,符号表示的是向量元素对应位置相乘。另外程序中有对偏置求梯度的操作,其原理跟求解权重的原理类似。

3yeuM3f.png!mobile

▲ 图7. BP反向传播

图 8 主要是对神经网络参数进行更新(图中红框所示),以上操作将各个层的权重和偏置的梯度求解出来,再利用梯度下降对各个层的权重和偏置的参数进行更新。

6NRfMjE.png!mobile

▲ 图8.权重和偏置更新

设置学习率的步长为 0.1,epoch 为 1000,batch_size 为 10,则可得到如下图程序结果,由结果可知,通过 BP 求解参数梯度,再利用梯度下降法,损失函数整体是减小的。

JfuYRj.png!mobile

▲ 图9.权重和偏置更新

更多阅读

qINrIjN.png!mobile

UZnQ7b7.png!mobile

6RrUNji.png!mobile

2mUBJnB.gif!mobile

# 投 稿 通 道 #

让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢? 答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是 最新论文解读 ,也可以是 学习心得技术干货 。我们的目的只有一个,让知识真正流动起来。

:memo:  来稿标准:

• 稿件确系个人 原创作品 ,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

:mailbox_with_mail:  投稿邮箱:

• 投稿邮箱: [email protected] 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

:mag:

现在,在 「知乎」 也能找到我们了

进入知乎首页搜索 「PaperWeekly」

点击 「关注」 订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击 「交流群」 ,小助手将把你带入 PaperWeekly 的交流群里。

R7nmyuB.gif!mobile

nUvi2i.jpg!mobile


Recommend

  • 22
    • dockone.io 6 years ago
    • Cache

    Kubernetes中的负载均衡全解

    Kubernetes中的负载均衡全解

  • 36
    • database.51cto.com 6 years ago
    • Cache

    MySQL高性能优化实战全解!

    MySQL高性能优化实战全解!

  • 33
    • sunxichun.github.io 5 years ago
    • Cache

    [iOS]定位使用全解 - sunxc | Blog

  • 16
    • dockone.io 5 years ago
    • Cache

    Kubernetes Node全解

    介 绍 Kubernetes在GitHub上拥有超过48,000颗星,超过75,000个commit,拥有以Google为代表的科技巨头公司为主要贡献者。可以说,Kubernetes已迅速掌管了容器生态系统,成为容器编排平台的真正领导者。 Kubernetes提供...

  • 35
    • 掘金 juejin.im 4 years ago
    • Cache

    剑指 Offer 全解(Java 版)

    本文转自个人博客:CyC2018/CS-Notes 3. 数组中重复的数字 NowCoder 题目描述 在一个长度为 n 的数组里的所有数字都在 0 到 n-1 的范围内。数组中某些数字是重复的,但不知道有几个数字是重复的,也不知道每个数字重复几次。请找出数

  • 25

    底层基础设施安全设计 一、物理基础架构安全 谷歌数据中心包括了 生物识别、金属感应探测、监控、通行障碍和激光...

  • 19

    对于 \(Softmax\) 回归的正向传播非常简单,就是对于一个输入 \(X\) 对每一个输入标量 \(x_i\) 进行加权求和得到 \(Z\) 然后对其做概率归一化。 Softmax 示意图...

  • 27

    一个模拟人群行为和病毒传播过程的实验分析上海交通大学 计算机应用技术硕士

  • 2

    为数据仓库编写SQL存储过程的技巧 - babbling在数据仓库应用程序中,我们需要想办法有效地回填我们的数据并大规模快速运行我们的 SQL。回填是指我们想要在表中填充过去 X 天的数据。为此,我们的 SQL 必须是可重复和可水平扩展的。我们需要以不会泄漏数据或导致...

  • 1

    使用NumPy演示​​实现神经网络过程 在不断发展的人工智能(模拟智能)领域,有一个想法经...

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK