

预训练模型之对抗训练代码实现
source link: https://mp.weixin.qq.com/s?__biz=MjM5ODkzMzMwMQ%3D%3D&%3Bmid=2650429008&%3Bidx=3&%3Bsn=8666914b8d558795ea69ce355df81ae6
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

预训练模型之对抗训练代码实现
The following article is from 算法让生活更美好 Author BPSK
(3)不论哪种对抗基本都要求知道自己模型中的embedding的参数名,现在用的最多的就是bert,笔者这里打印了一下pytorch-transformers的bert-base-chinese模型层名:
FGSM是开山之作,也是最简单的对抗思路
class FGM(object):
def __init__(self, model, emb_name, epsilon=1.0):
# emb_name这个参数要换成你模型中embedding的参数名
self.model = model
self.epsilon = epsilon
self.emb_name = emb_name
self.backup = {}
def attack(self):
for name, param in self.model.named_parameters():
if param.requires_grad and self.emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = self.epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad and self.emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
fgm = FGM(model,epsilon=1,emb_name='word_embeddings.')
for batch_input, batch_label in processor:
# 正常训练
loss = model(batch_input, batch_label)
loss.backward() # 反向传播,得到正常的grad
# 对抗训练
fgm.attack() # 在embedding上添加对抗扰动
loss_adv = model(batch_input, batch_label)
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
fgm.restore() # 恢复embedding参数
# 梯度下降,更新参数
optimizer.step()
model.zero_grad()
相当于多步FGSM
class PGD(object):
def __init__(self, model, emb_name, epsilon=1., alpha=0.3):
# emb_name这个参数要换成你模型中embedding的参数名
self.model = model
self.emb_name = emb_name
self.epsilon = epsilon
self.alpha = alpha
self.emb_backup = {}
self.grad_backup = {}
def attack(self, is_first_attack=False):
for name, param in self.model.named_parameters():
if param.requires_grad and self.emb_name in name:
if is_first_attack:
self.emb_backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = self.alpha * param.grad / norm
param.data.add_(r_at)
param.data = self.project(name, param.data, self.epsilon)
def restore(self):
for name, param in self.model.named_parameters():
if param.requires_grad and self.emb_name in name:
assert name in self.emb_backup
param.data = self.emb_backup[name]
self.emb_backup = {}
def project(self, param_name, param_data, epsilon):
r = param_data - self.emb_backup[param_name]
if torch.norm(r) > epsilon:
r = epsilon * r / torch.norm(r)
return self.emb_backup[param_name] + r
def backup_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
self.grad_backup[name] = param.grad.clone()
def restore_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
param.grad = self.grad_backup[name]
pgd = PGD(model,emb_name='word_embeddings.',epsilon=1.0,alpha=0.3)
K = 3
for batch_input, batch_label in processor:
# 正常训练
loss = model(batch_input, batch_label)
loss.backward() # 反向传播,得到正常的grad
pgd.backup_grad()
# 对抗训练
for t in range(K):
pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.processor
if t != K-1:
model.zero_grad()
else:
pgd.restore_grad()
loss_adv = model(batch_input, batch_label)
loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
pgd.restore() # 恢复embedding参数
# 梯度下降,更新参数
optimizer.step()
model.zero_grad()
FreeLB
class FreeLB(object):
def __init__(self, adv_K, adv_lr, adv_init_mag, adv_max_norm=0., adv_norm_type='l2', base_model='bert'):
self.adv_K = adv_K
self.adv_lr = adv_lr
self.adv_max_norm = adv_max_norm
self.adv_init_mag = adv_init_mag
self.adv_norm_type = adv_norm_type
self.base_model = base_model
def attack(self, model, inputs, gradient_accumulation_steps=1):
input_ids = inputs['input_ids']
if isinstance(model, torch.nn.DataParallel):
embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids)
else:
embeds_init = getattr(model, self.base_model).embeddings.word_embeddings(input_ids)
if self.adv_init_mag > 0:
input_mask = inputs['attention_mask'].to(embeds_init)
input_lengths = torch.sum(input_mask, 1)
if self.adv_norm_type == "l2":
delta = torch.zeros_like(embeds_init).uniform_(-1, 1) * input_mask.unsqueeze(2)
dims = input_lengths * embeds_init.size(-1)
mag = self.adv_init_mag / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
elif self.adv_norm_type == "linf":
delta = torch.zeros_like(embeds_init).uniform_(-self.adv_init_mag, self.adv_init_mag)
delta = delta * input_mask.unsqueeze(2)
else:
delta = torch.zeros_like(embeds_init)
for astep in range(self.adv_K):
delta.requires_grad_()
inputs['inputs_embeds'] = delta + embeds_init
inputs['input_ids'] = None
outputs = model(**inputs)
loss, logits = outputs[:2] # model outputs are always tuple in transformers (see doc)
loss = loss.mean() # mean() to average on multi-gpu parallel training
loss = loss / gradient_accumulation_steps
loss.backward()
delta_grad = delta.grad.clone().detach()
if self.adv_norm_type == "l2":
denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1).view(-1, 1, 1)
denorm = torch.clamp(denorm, min=1e-8)
delta = (delta + self.adv_lr * delta_grad / denorm).detach()
if self.adv_max_norm > 0:
delta_norm = torch.norm(delta.view(delta.size(0), -1).float(), p=2, dim=1).detach()
exceed_mask = (delta_norm > self.adv_max_norm).to(embeds_init)
reweights = (self.adv_max_norm / delta_norm * exceed_mask + (1 - exceed_mask)).view(-1, 1, 1)
delta = (delta * reweights).detach()
elif self.adv_norm_type == "linf":
denorm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1, p=float("inf")).view(-1, 1, 1)
denorm = torch.clamp(denorm, min=1e-8)
delta = (delta + self.adv_lr * delta_grad / denorm).detach()
if self.adv_max_norm > 0:
delta = torch.clamp(delta, -self.adv_max_norm, self.adv_max_norm).detach()
else:
raise ValueError("Norm type {} not specified.".format(self.adv_norm_type))
if isinstance(model, torch.nn.DataParallel):
embeds_init = getattr(model.module, self.base_model).embeddings.word_embeddings(input_ids)
else:
embeds_init = getattr(model, self.base_model).embeddings.word_embeddings(input_ids)
return loss
freelb = FreeLB()
K = 3
for batch_input, batch_label in processor:
loss = freelb.attack(model,inputs,.....)
(1)关于这里的训练可能看着有点疑惑,还是不知道具体怎么写,可以直接看:
图片来源:
https://github.com/lonePatient/TorchBlocks/blob/e6c5959e6a3d3380bbb147f1c30f752cd8482c1a/examples/task_text_classification_freelb_cola.py#L43
55行的就是一个字典,更多详细情况看代码就知道是怎么回事了。
(2)dropout=0的问题
图片来源见参考资料。
关于dropout要不要为0,笔者建议都试一试,取其好。
Virtual Adversarial Training
这是一种基于对抗学习的半监督训练方式,如果你的标签数据较少,且还有很多未标签数据,可以试一试该方法对结果有没有效果,具体原理见参考资料笔者自己的数据集上面做过一些测试,实验结果如下:FGSM,PGD.FreeLB代码:
https://github.com/lonePatient/TorchBlocks/blob/e6c5959e6a3d3380bbb147f1c30f752cd8482c1a/torchblocks/callback/adversarial.py
NLP中的对抗训练 + PyTorch实现:
https://zhuanlan.zhihu.com/p/91269728?utm_source=wechat_session
对抗训练的理解,以及FGM、PGD和FreeLB的详细介绍:
https://blog.csdn.net/weixin_41712499/article/details/110878322
Virtual Adversarial Training解读:
https://blog.csdn.net/qq_33221657/article/details/105170202
Virtual Adversarial Training loss参考代码:
https://blog.csdn.net/guotong1988/article/details/90376004
Virtual Adversarial Training整个实现参考代码:
https://github.com/DevSinghSachan/ssl_text_classification/blob/1b92c8df59230f259a7b8a6d50b830d17e081362/training.py#L295
进技术交流群请添加AINLP小助手微信(id: ainlper)
请备注具体方向+所用到的相关技术点
关于AINLP
AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。
阅读至此了,分享、点赞、在看三选一吧🙏
Recommend
-
51
-
63
本文主要关注将弱监督方法应用到事件抽取(Event Detection)中,并利用对抗训练,来解决人工标注耗时耗力的问题,通过自动标注数据获得了显著的效果。以往的方法通常依赖于复杂的预定义规则和知识库中的现有示例来标注信息,该方法容易遭...
-
18
当前,说到深度学习中的对抗,一般会有两个含义:一个是生成对抗网络(Generative Adversarial Networks,GAN),代表着一大类先进的生成模型;另一个则是跟对抗攻击、对抗样本相关的领域,它跟GAN相关,但又很不一样,它主要关心的是模型...
-
28
简介 提到“对抗”,相信大多数人的第一反应都是CV中的对抗生成网络(GAN),殊不知,其实对抗也可以作为一种 防御机制 ,并且经过简单的修改,便能用在NLP任务上,提高模型的泛化能力。关键是,对抗训练可以写...
-
22
在训练轻量化模型时,经常发生的情况就是,明明 GPU 很闲,可速度就是上不去,用了多张卡并行也没有太大改善。 如果什么优化都不做,仅仅是使用nn.DataParallel这个模块,那么实测大概只能实现一点几倍的加速(按每秒处理的总图片数计算)...
-
7
AdvProp:两组 Batch Normalization 助你在 CNN 对抗训练中高效涨点 2个月前...
-
3
如何不写代码,训练人工智能模型?南开大学 情报学博士
-
7
牛津大学出品:随机噪声对抗训练
-
7
NLP中的对抗训练:从概念、原理到实践说道对抗训练容易让人联想到GAN这类生成模型,不过本文说的是一种提高模型稳健性的模型训练方法。这种方法最初出现在CV中,稍加修改可以引入到NLP中。本文先介绍对抗训练相关知识,然后讲述对抗...
-
7
本篇post以半监督文本分类为例,重点介绍如何通过对抗训练(Adversarial Training)的trick进一步提升系统性能。 对抗训练是一种用于监督式学习的正则化方法,虚拟对抗训练可将监督式学习的正则化方法扩展到半监督式中。上...
About Joyk
Aggregate valuable and interesting links.
Joyk means Joy of geeK