以FGSM算法为例的对抗训练的实现(基于Pytorch)

news2024/10/1 9:38:55

1. 前言

深度学习虽然发展迅速,但是由于其线性的特性,受到了对抗样本的影响,很容易造成系统功能的失效。
以图像分类为例子,对抗样本很容易使得在测试集上精度很高的模型在对抗样本上的识别精度很低。
对抗样本指的是在合法数据上添加了特定的小的扰动,人眼不能明显分辨但是会影响深度学习模型的输出的样本。
对抗样本攻击示例
常见的防御方法有对抗训练Adversarial Training。最近我在尝试复现对抗训练,找了一下发现有一些基于tensorflow的对抗训练的代码,但是没怎么看见pytorch的代码,所以我在这里做一个记录。

2. 参考文献:

1书籍《AI安全之对抗样本入门》
2 论文 Explaining and Harnessing Adversarial Examples 论文链接

3. 以FGSM算法为例的对抗训练的实现(基于Pytorch)

3.1 FGSM算法概述

FGSM算法由上面参考文献中的论文首次提出。
全名叫Fast Gradient Sign Method
前言介绍中说了,对抗样本是在合法数据上添加一个特定的扰动,FGSM算法就是用来计算这个扰动的,称其为 η \eta η
FGSM算法属于一个白盒攻击算法,即知道被攻击模型的结构,模型参数并且可以控制输入输出。
数学式子如下:
η = ϵ s i g n ( ∇ x J ( θ , x , y ) ) (1) \eta =\epsilon sign(\nabla_xJ(\theta,x,y)) \tag{1} η=ϵsign(xJ(θ,x,y))(1)

ϵ \epsilon ϵ是攻击强度; θ \theta θ是模型的参数;x是输入模型的样本;y是x对应的标签;sign是符号函数,大于0取1,小于0取-1,等于0时就是0。

公式(1)表示,输入x后得到pred,然后pred与y计算loss,最后loss反向传播一次,获取x的梯度并且取符号,然后乘以扰动强度 ϵ \epsilon ϵ就得到了我们需要的特定的扰动。
这么做的原因是因为在把这个扰动 η \eta η加到原合法数据x上时,相当于使得x朝着loss增大的方向去变化了,loss增大的方向就是预测结果与标签不同的方向。
一般梯度下降优化模型参数时,都是减去梯度方向的某一步长的值。
代码实现:cleverhans攻防库

3.2 对抗训练的思想

论文 Explaining and Harnessing Adversarial Examples采用在训练过程中加入对抗样本的方法。这使得模型或多或少都会学习到一些对抗样本的知识,可以在一定程度上帮助模型抵御对抗样本的攻击。
但是对抗训练的缺点是针对特定攻击算法进行对抗训练的模型只对该攻击算法有防御效果,不同的攻击算法或者相同算法的不同扰动强度都会使得对抗训练的防御效果大打折扣。
对抗训练的训练过程中的目标函数:
J ~ ( θ , x , y ) = α J ( θ , x , y ) + ( 1 − α ) J ( θ , x ~ , y ) (2) \widetilde J(\theta,x,y) = \alpha J(\theta,x,y)+(1-\alpha)J(\theta,\widetilde x,y) \tag{2} J (θ,x,y)=αJ(θ,x,y)+(1α)J(θ,x ,y)(2)

x ~ \widetilde x x 是对抗样本,x是合法样本。
α \alpha α在论文中取0.5

3.3 tensorflow实现的对抗训练的代码

tensorflow的对抗训练的实现

3.4 我用pytorch实现的对抗训练代码

3.4.1 总的运行流程


if __name__ == "__main__":
	# 使用哪个设备训练---
    device = "cuda"
    # device = "cpu"
    # 攻击算法的相关参数---
    eps = 0.01  # FGSM的攻击强度
    attack_type = "FGSM"
    # 输出日志
    print("attack_type = {}".format(attack_type))
    print("eps = {}".format(eps))

    # 加载模型---
    model_name = "Resnet18"
    cls = loadModel()
    # 对抗训练用的参数
    batchsize = 16
    lr = 1e-4
    total_epochs = 500
    # 开始对抗训练
    clsWithAdv = ClsWithAdvData(cls, attack_type)
    # 对抗训练过程
    clsWithAdv.mainProcess()
    # 在测试集对应的对抗样本测试集上测试精度
    clsWithAdv.test_model()

3.4.2 加载模型的函数

def loadModel():
    """
    加载模型, 对其训练并测试精度
    :return: 返回的是已经加载好了的模型
    """
    if model_name == "Resnet18":
    	# 这里可以把模型的初始化自定义
    	# FineTuneClassifier是我自己写的初始化resnet18的类
        classifier = FineTuneClassifier(model_name="Resnet18")
    else:
    	# 模型名字不正确则抛出运行时异常
        raise RuntimeError("The model_name:{} is invalid!".format(model_name))
    # 返回初始化完毕的模型
    return classifier

3.4.3 主要的训练流程

from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from torch import optim
import time
import torch.nn as nn
from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.carlini_wagner_l2 import carlini_wagner_l2
import numpy as np
import os
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import Dataset
# 下面这两个是我自己写的
# mypath用于获取工程文件夹的目录,方便保存和加载文件时使用绝对路径
import mypath
# accuracy是计算预测结果和标签的相同个数的函数
from myutils.classifier_accuracy import accuracy

class ClsWithAdvData:
    def __init__(self, model, attack_type):
        super(ClsWithAdvData, self).__init__()
        # 输出日志
        print("trainClsByAdvData...")
        print("batchsize = {}, lr = {}, total_epochs = {}".format(batchsize, lr, total_epochs))
        # 获取数据集
        self.train_loader = DataLoader(MyOwnDataset(type="Train"), batch_size=batchsize, shuffle=True)
        self.val_loader = DataLoader(MyOwnDataset(type="Val"), batch_size=batchsize, shuffle=True)
        self.test_loader = DataLoader(MyOwnDataset(type="Test"), batch_size=batchsize, shuffle=True)
        # 训练开始的时间,用于保存模型
        self.train_begin_time = time.strftime("%Y-%m-%d_%H-%M", time.localtime())
        # 模型保存的路径
        self.ckp_path = os.path.join(mypath.root_path, "AdvTrain", "ckp", model_name + "_" + self.train_begin_time)
        # 损失函数
        self.loss_fun = nn.CrossEntropyLoss()
        # 模型送入设备
        self.classifier = model
        self.classifier.to(device)
        # 优化器
        self.optimer = optim.AdamW(self.classifier.parameters(), lr=lr, weight_decay=0.05)
        # 学习率衰减
        self.scheduler = optim.lr_scheduler.StepLR(self.optimer, step_size=10, gamma=0.99)
        # 攻击方法
        self.attack_type = attack_type
        # 训练过程中的日志数据
        self.train_loss_list = []
        self.train_acc_list = []
        self.val_acc_list = []
        
    def mainProcess(self):
        """
        对抗训练的训练流程
        """
        # ------开始训练------
        best_val_acc = 0  # 最高的验证准确率
        for e in range(total_epochs):
            # 定期清空模型,减少不必要的显存消耗
            torch.cuda.empty_cache()
            # 训练一次模型
            # 返回的是本次训练的累计loss和精度信息
            currentEpoch_train_loss, currentEpoch_train_acc = self.train(self.train_loader)
            # 计算验证集上的精度
            val_acc = self.eval_data(self.val_loader)
            # 打印一次日志
            print("[Epoch %d/%d] [Train_loss:   %f]\n [Train_acc:   %f] [Val_acc:   %f]\n" % (
                e, total_epochs,
                currentEpoch_train_loss,
                currentEpoch_train_acc,
                val_acc))
            # 将loss,acc保存到列表list中方便绘图可视化
            self.train_loss_list.append(currentEpoch_train_loss)
            self.train_acc_list.append(currentEpoch_train_acc)
            self.val_acc_list.append(val_acc)
            self.saveLossData()  # 保存loss数据
            # 本次epoch的验证精度最高则保存模型
            if val_acc > best_val_acc:
                torch.save(self.classifier.state_dict(), self.ckp_path)
                print("当前模型已经保存")
                # 更新最好验证集的记录
                best_val_acc = val_acc
        # for total_epochs end

    def train(self, train_loader):
        """
        一个epoch的训练过程

        :param train_loader: 训练数据集
        :return:
            float(epoch_loss)是本次epoch训练过程中的累计loss.
            float(train_acc) 是本次epoch训练过程中的预测准确率.
        """
        correct_num = 0  # 某次epoch分类正确的图片个数
        epoch_num = 0  # 某次epoch已经参与训练的图片个数
        epoch_loss = 0  # 某次epoch所有batch的loss的和
        # 模型置训练模式
        self.classifier.train()
        # 进度条
        bar = tqdm(train_loader, ncols=100)
        # 开始训练一个epoch
        for index, (img, label) in enumerate(bar):
            # 数据送入设备
            img, label = img.to(device), label.to(device)
            # 使用攻击算法获取对抗样本
            adv_img = self.attack(img)
            # 分类器给出预测
            pred = self.classifier(img)  # legitimate样本的预测结果
            adv_pred = self.classifier(adv_img)   # 对抗样本的预测结果
            # 计算loss
            legitimate_loss = self.loss_fun(pred, label.long())
            adv_loss = self.loss_fun(adv_pred, label.long())
            cls_loss = 0.5*legitimate_loss + 0.5*adv_loss
            # 记录本次epoch的累计loss
            epoch_loss += cls_loss
            # 更新网络
            self.optimer.zero_grad()
            cls_loss.backward()
            self.optimer.step()
            # 累计分类正确的个数
            correct_num += accuracy(pred, label)
            # 累计已经预测过的总数
            epoch_num += label.size(0)
        # epoch end
        # 计算本次epoch训练的精度
        train_acc = np.true_divide(correct_num, epoch_num)
        # 衰减学习率
        if self.optimer.state_dict()['param_groups'][0]['lr'] > lr / 1e2:
            self.scheduler.step()
        # 删除一些变量,尽可能降低内存/显存使用量
        del correct_num, epoch_num
        del img, label, pred, cls_loss
        return float(epoch_loss), float(train_acc)
    def eval_data(self, data_loader, adv=False, preparedAdvData=False):
        """
        在验证集或者测试集上评估精度

        :param data_loader: 使用的数据集
        :param adv: 是否需要将样本转换为对应的对抗样本
        :param preparedAdvData:
        是否提前准备好了对抗样本数据。
        提前准备好的advTest数据集的是一个list:
        "normal":normal_data, "adv":adv_data, "label":label_data
        :return: 评估的精度
        """
        correct_num = 0  # 本次测试中分类正确的图片个数
        epoch_num = 0  # 本次测试中已经参与的图片个数
        # 模型置训练模式
        self.classifier.eval()
        # 进度条
        bar = tqdm(data_loader, ncols=100)
        for index, data in enumerate(bar):
        	# 提前准备好了对抗样本数据的话,img使用data的第二个
            if preparedAdvData:
                img = data[1].to(device)
            else:
                img = data[0].to(device)
            # 是否需要将当前样本转换为对应的对抗样本
            if adv:
                img = self.attack(img)
            label = data[-1].to(device)
            # 分类器给出预测
            pred = self.classifier(img)
            # 累计分类正确的个数
            correct_num += accuracy(pred, label)
            # 累计已经预测过的总数
            epoch_num += label.size(0)
        # for end
        # 计算本次测试中的精度
        train_acc = np.true_divide(correct_num, epoch_num)
        # 删除变量,减少一个指向。某一块内存只有没有被任何变量指向才会被释放。
        del img, label, pred, correct_num, epoch_num
        return float(train_acc)

    def test_model(self):
        # self.loadPreTrainedModel()
        # 训练完成所有epoch, 计算测试集上的精度
        test_acc = self.eval_data(self.test_loader)
        print("在攻击前的测试数据集上测试精度...")
        print("[Test_acc:   %f]\n" % test_acc)

        # 在攻击后的测试数据集上测试精度
        # 加载攻击后的数据集
        adv_test_loader = DataLoader(MyOwnDataset(model_name,attack_type,eps), batch_size=batchsize, shuffle=True)
        test_acc = self.eval_data(adv_test_loader, adv=True)
        print("在攻击后的测试数据集上测试精度...")
        print("[Test_acc:   %f]\n" % test_acc)

    def loadPreTrainedModel(self):
        self.classifier.load_state_dict(torch.load(os.path.join(mypath.root_path, "AdvTrain", "ckp",
                                                                "my_own_ckp_name")))

    # 保存训练过程中的loss数据
    def saveLossData(self):
    	# 将日志列表保存为文件
        np.save(os.path.join(mypath.root_path, "AdvTrain", "loss", model_name + "_train_loss_list.npy"),
                torch.tensor(self.train_loss_list, device="cpu").numpy())
        np.save(os.path.join(mypath.root_path, "AdvTrain", "loss", model_name + "_train_acc_list.npy"),
                torch.tensor(self.train_acc_list, device="cpu").numpy())
        np.save(os.path.join(mypath.root_path, "AdvTrain", "loss", model_name + "_val_acc_list.npy"),
                torch.tensor(self.val_acc_list, device="cpu").numpy())
                
		# 绘制loss变化的可视化图像
        plt.figure()
        x = np.linspace(start=0, stop=len(self.train_loss_list), num=len(self.train_loss_list), dtype=np.uint32)
        plt.plot(x, torch.tensor(self.train_loss_list, device="cpu").numpy(), "red", label="train_loss_list")
        # plt.plot(x, torch.tensor(self.val_loss_list, device="cpu").numpy(), "blue", label="val_loss_list")
        plt.legend(loc='best')
        plt.savefig(os.path.join(mypath.root_path, "AdvTrain", "loss", model_name + "_loss.png"))
        plt.close()
        
		# 绘制acc变化的可视化图像
        plt.figure()
        x = np.linspace(start=0, stop=len(self.train_loss_list), num=len(self.train_loss_list), dtype=np.uint32)
        plt.plot(x, torch.tensor(self.train_acc_list, device="cpu").numpy(), "yellow", label="train_acc_list")
        plt.plot(x, torch.tensor(self.val_acc_list, device="cpu").numpy(), "green", label="val_acc_list")
        plt.legend(loc='best')
        plt.savefig(os.path.join(mypath.root_path, "AdvTrain", "loss", model_name + "_acc.png"))
        plt.close()

    def attack(self, img):
        """
        使用攻击算法得到对抗样本

        :param img: 输入的合法legitimate数据
        :return: 输入数据对应的对抗样本
        """
        if attack_type == "FGSM":
            img = fast_gradient_method(self.classifier, img, eps, np.inf)
        elif attack_type == "PGD":
            img = projected_gradient_descent(self.classifier, img,
                                             eps=eps, eps_iter=1 / 255,
                                             nb_iter=min(255 * eps + 4, 1.25 * (eps * 255)), norm=np.inf)
        elif attack_type == "CW":
            img = carlini_wagner_l2(self.classifier, img, n_classes=500,
                                      lr=lr,
                                      initial_const=initial_const,
                                      binary_search_steps=binary_search_steps,
                                      max_iterations=max_iterations)
        else:
            raise RuntimeError("Attack type is invalid!")
        return img


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/374707.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

聚类算法(下):10个聚类算法的评价指标

上篇文章我们已经介绍了一些常见的聚类算法,下面我们将要介绍评估聚类算法的指标 1、Rand Index Rand Index(兰德指数)是一种衡量聚类算法性能的指标。它衡量的是聚类算法将数据点分配到聚类中的准确程度。兰德指数的范围从0到1,1的值表示两…

Python-GEE遥感云大数据分析、管理与可视化技术及多领域案例实践应用

随着航空、航天、近地空间等多个遥感平台的不断发展,近年来遥感技术突飞猛进。由此,遥感数据的空间、时间、光谱分辨率不断提高,数据量也大幅增长,使其越来越具有大数据特征。对于相关研究而言,遥感大数据的出现为其提…

【阿旭机器学习实战】【37】电影推荐系统---基于矩阵分解

【阿旭机器学习实战】系列文章主要介绍机器学习的各种算法模型及其实战案例,欢迎点赞,关注共同学习交流。 电影推荐系统 目录电影推荐系统1. 问题介绍1.1推荐系统矩阵分解方法介绍1.2 数据集:ml-100k2. 推荐系统实现2.1 定义矩阵分解函数2.2 …

什么牌子的蓝牙耳机便宜好用?四款高品质蓝牙耳机推荐

随着时代的发展,蓝牙耳机的使用频率越来越高,不少人外出时除了带手机外,蓝牙耳机也成为了外出必备的数码产品之一。现在的蓝牙耳机品牌众多,什么牌子的蓝牙耳机便宜好用?下面,我来给大家推荐四款高品质的蓝…

ZigBee组网原理详解

关键词:RFD FFD ZigBee 1. 组网概述 组建一个完整的zigbee网状网络包括两个步骤:网络初始化、节点加入网络。其中节点加入网络又包括两个步骤:通过与协调器连接入网和通过已有父节点入网。 ZigBee网络中的节点主要包含三个:终端…

一文3000字从0到1实现基于Selenium+Python的web自动化测试框架 (建议收藏)

一、什么是Selenium? Selenium是一个基于浏览器的自动化测试工具,它提供了一种跨平台、跨浏览器的端到端的web自动化解决方案。Selenium主要包括三部分:Selenium IDE、Selenium WebDriver 和Selenium Grid。 Selenium IDE:Firefo…

阿里云服务器宝塔phpstudyIIS建站

P1 建站准备工作 1.购买云服务器 (新用户登录阿里云有阿里云服务器一个月的试用权限,但是试用期的云服务器有地区限制(不可自己选择地区),我的显示的是杭州,内地的服务器进行域名绑定的话,需要…

香港新世代加密资产网红正在崛起

2023年,历经兴衰的加密资产,在元宇宙和NFT的影响下,越来越多人开始关注这个领域。而在香港,不同的人更是成为了加密资产网红,引起加密资产热度的提升。香港加密资产政策促进网红崛起随着加密资产在全球的兴起&#xff…

OPPO手机删除文件数据恢复技巧篇

由于各种原因,所有 Android 手机上的数据都可能丢失。Oppo也是一个专注于Android操作系统的智能手机品牌。因此,您的 Oppo 设备上的数据也容易被删除和损坏。在本文中,我们将讨论 Oppo 用户恢复丢失或删除数据的不同方式。我们将详细讲解OPPO…

原始GAN-pytorch-生成MNIST数据集(原理)

文章目录1. GAN 《Generative Adversarial Nets》1.1 相关概念1.2 公式理解1.3 图片理解1.4 熵、交叉熵、KL散度、JS散度1.5 其他相关(正在补充!)1. GAN 《Generative Adversarial Nets》 Ian J. Goodfellow, Jean Pouget-Abadie, Yoshua Be…

string类的理解以及模拟实现

string类的理解为什么需要学习string类标准库中的string类string类简单了解string类常见接口string模拟实现深浅拷贝问题标准库下的stringVS环境下g环境下为什么需要学习string类 在C语言中,字符串和字符串相关的函数是分开的,不太符合面向对象的思想&a…

在线视频加密播放与防下载该如何考虑?

在线视频加密播放与防下载该如何考虑? ▲ 图 / 防录屏随机水印 1. 视频加密(分片加密)VRM加密: 将视频进行切片、对碎片逐一进行混淆式加密,包括AES128加密、XOR加密、关键帧错序等。 2. 防录屏(用名信息I…

IM即时通讯开发如何解决大量离线消息导致客户端卡顿的

大部分做后端开发的朋友,都在开发接口。客户端或浏览器h5通过HTTP请求到我们后端的Controller接口,后端查数据库等返回JSON给客户端。大家都知道,HTTP协议有短连接、无状态、三次握手四次挥手等特点。而像游戏、实时通信等业务反而很不适合用…

一个Laravel+vue免费开源的基于RABC控制的博客系统

项目介绍 CCENOTE 是一个使用 Vue3 Laravel8 开发的前后端分离的基于RABC权限控制管理的内容管理系统,由于作者本人比较喜欢写作的原因,因此开发了这个项目,后端使用的PHP的Laravel框架,并且整理了数据层与业务层,相…

node环境搭建以及接口的封装

node环境搭建 文章目录node环境搭建1.在cmd中输入命令安装express(全局)2.在自己的项目下安装serve3.测试接口4.连接mysql4.1 创建数据表4.2 在serve目录下建db下的sql.js4.3 sql.js4.4 在serve路径下安装mysql4.5 在routes 中引入并发送请求4.6 请求到数…

一文3000字从0到1教你用python+selenium搭建UI自动化测试环境以及使用

一、什么是Selenium ? Selenium 是一个浏览器自动化测试框架,它主要用于web应用程序的自动化测试,其主要特点如下:开源、免费;多平台、浏览器、多语言支持;对web页面有良好的支持;API简单灵活易…

STM32CubeMX串口USART中断发送接收数据

本文代码使用 HAL 库。 文章目录前言一、中断控制二、USART中断使用1. 中断优先级设置 :2. 使能中断3. 使能UART的发送、接收中断4. 中断收发函数5. 中断处理函数6. 中断收发回调函数三、串口中断实验串口中断发送数据点亮 led:实验现象:总结…

excel图表制作:旋风图让数据对比更直观

旋风图是我们工作中最常用的数据对比图表。旋风图中两组图表背靠背,纵坐标同向,横坐标反向。今天我们就跟大家分享两种制作旋风图的方式。如下表所示,我们以某平台各主要城市的男女粉丝数据为例,制作旋风图来对比男女用户情况。一…

中级嵌入式系统设计师2016下半年下午应用设计试题

中级嵌入式系统设计师2016下半年下午试题 试题一 阅读以下说明,回答问题1至问题3。 【说明】 某综合化智能空气净化器设计以微处理器为核心,包含各种传感器和控制器,具有检测环境空气参数(包含温湿度、可燃气体、细颗粒物等),空气净化、加湿、除湿、加热和杀菌等功能…

7、算法MATLAB ---(运算符)(语句)

运算符&语句1.关系运算符2.逻辑运算符3. if...else 控制语句4. for循环5. While循环6.控制循环退出的关键字6.1 Break6.2 Continue6.3 Return1.关系运算符 ">"大于 ">"大于等于 "<"小于 "<"小于等于 ""等于…