【对抗样本】【FGSM】Explaining and Harnessing Adversarial Examples 代码复现

news2024/11/27 6:39:15

简介

参考Pytorch官方的代码Adversarial Example Generation

参数设置(main.py)

# 模型选择:GPU
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
# 数据集位置
dataset_path = '../../../Datasets'
batch_size = 1
shuffle = True
download = False
# 学习率
learning_rate = 0.001
# 预训练模型位置
model_path = "../../../Pretrained_models/Model/MNISTModel_9.pth"

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Grayscale(),
])
# 扰动参数
epsilons = [0, .05, .1, .15, .2, .25, .3]
  • device:用于选择训练设备,本人是mac m1的电脑,所以使用mps训练
  • dataset_path:指定数据集路径
  • batch_size:用于DataLoader,判断每次抓取数据的数量
  • shuffle:用于DataLoader,判断是否洗牌
  • download:判断数据集是否下载
  • learning_rate:设置学习率
  • model_path:指定预训练模型路径
  • transform:指定数据集转化规则,用于对数据集中输入的图像进行预处理
    • ToTensor:转化为Tensor数据类型,同时将图片进行归一化
    • Normalize:正则化
    • Grayscale:转化为灰度图像
  • epsilons:设置扰动参数,用于测试不同扰动的对抗样本的正确率
    在这里插入图片描述

对抗样本代码主流程(main.py)

因为习惯了C++的语法,还是喜欢定义main函数,比较直观哈哈哈哈哈
主要分为三步

  1. 对数据的生成与预处理
if __name__ == '__main__':
    # 1.预处理
    train_dataset = datasets.MNIST(dataset_path, train=True, download=download, transform=transform)
    val_dataset = datasets.MNIST(dataset_path, train=False, download=download, transform=transform)

    train_DataLoader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    val_DataLoader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)

    model = torch.load(model_path).to(device)
  • 数据集生成
  • DataLoader 生成
  • 预训练模型加载
  1. 开始测试
    # 2.开始测试
    # 记录不同扰动下的准确度
    accuracies = []
    # 记录样本
    examples = []

    # 对每个epsilon运行测试
    for eps in epsilons:
        # 进行对抗样本攻击
        acc, ex = test(model, device, val_DataLoader, eps)
        # 将此扰动的准确度记录
        accuracies.append(acc)
        # 二维数组,行代表不同的epsilon,列代表当前epsilon生成的对抗样本
        examples.append(ex)

共测试扰动参数为0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3

  • 记录每种扰动参数的准确率与生成的对抗样本
  • 对每种扰动参数生成对抗样本并进行攻击(调用test.py)
  1. 绘图并进行可视化
	# 创建一个新的图形对象,图形大小设置为 5x5 英寸
    plt.figure(figsize=(5, 5))
    # 用epsilons作为x轴数据,accuracies作为y轴数据
    # *-代表数据点用*标记,点之间用直线链接
    plt.plot(epsilons, accuracies, "*-")
    # 设置y轴刻度,
    # np.arange(0, 1.1, step=0.1)生成0~1的数组,步长为0.1
    plt.yticks(np.arange(0, 1.1, step=0.1))
    # 设置x轴刻度
    # 生成0~0.3的数组,步长为0.05
    plt.xticks(np.arange(0, .35, step=0.05))
    # 将图标标题设为Accuracy vs Epsilon
    plt.title("Accuracy vs Epsilon")
    # x轴标签为Epsilon
    plt.xlabel("Epsilon")
    # y轴标签为Accuracy
    plt.ylabel("Accuracy")
    # 显示图表
    plt.show()

在这里插入图片描述

    cnt = 0
    plt.figure(figsize=(8, 10))
    # 行代表不同的epsilon
    for i in range(len(epsilons)):
        # 列代表同一epsilon生成的图像
        for j in range(len(examples[i])):
            cnt += 1
            plt.subplot(len(epsilons), len(examples[0]), cnt)
            plt.xticks([], [])
            plt.yticks([], [])
            if j == 0:
                plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)
            orig, adv, ex = examples[i][j]
            plt.title("{} -> {}".format(orig, adv))
            plt.imshow(ex, cmap="gray")
    plt.tight_layout()
    plt.show()

在这里插入图片描述

对抗样本攻击流程(test.py)

def test(model, device, test_loader, epsilon):

传入四个参数

  • model:传入训练好的神经网络
  • device:训练设备,mps
  • test_loader:测试集的DataLoader
  • epsilon:扰动参数
import torch.nn.functional as F
from torchvision import transforms
import time

from src.attack import fgsm_attack, denorm


