

清华大学出品:罚梯度范数提高深度学习模型泛化性
source link: https://blog.csdn.net/qq_38406029/article/details/122851202
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

神经网络结构简单,训练样本量不足,则会导致训练出来的模型分类精度不高;神经网络结构复杂,训练样本量过大,则又会导致模型过拟合,所以如何训练神经网络提高模型的泛化性是人工智能领域一个非常核心的问题。最近读到了一篇与该问题相关的文章,论文中作者在训练过程中通过在损失函数中增加正则化项梯度范数的约束从而来提高深度学习模型的泛化性。作者从原理和实验两方面分别对论文中的方法进行了详细地阐述和验证。
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz连续是对深度学习进行理论分析中非常重要且常见的数学工具,该论文就是以神经网络损失函数
是
L
i
p
s
c
h
i
t
z
是\mathrm{Lipschitz}
是Lipschitz连续为出发点进行数学推导。为了方便读者能够更流畅地欣赏论文作者漂亮的数学证明思路和过程,本文对于论文中没有展开数学证明细节进行了补充。
论文链接:https://arxiv.org/abs/2202.03599
2 L i p s c h i z \mathrm{Lipschiz} Lipschiz连续
给定一个训练数据集
S
=
{
(
x
i
,
y
i
)
}
i
=
0
n
\mathcal{S}=\{(x_i,y_i)\}_{i=0}^n
S={(xi,yi)}i=0n服从分布
D
\mathcal{D}
D,一个带有参数
θ
∈
Θ
\theta \in \Theta
θ∈Θ的神经网络
f
(
⋅
;
θ
)
f(\cdot;\theta)
f(⋅;θ),损失函数为
L
S
=
1
N
∑
i
=
1
N
l
(
y
i
,
y
i
,
θ
^
)
L_{\mathcal{S}}=\frac{1}{N}\sum\limits_{i=1}^N l(\hat{y_i,y_i ,\theta})
LS=N1i=1∑Nl(yi,yi,θ^)当需要对损失函数中的梯度范数进行约束时,则有如下损失函数
L
(
θ
)
=
L
S
+
λ
⋅
∥
∇
θ
L
S
(
θ
)
∥
p
L(\theta)=L_{\mathcal{S}}+\lambda \cdot \|\nabla_\theta L_{\mathcal{S}}(\theta)\|_p
L(θ)=LS+λ⋅∥∇θLS(θ)∥p其中
∥
⋅
∥
p
\|\cdot \|_p
∥⋅∥p表示
p
p
p范数,
λ
∈
R
+
\lambda\in \mathbb{R}^{+}
λ∈R+为梯度惩罚系数。一般情况下,损失函数引入梯度的正则化项会使得其在优化过程中在局部有更小的
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz常数,
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz常数越小,就意味着损失函数就越平滑,平损失函数平滑区域易于损失函数优化权重参数。进而会使得训练出来的深度学习模型有更好的泛化性。
深度学习中一个非常重要而且常见的概念就是
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz连续。给定一个空间
Ω
⊂
R
n
\Omega \subset \mathbb{R}^n
Ω⊂Rn,对于函数
h
:
Ω
→
R
m
h:\Omega \rightarrow \mathbb{R}^m
h:Ω→Rm,如果存在一个常数
K
K
K,对于
∀
θ
1
,
θ
2
∈
Ω
\forall \theta_1,\theta_2 \in \Omega
∀θ1,θ2∈Ω满足以下条件则称
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz连续
∥
h
(
θ
1
)
−
h
(
θ
2
)
∥
2
≤
K
⋅
∥
θ
1
−
θ
2
∥
2
\|h(\theta_1)-h(\theta_2)\|_2 \le K \cdot \|\theta_1 - \theta_2\|_2
∥h(θ1)−h(θ2)∥2≤K⋅∥θ1−θ2∥2其中
K
K
K表示的是
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz常数。如果对于参数空间
Θ
⊂
Ω
\Theta \subset \Omega
Θ⊂Ω,如果
Θ
\Theta
Θ有一个邻域
A
\mathcal{A}
A,且
h
∣
A
h|_{\mathcal{A}}
h∣A是
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz连续,则称
h
h
h是局部
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz连续。直观来看,
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz常数描述的是输出关于输入变化速率的一个上界。对于一个小的
L
i
p
s
c
h
i
t
z
\mathrm{Lipschitz}
Lipschitz参数,在邻域
A
\mathcal{A}
A中给定任意两个点,它们输出的改变被限制在一个小的范围里。
根据微分中值定理,给定一个最小值点
θ
i
\theta_i
θi,对于任意点
∀
θ
i
′
∈
A
\forall \theta_i^{\prime}\in \mathcal{A}
∀θi′∈A,则有如下公式成立
∥
∣
L
(
θ
i
′
)
−
L
(
θ
i
)
∥
2
=
∥
∇
L
(
ζ
)
(
θ
i
′
−
θ
i
)
∥
2
\||L(\theta_i^{\prime})-L(\theta_i)\|_2 = \|\nabla L (\zeta) (\theta_i^{\prime}-\theta_i)\|_2
∥∣L(θi′)−L(θi)∥2=∥∇L(ζ)(θi′−θi)∥2其中
ζ
=
c
θ
i
+
(
1
−
c
)
θ
i
′
,
c
∈
[
0
,
1
]
\zeta=c \theta_i + (1-c)\theta^\prime_i, c \in [0,1]
ζ=cθi+(1−c)θi′,c∈[0,1],根据
C
a
u
c
h
y
-
S
c
h
w
a
r
z
\mathrm{Cauchy\text{-}Schwarz}
Cauchy-Schwarz不等式可知
∥
∣
L
(
θ
i
′
)
−
L
(
θ
i
)
∥
2
≤
∥
∇
L
(
ζ
)
∥
2
∥
(
θ
i
′
−
θ
i
)
∥
2
\||L(\theta_i^{\prime})-L(\theta_i)\|_2 \le \|\nabla L (\zeta)\|_2 \|(\theta_i^{\prime}-\theta_i)\|_2
∥∣L(θi′)−L(θi)∥2≤∥∇L(ζ)∥2∥(θi′−θi)∥2当
θ
i
′
→
θ
\theta_i^{\prime}\rightarrow \theta
θi′→θ时,相应的
L
i
p
s
c
h
i
z
\mathrm{Lipschiz}
Lipschiz常数接近
∥
∇
L
(
θ
i
)
∥
2
\|\nabla L(\theta_i)\|_2
∥∇L(θi)∥2。因此可以通过减小
∥
∇
L
(
θ
i
)
∥
\|\nabla L(\theta_i)\|
∥∇L(θi)∥的数值使得模型能够更平滑的收敛。
3 论文方法
对带有梯度范数约束的损失函数求梯度可得
∇
θ
L
(
θ
)
=
∇
θ
L
S
(
θ
)
+
∇
θ
(
λ
⋅
∥
∇
θ
L
S
(
θ
)
∥
p
)
\nabla_\theta L(\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+\nabla_\theta(\lambda \cdot \|\nabla_\theta L_{\mathcal{S}}(\theta)\|_p)
∇θL(θ)=∇θLS(θ)+∇θ(λ⋅∥∇θLS(θ)∥p)在本文中,作者令
p
=
2
p=2
p=2,此时则有如下推导过程
∇
θ
∥
∇
θ
L
S
(
θ
)
∥
2
=
∇
θ
[
∇
θ
⊤
L
S
(
θ
)
⋅
∇
θ
L
S
(
θ
)
]
1
2
=
1
2
⋅
∇
θ
2
L
S
(
θ
)
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
∇θ∥∇θLS(θ)∥2=∇θ[∇θ⊤LS(θ)⋅∇θLS(θ)]21=21⋅∇θ2LS(θ)∥∇θLS(θ)∥2∇θLS(θ)将该结果带入到梯度范数约束的损失函数中,则有以下公式
∇
θ
L
(
θ
)
=
∇
θ
L
S
(
θ
)
+
λ
⋅
∇
θ
2
L
S
(
θ
)
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
\nabla_\theta L(\theta)=\nabla_\theta L_{\mathcal{S}}(\theta)+\lambda \cdot \nabla^2_\theta L_{\mathcal{S}}(\theta) \frac{\nabla_\theta L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2}
∇θL(θ)=∇θLS(θ)+λ⋅∇θ2LS(θ)∥∇θLS(θ)∥2∇θLS(θ)可以发现,以上公式中涉及到
H
e
s
s
i
a
n
\mathrm{Hessian}
Hessian矩阵的计算,在深度学习中,计算参数的
H
e
s
s
i
a
n
\mathrm{Hessian}
Hessian矩阵会带来高昂的计算成本,所以需要用到一些近似的方法。作者将损失函数进行泰勒展开,其中令
H
=
∇
θ
2
L
S
(
θ
)
H=\nabla^2_\theta L_\mathcal{S}(\theta)
H=∇θ2LS(θ),则有
L
S
(
θ
+
Δ
θ
)
=
L
S
(
θ
)
+
∇
θ
⊤
L
S
(
θ
)
⋅
Δ
θ
+
1
2
Δ
θ
⊤
H
Δ
θ
+
O
(
∥
Δ
θ
∥
2
2
)
L_\mathcal{S}(\theta+\Delta \theta)=L_\mathcal{S}(\theta)+\nabla^{\top}_{\theta}L_\mathcal{S}(\theta)\cdot \Delta \theta + \frac{1}{2} \Delta \theta^{\top} H \Delta \theta +\mathcal{O}(\|\Delta \theta\|_2^2)
LS(θ+Δθ)=LS(θ)+∇θ⊤LS(θ)⋅Δθ+21Δθ⊤HΔθ+O(∥Δθ∥22)进而则有
∇
θ
L
S
(
θ
+
Δ
θ
)
=
∇
Δ
θ
L
S
(
θ
+
Δ
θ
)
=
∇
θ
L
S
(
θ
)
+
H
Δ
θ
+
O
(
∥
Δ
θ
∥
2
2
)
∇θLS(θ+Δθ)=∇ΔθLS(θ+Δθ)=∇θLS(θ)+HΔθ+O(∥Δθ∥22)其中令
Δ
θ
=
r
v
\Delta \theta=r v
Δθ=rv,
r
r
r表示一个小的数值,
v
v
v表示一个向量,带入上式则有
H
v
=
∇
θ
L
S
(
θ
+
r
v
)
−
∇
θ
L
S
(
θ
)
r
+
O
(
r
)
H v =\frac{\nabla_\theta L_{\mathcal{S}}(\theta + r v)-\nabla_\theta L_{\mathcal{S}}(\theta)}{r}+\mathcal{O}(r)
Hv=r∇θLS(θ+rv)−∇θLS(θ)+O(r)如果令
v
=
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
v=\frac{\nabla_{\theta}L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|}
v=∥∇θLS(θ)∥∇θLS(θ),则有
H
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
≈
∇
θ
L
(
θ
+
r
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
)
−
∇
θ
L
(
θ
)
r
H \frac{\nabla_{\theta}L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2}\approx \frac{\nabla_\theta L(\theta + r\frac{\nabla_\theta L_{\mathcal{S}}(\theta)}{\|\nabla_\theta L_{\mathcal{S}}(\theta)\|_2})-\nabla_\theta L(\theta)}{r}
H∥∇θLS(θ)∥2∇θLS(θ)≈r∇θL(θ+r∥∇θLS(θ)∥2∇θLS(θ))−∇θL(θ)
综上所述,经过整理可得
∇
θ
L
(
θ
)
=
∇
θ
L
S
(
θ
)
+
λ
r
⋅
(
∇
θ
L
S
(
θ
+
r
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
)
−
∇
θ
L
S
(
θ
)
)
=
(
1
−
α
)
∇
θ
L
S
(
θ
)
+
α
∇
θ
L
S
(
θ
+
r
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
)
∇θL(θ)=∇θLS(θ)+rλ⋅(∇θLS(θ+r∥∇θLS(θ)∥2∇θLS(θ))−∇θLS(θ))=(1−α)∇θLS(θ)+α∇θLS(θ+r∥∇θLS(θ)∥2∇θLS(θ))其中
α
=
λ
r
\alpha=\frac{\lambda}{r}
α=rλ,称
α
\alpha
α为平衡系数,取值范围为
0
≤
α
≤
1
0 \le \alpha \le 1
0≤α≤1。作者为了避免在近似计算梯度时,以上公式中的第二项链式法则求梯度需要计算
H
e
s
s
i
a
n
\mathrm{Hessian}
Hessian矩阵,做了以下的近似则有
∇
θ
L
S
(
θ
+
r
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
)
≈
∇
θ
L
S
(
θ
)
∣
θ
=
θ
+
r
∇
θ
L
S
(
θ
)
∥
∇
θ
L
S
(
θ
)
∥
2
\nabla_\theta L_\mathcal{S}(\theta+r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2})\approx \nabla_\theta L_\mathcal{S} (\theta)|_{\theta =\theta +r \frac{\nabla_\theta L_\mathcal{S}(\theta)}{\|\nabla_\theta L_\mathcal{S}(\theta)\|_2}}
∇θLS(θ+r∥∇θLS(θ)∥2∇θLS(θ))≈∇θLS(θ)∣θ=θ+r∥∇θLS(θ)∥2∇θLS(θ)以下算法流程图对本论文的训练方法进行汇总
4 实验结果
下表表示的是在
C
i
f
a
r
10
\mathrm{Cifar10}
Cifar10和
C
i
f
a
r
100
\mathrm{Cifar100}
Cifar100这两个数据集中不同
C
N
N
\mathrm{CNN}
CNN网络结构在标准训练,
S
A
M
\mathrm{SAM}
SAM和本文的梯度约束这三种训练方法之间的测试错误率的比较。可以很直观的发现,本文提出的方法在绝大多数情况下测试错误率都是最低的,这也从侧面验证了经过论文方法的训练可以提高
C
N
N
\mathrm{CNN}
CNN模型的泛化性。
论文作者也在当前非常热门的网络结构
V
i
s
i
o
n
T
r
a
n
s
f
o
r
m
e
r
\mathrm{Vision \text{ } Transformer}
Vision Transformer进行了实验。下表表示的是在
C
i
f
a
r
10
\mathrm{Cifar10}
Cifar10和
C
i
f
a
r
100
\mathrm{Cifar100}
Cifar100这两个数据集中不同
V
i
T
\mathrm{ViT}
ViT网络结构在标准训练,
S
A
M
\mathrm{SAM}
SAM和本文的梯度约束这三种训练方法之间的测试错误率的比较。同理也可以发现本文提出的方法在所有情况下测试错误率都是最低的,这说明本文的方法也可以提到
V
i
s
i
o
n
t
r
a
n
s
f
o
r
m
e
r
\mathrm{Vision \text{ } transformer}
Vision transformer模型的泛化性。
Recommend
-
77
机器学习中大部分都是优化问题,大多数的优化问题都可以使用梯度下降/上升法处理,所以,搞清楚梯度算法就非常重要学习梯度,需要一定的数学知识:导数(Derivative)、偏导数(Partial derivative)和方向导数(Directional derivative)。
-
8
预测从瞎猜开始 按 上一篇文章 所说,机器学习是应用数学方法在数据中发现规律的过程。既然数学是对现实世界的解释,那么我们回归现实世界,做一些对照的想...
-
6
【Real analysis(1)】范数、测度和距离 2017年06月04日 Author: Guofei 文章归类: 5-1-代数与分析 ,文章编号: 92201 版权声明:本文作者是郭飞...
-
6
之前和大家分享了「产品体验设计-产品认知篇」一些基础认知相关内容,这次内容主要是关于如何在成熟产品形态上去提升产品体验设计的方法讲解。
-
6
死活想不起某个词语?清华大学出品的「反向词典」帮你告别「词不达意」 有一个古老的段子是这样说的:古人登上泰山,会感叹「会当凌绝顶,一...
-
4
一个简单直观Lp范数上下界推导,并获得一个重要的结论。 闵可夫斯基距离为, d(X,Y)=(n∑i=1|xi−yi|p)1pd(X,Y)=(∑i=1n|xi−yi|p)1p可以简约成LpLp范数,表示为, Lp(x1,…,xn)=(n∑i=1∣∣xi∣∣p)1pLp(x1,…,xn)=(∑i=1n|xi|p)1p有时候我们想快速...
-
4
范数正则化的原理分析(一):贝叶斯学派角度范数正则化是机器学习和深度学习中最常用的正则化手段,本文讲述从贝叶斯角度理解范数正则化,另外还提供信息论上的解释。L1、L2正则化都是解决模型过拟合的方法,它们有什么数学上的解释呢? 正...
-
3
Mr.Feng BlogNLP、深度学习、机器学习、Python、Go范数正则化的原理分析(二):参数约束与最大熵原理最大熵原理角度看参数约束 统计约束与最大熵分布
-
4
特征值和奇异值 矩阵的奇异值(续) 其中,旋转和缩放不改变向量的维数。矩阵特征值运算,实际上就是将向量V旋转缩放到一个正交基W上。因为V和W等维,所以要求矩阵必须是方阵。 正交化过程,代表旋转变换,又被称为等距同...
-
7
范数及求导 | 沉默杀手范数及求导 2023-02-27|
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK