【深度学习】NLP中的对抗训练

news2025/1/2 4:00:48

        在NLP中,对抗训练往往都是针对嵌入层(包括词嵌入,位置嵌入,segment嵌入等等)开展的,思想很简单,即针对嵌入层添加干扰,从而提高模型的鲁棒性和泛化能力,下面结合具体代码讲解一些NLP中常见对抗训练算法。

1.Fast Gradient Method(FGM)

        FGM的思想是针对词嵌入加入梯度方向的干扰,至于干扰的大小是我们可以调节的,增加干扰后的样本可以作为额外的对抗样本进行训练,以此提高模型的效果。由于我们在训练时会针对每个样本都进行一次额外的增加干扰后的训练,所以使用FGM后训练时间理论上也会大概增加一倍。

        FGM在原训练代码的基础上,主要增加了以下几个额外的操作:针对嵌入层添加干扰并备份参数,计算添加干扰后的损失,梯度回传从而累积添加干扰后的梯度,恢复原来的嵌入层参数。

1.1 算法流程

对于每个x:
  1.计算x的前向loss、反向传播得到梯度
  2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r
  3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上
  4.将embedding恢复为(1)时的值
  5.根据(3)的梯度对参数进行更新

  1.2 具体代码

import torch
class FGM():
    def __init__(self, model):
        self.model = model
        self.backup = {}

    def attack(self, epsilon=1., emb_name='word_embeddings'):
        # emb_name这个参数要换成你模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                #print('增加扰动的对象是', name)
                #print(type(param.grad))
                self.backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0 and not torch.isnan(norm):
                    r_at = epsilon * param.grad / norm
                    param.data.add_(r_at)

    def restore(self, emb_name='word_embeddings'):
        # emb_name这个参数要换成你模型中embedding的参数名
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name: 
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

1.3 具体用法

fgm = FGM(model) # (#1)初始化
for batch_input, batch_label in data:
    loss = model(batch_input, batch_label) # 正常训练
    loss.backward() # 反向传播,得到正常的grad
    # 对抗训练
    fgm.attack() # (#2)在embedding上添加对抗扰动
    loss_adv = model(batch_input, batch_label) # (#3)计算含有扰动的对抗样本的loss
    loss_adv.backward() # (#4)反向传播,并在正常的grad基础上,累加对抗训练的梯度
    fgm.restore() # (#5)恢复embedding参数
    # 梯度下降,更新参数
    optimizer.step()
    model.zero_grad()

2.Projected Gradient Descent (PGD

        Project Gradient Descent(PGD)是一种迭代攻击算法,相比于普通的FGM 仅做一次迭代,PGD是做多次迭代,每次走一小步,每次迭代都会将扰动投射到规定范围内。其中r为扰动约束空间(一个半径为r的球体),原始的输入样本对应的初识点为球心,避免扰动超过球面。迭代多次后,保证扰动在一定范围内,如下图所示:

 2.1 算法流程

对于每个x:
  1.计算x的前向loss、反向传播得到梯度并备份
  对于每步t:
    2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r(超出范围则投影回epsilon内)
    3.t不是最后一步: 将梯度归0,根据1的x+r计算前后向并得到梯度
    4.t是最后一步: 恢复(1)的梯度,计算最后的x+r并将梯度累加到(1)上
  5.将embedding恢复为(1)时的值
  6.根据(4)的梯度对参数进行更新

 2.2  具体代码

import torch
class PGD():
    def __init__(self, model):
        self.model = model
        self.emb_backup = {}
        self.grad_backup = {}
 
    def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings', is_first_attack=False):
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name:
                if is_first_attack:
                    self.emb_backup[name] = param.data.clone()
                norm = torch.norm(param.grad)
                if norm != 0 and not torch.isnan(norm):
                    r_at = alpha * param.grad / norm
                    param.data.add_(r_at)
                    param.data = self.project(name, param.data, epsilon)
 
    def restore(self, emb_name='word_embeddings'):
        for name, param in self.model.named_parameters():
            if param.requires_grad and emb_name in name: 
                assert name in self.emb_backup
                param.data = self.emb_backup[name]
        self.emb_backup = {}
 
    def project(self, param_name, param_data, epsilon):
        r = param_data - self.emb_backup[param_name]
        if torch.norm(r) > epsilon:
            r = epsilon * r / torch.norm(r)
        return self.emb_backup[param_name] + r
 
    def backup_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.grad_backup[name] = param.grad.clone()
 
    def restore_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.grad = self.grad_backup[name]

2.3 具体用法

pgd = PGD(model)
K = 3
for batch_input, batch_label in data:
    # 正常训练
    loss = model(batch_input, batch_label)
    loss.backward() # 反向传播,得到正常的grad
    pgd.backup_grad()
    # 累积多次对抗训练——每次生成对抗样本后,进行一次对抗训练,并不断累积梯度
    for t in range(K):
        pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data
        if t != K-1:
            model.zero_grad()
        else:
            pgd.restore_grad()
        loss_adv = model(batch_input, batch_label)
        loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
    pgd.restore() # 恢复embedding参数
    # 梯度下降,更新参数
    optimizer.step()
    model.zero_grad()

Reference:

1.NLP中的对抗训练_colourmind的博客-CSDN博客

2.【NLP】NLP中的对抗训练_风度78的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/878599.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

剑指offer(C++)-JZ56:数组中只出现一次的两个数字(算法-位运算)

作者:翟天保Steven 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 题目描述: 一个整型数组里除了两个数字只出现一次,其他的数字都出现了两次。请写程序找出这两个只出现一…

DoIP学习笔记系列:(五)“安全认证”的.dll从何而来?

文章目录 1. “安全认证”的.dll从何而来?1.1 .dll文件base1.2 增加客户需求算法传送门 DoIP学习笔记系列:导航篇 1. “安全认证”的.dll从何而来? 无论是用CANoe还是VFlash,亦或是编辑cdd文件,都需要加载一个与$27服务相关的.dll(Windows的动态库文件),这个文件是从哪…

ES踩坑记录之UNASSIGNED分片无法恢复

问题背景 换节点 我们线上有一套ES集群,三台机器,共运行了6个节点。一直在线上跑了几个月也一直没出什么问题。然而好巧不巧,就在昨天,集群中的3号节点磁盘出现故障,导致机器直接瘫痪。本来大家觉得问题不大&#xf…

Lua学习记录

Lua基础了解 Lua的注释通过 (-- 单行注释,--[[ ]] 多行注释)可以不加; 多个变量赋值,按顺序赋值,没有则为nil; function的简单用法,多个返回值配合多重赋值,以end为结束标志 Lua下标从1开始&…

R语言生存分析(机器学习)(1)——GBM(梯度提升机)

GBM是一种集成学习算法,它结合了多个弱学习器(通常是决策树)来构建一个强大的预测模型。GBM使用“Boosting”的技术来训练弱学习器,这种技术是一个迭代的过程,每一轮都会关注之前轮次中预测效果较差的样本,…

二叉树题目:二叉树的直径

文章目录 题目标题和出处难度题目描述要求示例数据范围 解法思路和算法代码复杂度分析 题目 标题和出处 标题:二叉树的直径 出处:543. 二叉树的直径 难度 3 级 题目描述 要求 给定二叉树的根结点 root \texttt{root} root,返回其直径…

Docker 基本管理(一)

目录 一、虚拟化简介 1.1.虚拟化概述 1.2.cpu的时间分片(cpu虚拟化) 1.3.cpu虚拟化性性能瓶颈 1.4.虚拟化工作原理 1.5 虚拟化类型 1.6 虚拟化功能 ​二、Docker容器概述 2.1 docker是什么? 2.2 使用docker有什么意义&#xff…

nginx上web服务的基本安全优化、服务性能优化、访问日志优化、目录资源优化和防盗链配置简介

一.基本安全优化 1.隐藏nginx软件版本信息 2.更改源码来隐藏软件名和版本 (1)修改第一个文件(核心头文件),在nginx安装目录下找到这个文件并修改 (2)第二个文件 (3)…

算法通过村第三关-数组青铜笔记|单调数组

文章目录 前言单调数组问题搜索插入位置:数组合并问题:总结 前言 提示:本份真诚面对自己、坦然无碍面对他人,就是优雅。 数组中的比较经典性问题: 单调数组问题数组合并问题 单调数组问题 参考例子:896. 单调数列…

【BEV Review】论文 Delving into the Devils of Bird’s-eye-view 2022-9 笔记

背景 一般来说,自动驾驶车辆的视觉传感器(比如摄像头)安装在车身上方或者车内后视镜上。无论哪个位置,摄像头所得到的都是真实世界在透视视图(Perspective View)下的投影(世界坐标系到图像坐标系…

Docker数据卷容器

1.数据卷容器介绍 即使数据卷容器c3挂掉也不会影响c1和c2通信。 2.配置数据卷容器 创建启动c3数据卷容器,使用-v参数设置数据卷。volume为目录,这种方式数据卷目录就不用写了,直接写宿主机目录。 创建c1、c2容器,使用–volum…

MapStruct 中 Java Bean 映射代码生成器的基本使用

文章目录 一、简介:二、背景:三、相关概念:1、映射器(Mapper):2、映射方法(Mapping Method):3、常规映射方法(Regular Mapping Method)&#xff1…

多功能杆在智慧农业中的应用

随着农业现代化发展,农业生产和管理不断运用越来越多新技术、新设施,以提高农业生产的综合效率、产品质量,降低管理经营成本。诸如数字化监测、物联网管理、5G远程控制,以及本次我们为大家介绍的多功能智慧杆系统。 多功能智慧杆拥…

股权激励一发布,股价飙升买别墅?

主要内容: 1.股权激励计划的含义 2.股权激励的公告数据 3.公告日到授予日股价变化 4.构建股权激励策略 5.策略运行结果 当谈到现代科技领域的先锋人物,马斯克无疑是其中的佼佼者,他人生经历可谓尽是高光时刻。 1981年10岁的马斯克用攒到…

每日温度(力扣)单调栈 JAVA

给定一个整数数组 temperatures ,表示每天的温度,返回一个数组 answer ,其中 answer[i] 是指对于第 i 天,下一个更高温度出现在几天后。如果气温在这之后都不会升高,请在该位置用 0 来代替。 示例 1: 输入: temperatur…

使用VMware安装ubuntu和VMware tool

一、准备工作 提前准备好vmware的安装包还有Ubuntu的系统镜像 安装包已经放到网盘,链接在这篇文章中:https://blog.csdn.net/u014151564/article/details/132267441 二、使用步骤 1、打开虚拟机来到主页 在左侧右键选择新建虚拟机 2、向导步骤如图…

变压器故障诊断(python代码,逻辑回归/SVM/KNN三种方法同时使用,有详细中文注释)

代码运行要求:tensorflow版本>2.4.0,Python>3.6.0即可,无需修改数据路径。 1.数据集介绍: 采集数据的设备照片 变压器在电力系统中扮演着非常重要的角色。尽管它们是电网中最可靠的部件,但由于内部或外部的许多因素&#…

预告|8月16日-18日,相约DTCC 2023!星瑞格邀您共飨数据库技术盛宴

相约DTCC 2023,共飨数据库技术盛宴! 2023年8月16-18日,第十四届中国数据库技术大会(DTCC 2023)将于北京国际会议中心隆重召开。福建星瑞格软件有限公司(以下简称星瑞格)受邀参加本届DTCC中国数…

污水处理厂人员定位方案介绍

污水处理厂人员定位在现代化的污水处理厂中具有重要的意义,它可以带来多方面的优势和好处: 安全管理: 污水处理厂通常涉及到各种危险环境和设备,如化学品、高压设备等。人员定位系统可以追踪人员的位置,确保他们不会进…

基于C#UI Automation自动化测试

步骤 UI Automation 只适用于,标准的win32和 WPF程序 需要添加对UIAutomationClient、 UIAutomationProvider、 UIAutomationTypes的引用 代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.D…