def test(model, device, test_loader, epsilon):
    model.eval()
    accuracy = 0
    adv_examples = []
    start_time = time.time()
    for img, label in test_loader:
        img, label = img.to(device), label.to(device)
        # 作用是允许 PyTorch 跟踪输入图像的梯度,以便进行反向传播时计算对抗扰动。
        img.requires_grad = True
        output = model(img)

        init_pred = output.argmax(dim=1, keepdim=True)
        # 如果已经预测错误了,就不用进行后续操作了,进行下一轮循环
        if init_pred.item() != label.item():
            continue

        loss = F.nll_loss(output, label)

        model.zero_grad()
        loss.backward()

        # 收集图片梯度
        data_grad = img.grad.data
        # 恢复图片到原始尺度
        data_denorm = denorm(img, device)
        perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)

        """
        重新进行归一化处理
        如果不对生成的对抗样本进行归一化处理,程序可能会受到以下几个方面的影响:

        1. 输入数据分布不一致
        模型在训练时,输入数据经过了归一化处理,使得数据的分布具有均值和标准差的特定统计特性。如果对抗样本在进行攻击后没有进行归一化处理,其数据分布将与模型训练时的数据分布不一致。这种不一致可能导致模型对对抗样本的预测不准确。

        2. 模型性能下降
        由于输入数据分布的变化,模型的权重和偏置项可能无法适应未归一化的数据,从而导致模型性能下降。模型可能无法正确分类这些未归一化的对抗样本,从而影响模型的预测准确率。

        3. 扰动效果不可控
        在 FGSM 攻击中,添加的扰动是在未归一化的数据上进行的。如果不进行归一化处理,这些扰动在模型输入阶段可能会被放大或缩小,影响攻击的效果。这样,攻击的成功率和对抗样本的生成效果可能会变得不可控。
        """
        perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)
        output = model(perturbed_data_normalized)

        final_pred = output.argmax(dim=1, keepdim=True)
        if final_pred.item() == label.item():
            accuracy += 1
            if epsilon == 0 and len(adv_examples) < 5:
                """
                perturbed_data 是经过FGSM攻击后的对抗样本,仍是一个tensor张量
                squeeze 会移除所有大小为1的维度
                    比如MNIST中batch_size = 1 channel=1 像素为28x28,则perturbed_data.shape = (1,1,28,28)
                    通过squeeze会变为(28,28)
                detach      代表不在跟踪其梯度,类似于
                            你有一个银行账户(相当于张量 x),你希望在这个账户基础上做一些假设性的计算(比如计划未来的支出),
                            但不希望这些假设性的计算影响到实际的账户余额。
                            银行账户余额(张量 x):

                            你现在的账户余额是 $1000。
                            你可以对这个余额进行正常的交易(如存款、取款),这些交易会影响余额。
                            假设性的计算(使用 detach()):

                            你想做一些假设性的计算,比如计划未来的支出,看看在不同情况下余额会变成多少。
                            你将当前余额复制一份(使用 detach()),对这份复制的余额进行操作。
                            不管你对复制的余额进行什么操作,都不会影响到实际的账户余额。
                cpu 将张量从GPU移到CPU,因为NumPy不支持GPU张量
                numpy   将tensor转化为Numpy数组
                """
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
        else:
            if len(adv_examples) < 5:
                adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))

    # Calculate final accuracy for this epsilon
    final_acc = accuracy / float(len(test_loader))
    end_time = time.time()
    print(f"Epsilon: {epsilon}\tTest Accuracy = {accuracy} / {len(test_loader)} = {final_acc},Time = {end_time - start_time}")
    # Return the accuracy and an adversarial example
    return final_acc, adv_examp

需要注意的是,在生成对抗样本的时候,需要先调用自定义的denorm方法进行反归一化,具体的原因是

		# 反归一化
        data_denorm = denorm(img, device)
        # 生成对抗样本
        perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
        # 将对抗样本标准化
        perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data)

在我们构建生成数据集时的transform时,

  • ToTensor不仅有转化数据类型的作用,还有归一化的作用,将原来单通道像素值0~255归一化到0~1
  • Normalize进行了标准化操作,在标准化后可能出现像素值又>1的情况,又因为在fgsm_attack方法中
perturbed_image = torch.clamp(perturbed_image, 0, 1)

将图像重新归一化,所以我们需要使用denorm,在生成对抗样本之前进行反归一化,目的是将预处理标准化可能>1的情况进行消除,重新回到0~1
如果不进行反归一化,会导致生成的对抗样本与原图的偏差较大
在生成对抗样本后,重新进行标准化


adv_example主要用于存储五个生成的对抗样本,用于后续的图像生成

对抗样本生成(attack.py)

def fgsm_attack(image, epsilon, data_grad):
    """
    Perform FGSM with
    :param image: 输入图片
    :param epsilon: 𝜀超参数
    :param data_grad: 梯度
    :return:
    """
    # 获取梯度方向
    sign_data_grad = data_grad.sign()
    # 对原始图像添加扰动
    perturbed_image = image + epsilon * sign_data_grad
    # 将生成的对抗样本的扰动控制在0~1之间
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

传入的参数为

  • image:Tensor类型的图片
  • epsilon:扰动的参数
  • data_grad:图片的梯度
    方法的作用为生成对抗样本
  1. 根据传入的梯度参数获取梯度的方向
  2. 将原始图片家养梯度方向的扰动,使得生成的图像在视觉上与原始图像几乎相同,但模型的预测可能会发生变化
  3. 将生成的对抗样本的扰动控制在0~1之间

反归一化(attack.py)

def denorm(batch, device, mean=[0.1307], std=[0.3081]):
  • batch:传入的图像
  • device:训练设备
  • mean:均值
  • std:标准差
def denorm(batch, device, mean=[0.1307], std=[0.3081]):
    """
    Convert a batch of tensors to their original scale.

    Args:
        batch (torch.Tensor): Batch of normalized tensors.
        device:
        mean (torch.Tensor or list): Mean used for normalization.
        std (torch.Tensor or list): Standard deviation used for normalization.

    Returns:
        torch.Tensor: batch of tensors without normalization applied to them.
    """
    if isinstance(mean, list):
        mean = torch.tensor(mean, requires_grad=True).to(device)
    if isinstance(std, list):
        std = torch.tensor(std, requires_grad=True).to(device)

    return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)

将图像重新归一化

完整代码

见Github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1816406.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

express入门03增删改查

目录 1 搭建服务器2 静态文件托管3 引入bootstrap4 引入jquery5 编写后端接口5.1 添加列表查询方法5.2 添加路由5.3 添加数据表格 总结 我们前两篇介绍了如何利用express搭建服务器&#xff0c;如何实现静态资源托管。那利用这两篇的知识点&#xff0c;我们就可以实现一个小功能…

WebSocket 快速入门 与 应用

WebSocket 是一种在 Web 应用程序中实现实时、双向通信的技术。它允许客户端和服务器之间建立持久性的连接&#xff0c;以便可以在两者之间双向传输数据。 以下是 WebSocket 的一些关键特点和工作原理&#xff1a; 0.特点&#xff1a; 双向通信&#xff1a;WebSocket 允许服务…

艾宾浩斯winform单词系统+mysql

为用户提供集词典、题库、记忆单词功能于一体的应用&#xff0c;为用户提供目的性强、科学高效、多样化的记忆单词方法&#xff0c;使用户学习英语和记忆单词的效率得到提高 单词记忆模块 管理模块 查询单词 阅读英文 查看词汇 记忆单词 收藏单词 字段管理设置 统计 艾宾浩斯wi…

springBoot多数据源使用、配置

又参加了一个新的项目&#xff0c;虽然是去年做的项目&#xff0c;拿来复用改造&#xff0c;但是也学到了很多。这个项目会用到其他项目的数据&#xff0c;如果调用他们的接口取数据&#xff0c;我还是觉得太麻烦了。打算直接配置多数据源。 然后去另一个数据库系统中取出数据…

【C语音 || 数据结构】二叉树--堆

文章目录 前言堆1.1 二叉树的概念1.2 满二叉树和完美二叉树1.3 堆的概念1.4 堆的性质1.4 堆的实现1.4.1堆的向上调整算法1.4.1堆的向下调整算法1.4.1堆的接口实现1.4.1.1堆的初始化1.4.1.2堆的销毁1.4.1.3堆的插入1.4.1.4堆的删除1.4.1.4堆的判空1.4.1.4 获取堆的数据个数 前言…

当客户一上来就问你产品价格,你可以多尝试问问

做外贸业务&#xff0c;每个对产品不了解的客户&#xff0c;很多人一上来都会习惯性地问我们价格。一些新手业务会比较直接&#xff0c;一下子就把价格报出去了&#xff0c;很容易因为报错价格导致客户杳无音讯。 其实这个时候&#xff0c;我们最应该做的是尝试跟客户多聊一聊…

vuInhub靶场实战系列--Kioptrix Level #4

免责声明 本文档仅供学习和研究使用,请勿使用文中的技术源码用于非法用途,任何人造成的任何负面影响,与本人无关。 目录 免责声明前言一、环境配置1.1 靶场信息1.2 靶场配置 二、信息收集2.1 主机发现2.1.1 netdiscover2.1.2 arp-scan主机扫描 2.2 端口扫描2.3 指纹识别2.4 目…

MySQL-子查询(DQL 结束)

054-where后面使用子查询 什么是子查询 select语句中嵌套select语句就叫做子查询。select语句可以嵌套在哪里&#xff1f; where后面、from后面、select后面都是可以的。 select ..(select).. from ..(select).. where ..(select)..where后面使用子查询 案例&#xff1a;找…

国际贸易条件简称的解析说明

声明&#xff1a;本文仅代表作者观点和立场&#xff0c;不代表任何公司&#xff01;仅用于SAP软件应用学习参考。 SAP创建销售订单的界面有个国际贸易条件的字段&#xff0c;这个字段选择值主要有如下选择值&#xff0c;国际贸易条件简称的具体解析说明如下&#xff1a; EXW &…

【文档智能】包含段落的开源的中文版面分析模型

github&#xff1a;https://github.com/360AILAB-NLP/360LayoutAnalysis 权重下载地址&#xff1a;https://huggingface.co/qihoo360/360LayoutAnalysis 一、背景 在当今数字化时代&#xff0c;文档版式分析是信息提取和文档理解的关键步骤之一。文档版式分析&#xff0c;也…

数据价值管理-数据验收标准

前情提要&#xff1a;数据价值管理是指通过一系列管理策略和技术手段&#xff0c;帮助企业把庞大的、无序的、低价值的数据资源转变为高价值密度的数据资产的过程&#xff0c;即数据治理和价值变现。第一讲介绍了业务架构设计的基本逻辑和思路。前面我们讲完了数据资产建设标准…

零售业上云为什么首选谷歌云

零售业是国民经济的重要组成部分&#xff0c;在促进经济发展、改善人民生活水平方面发挥着重要作用。零售业也是一个竞争激烈的行业&#xff0c;零售企业需要不断创新经营方式、提高服务质量才能在竞争中立于不败之地。 近年来&#xff0c;中国企业在品牌出海方面&#xff0c;一…

大模型 - Langchain-Chatchat小白本地部署踩坑血泪史

环境介绍 windows 11python 3.9.9显卡 GTX970 4G显存 &#xff08;可怜巴巴&#xff09;内存 24G 一、下载 Langchain-Chatchat 注意&#xff1a;这里先不要执行依赖下载&#xff0c;如果项目是通过 PyCharm 打开&#xff0c;就不要着急下载依赖&#xff0c;跟着往下面走&am…

算法第六天:力扣第977题有序数组的平方

一、977.有序数组的平方的链接与题目描述 977. 有序数组的平方的链接如下所示&#xff1a;https://leetcode.cn/problems/squares-of-a-sorted-array/description/https://leetcode.cn/problems/squares-of-a-sorted-array/description/ 给你一个按 非递减顺序 排序的整数数组…

【Qt 学习笔记】Qt窗口 | 标准对话框 | 输入对话框QInputDialog

博客主页&#xff1a;Duck Bro 博客主页系列专栏&#xff1a;Qt 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ Qt窗口 | 标准对话框 | 输入对话框QInputDialog 文章编号&#xff1a;…

vue3+electron搭建桌面软件

vue3electron开发桌面软件 最近有个小项目, 客户希望像打开 网易云音乐 那么简单的运行起来系统. 前端用 Vue 会比较快一些, 因此决定使用 electron 结合 Vue3 的方式来完成该项目. 然而, 在实施过程中发现没有完整的博客能够记录从创建到打包的流程, 摸索一番之后, 随即梳理…

图的遍历介绍

概念 特点 无论是进行哪种遍历&#xff0c;均需要通过设置辅助数组标记顶点是否被访问来避免重复访问&#xff01;&#xff01;&#xff01;&#xff01; 类型 深度优先遍历 可以实现一次遍历访问一个连通图中的所有顶点&#xff0c;只要连通就能继续向下访问。 因此&#x…

getDay 与 getUTCDay 本质区别

背景 我在做这个实验的时候是北京时间&#xff1a;2024年6月12日 下午16&#xff1a;32分许 研究方向 本文探讨 getDay 与 getUTCDay 本质区别 测试用例 如果你现在的时区设置的是 &#xff08;UTC08:00&#xff09; 北京&#xff0c;重庆&#xff0c;香港特别行政区&#x…

二刷算法训练营Day29 | 回溯算法(5/6)

目录 详细布置&#xff1a; 1. 491. 非递减子序列 2. 46. 全排列 3. 47. 全排列 II 详细布置&#xff1a; 1. 491. 非递减子序列 给你一个整数数组 nums &#xff0c;找出并返回所有该数组中不同的递增子序列&#xff0c;递增子序列中 至少有两个元素 。你可以按 任意顺序…

智能盒子如何检测进气压力传感器?

进气压力传感器是一种用于测量发动机进气系统中压力的传感器。安装在发动机的进气管路或进气歧管上&#xff0c;用于监测进气压力的变化。进气压力传感器的作用是将测量到的压力信号转换为电信号&#xff0c;以便发动机控制单元(ECU)可以根据压力变化来调整燃油喷射量、点火时机…