Pytorch | 利用BIM/I-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

news2025/2/13 17:46:42

Pytorch | 利用BIM/I-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

  • CIFAR数据集
  • BIM介绍
    • 基本原理
    • 算法流程
  • BIM代码实现
    • BIM算法实现
    • 攻击效果
  • 代码汇总
    • bim.py
    • train.py
    • advtest.py

之前已经针对CIFAR10训练了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
Pytorch | 从零构建ParNet对CIFAR10进行分类

本篇文章我们使用Pytorch实现BIM/I-FGSM对CIFAR10上的ResNet分类器进行攻击.

CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:

在这里插入图片描述

BIM介绍

BIM(Basic Iterative Method)算法,也称为迭代快速梯度符号法(Iterative Fast Gradient Sign Method,I-FGSM),是一种基于梯度的对抗攻击算法,以下是对它的详细介绍:

基本原理

  • 利用模型梯度:与FGSM(Fast Gradient Sign Method)算法类似,BMI算法也是利用目标模型对输入数据的梯度信息来生成对抗样本。通过在原始输入样本上添加一个微小的扰动,使得模型对扰动后的样本产生错误的分类结果。
  • 迭代更新扰动:不同于FGSM只进行一次梯度计算和扰动添加,BMI算法通过多次迭代来逐步调整扰动,每次迭代都根据当前模型对扰动后样本的梯度来更新扰动,使得扰动更具针对性和有效性,从而增加攻击的成功率。

算法流程

  1. 初始化:首先获取原始的输入图像(x)和对应的真实标签 y y y,并设置一些攻击参数,如扰动量 ϵ \epsilon ϵ、步长 α \alpha α 和迭代次数 T T T 等。然后将原始图像复制一份作为初始的对抗样本 x a d v = x x^{adv}=x xadv=x
  2. 迭代攻击:在每次迭代 t t t t = 1 , 2 , ⋯   , T t = 1, 2, \cdots, T t=1,2,,T)中,将当前的对抗样本 x a d v x^{adv} xadv 输入到目标模型 f f f 中,计算模型的输出 f ( x a d v ) f(x^{adv}) f(xadv) 和损失 J ( x a d v , y ) J(x^{adv}, y) J(xadv,y),其中损失函数通常使用交叉熵损失等。接着计算损失关于对抗样本的梯度 ∇ x a d v J ( x a d v , y ) \nabla_{x^{adv}}J(x^{adv}, y) xadvJ(xadv,y),并根据梯度的符号来更新对抗样本: x a d v = x a d v + α ⋅ sign ( ∇ x a d v J ( x a d v , y ) ) x^{adv}=x^{adv}+\alpha\cdot \text{sign}(\nabla_{x^{adv}}J(x^{adv}, y)) xadv=xadv+αsign(xadvJ(xadv,y))
  3. 裁剪扰动:为了确保扰动后的样本与原始样本在视觉上不会有太大差异,需要对更新后的对抗样本进行裁剪,使其满足 x a d v = clip ( x a d v , x − ϵ , x + ϵ ) x^{adv}=\text{clip}(x^{adv}, x-\epsilon, x+\epsilon) xadv=clip(xadv,xϵ,x+ϵ),即保证扰动后的样本在原始样本的 ϵ \epsilon ϵ 邻域内。
  4. 终止条件判断:经过(T)次迭代后,得到最终的对抗样本(x^{adv}),此时将其输入到目标模型中,若模型对其的预测结果与真实标签不同,则攻击成功,否则攻击失败。

BIM代码实现

BIM算法实现

import torch
import torch.nn as nn

def BIM(model, criterion, original_images, labels, epsilon, num_iterations=10):
    """
    BIM (Basic Iterative Method)
    I-FGSM (Iterative Fast Gradient Sign Method)

    参数:
    model: 要攻击的模型
    criterion: 损失函数
    original_images: 原始图像
    labels: 原始图像的标签
    epsilon: 最大扰动幅度
    num_iterations: 迭代次数 
    
    """
    # alpha 每次迭代步长
    alpha = epsilon / num_iterations
    perturbed_images = original_images.clone().detach().requires_grad_(True)

    for _ in range(num_iterations):
        # 计算损失
        outputs = model(perturbed_images)
        loss = criterion(outputs, labels)

        model.zero_grad()
        # 计算梯度
        loss.backward()

        # 更新对抗样本
        perturbation = alpha * perturbed_images.grad.sign()
        perturbed_images = perturbed_images + perturbation
        perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)
        perturbed_images = perturbed_images.detach().requires_grad_(True)

    return perturbed_images

攻击效果

在这里插入图片描述

代码汇总

bim.py

import torch
import torch.nn as nn

def BIM(model, criterion, original_images, labels, epsilon, num_iterations=10):
    """
    BIM (Basic Iterative Method)
    I-FGSM (Iterative Fast Gradient Sign Method)

    参数:
    model: 要攻击的模型
    criterion: 损失函数
    original_images: 原始图像
    labels: 原始图像的标签
    epsilon: 最大扰动幅度
    num_iterations: 迭代次数 
    
    """
    # alpha 每次迭代步长
    alpha = epsilon / num_iterations
    perturbed_images = original_images.clone().detach().requires_grad_(True)

    for _ in range(num_iterations):
        # 计算损失
        outputs = model(perturbed_images)
        loss = criterion(outputs, labels)

        model.zero_grad()
        # 计算梯度
        loss.backward()

        # 更新对抗样本
        perturbation = alpha * perturbed_images.grad.sign()
        perturbed_images = perturbed_images + perturbation
        perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)
        perturbed_images = perturbed_images.detach().requires_grad_(True)

    return perturbed_images

train.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import ResNet18


# 数据预处理
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 加载Cifar10训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# 定义设备(GPU或CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 初始化模型
model = ResNet18(num_classes=10)
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

if __name__ == "__main__":
    # 训练模型
    for epoch in range(10):  # 可以根据实际情况调整训练轮数
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')
                running_loss = 0.0

    torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')
    print('Finished Training')

advtest.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *
from attacks import *
import ssl
import os
from PIL import Image
import matplotlib.pyplot as plt

ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet18(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = "weights/epoch_10.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))


if __name__ == "__main__":
    # 在测试集上进行FGSM攻击并评估准确率
    model.eval()  # 设置为评估模式
    correct = 0
    total = 0
    epsilon = 16 / 255  # 可以调整扰动强度
    for data in testloader:
        original_images, labels = data[0].to(device), data[1].to(device)
        original_images.requires_grad = True
        
        attack_name = 'BIM'
        if attack_name == 'FGSM':
            perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)
        elif attack_name == 'BIM':
            perturbed_images = BIM(model, criterion, original_images, labels, epsilon)
        
        perturbed_outputs = model(perturbed_images)
        _, predicted = torch.max(perturbed_outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    # Attack Success Rate
    ASR = 100 - accuracy
    print(f'Load ResNet Model Weight from {weights_path}')
    print(f'epsilon: {epsilon}')
    print(f'ASR of {attack_name} : {ASR}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2266135.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网狐旗舰版源码搭建概览

简单的列一下: 服务端源码内核源码移动端源码核心移动端源码AI控制工具源码多款子游戏源码前端、管理后台、代理网站源码数据库自建脚本UI工程源码配置工具及二次开发帮助文档 编译环境要求 VS2015 和 Cocos3.10 环境,支持移动端 Android 一键编译&am…

【QT】:QT(介绍、下载安装、认识 QT Creator)

背景 🚀 在我们的互联网中的核心岗位主要有以下几种 开发(程序员)测试运维(管理机器)产品经理(非技术岗位,提出需求) 而我们这里主要关注的是开发方向,开发岗位又分很…

MySQL 数据”丢失”事件之 binlog 解析应用

事件背景 客户反馈在晚间数据跑批后,查询相关表的数据时,发现该表的部分数据在数据库中不存在 从应用跑批的日志来看,跑批未报错,且可查到日志中明确显示当时那批数据已插入到数据库中 需要帮忙分析这批数据丢失的原因。 备注:考虑信息敏感性,以下分析场景测试环境模拟,相关数据…

熊军出席ACDU·中国行南京站,详解SQL管理之道

12月21日,2024 ACDU中国行在南京圆满收官,本次活动分为三个篇章——回顾历史、立足当下、展望未来,为线上线下与会观众呈现了一场跨越时空的技术盛宴,吸引了众多业内人士的关注。云和恩墨副总经理熊军出席此次活动并发表了主题演讲…

Spring01 - 工厂篇

Spring入门(上)-工厂篇 文章目录 Spring入门(上)-工厂篇一:引言1:EJB存在的问题2:什么是Spring3:设计模式和反射工厂 二:第一个spring程序1:环境搭建2:核心API - ApplicationContext2.1&#xf…

攻防世界 unserialize3

开启场景 题目为unserialize3,这个单词在php中代表反序列化,代码 __wakeup 也是php反序列化中常见的魔术方法,所以这个题基本就是和反序列化有关的题目。根据代码提示,编写一个Exploit运行,将对象xctf的信息序列化 得到…

汽车免拆诊断案例 | 2011 款奔驰 S400L HYBRID 车发动机故障灯异常点亮

故障现象 一辆2011款奔驰 S400L HYBRID 车,搭载272 974发动机和126 V高压电网系统,累计行驶里程约为29万km。车主反映,行驶中发动机故障灯异常点亮。 故障诊断 接车后试车,组合仪表上的发动机故障灯长亮;用故障检测…

GitLab安装及使用

目录 一、安装 1.创建一个目录用来放rpm包 2.检查防火墙状态 3.安装下载好的rpm包 4.修改配置文件 5.重新加载配置 6.查看版本 7.查看服务器状态 8.重启服务器 9.输网址 二、GitLab的使用 1.创建空白项目 2.配置ssh 首先生成公钥: 查看公钥 把上面的…

Electron 学习笔记

目录 一、安装和启动electron 1. 官网链接 2. 根据文档在控制台输入 3. 打包必填 4. 安装electron开发依赖 5. 在开发的情况下打开应用 6. 修改main为main.js,然后创建main.js 7.启动 二、启动一个窗口 1. main.js 2. index.html 3. 隐藏菜单栏 三、其他…

网络管理-期末项目(附源码)

环境:网络管理 主机资源监控系统项目搭建 (保姆级教程 建议点赞 收藏)_搭建网络版信息管理系统-CSDN博客 效果图 下面3个文件的项目目录(python3.8.8的虚拟环境) D:\py_siqintu\myproject5\Scripts\mytest.py D:\py_siqintu\myproject5\Sc…

62.基于SpringBoot + Vue实现的前后端分离-驾校预约学习系统(项目+论文)

项目介绍 伴随着信息技术与互联网技术的不断发展,人们进到了一个新的信息化时代,传统管理技术性没法高效率、容易地管理信息内容。为了实现时代的发展必须,提升管理高效率,各种各样管理管理体系应时而生,各个领域陆续进…

MySQL用表组织数据

用表组织数据 文章目录 用表组织数据一.四种完整性约束二.数值类型2-1三.数值类型2-2四.字符串.日期类型五.设置1.设置主键2.设置标识列3.设置非空4.设置默认值 六.主外键建立后注意事项 一.四种完整性约束 1.域完整性 列 域完整性约束方法:限制数据类型,检查约束,外键约束,默…

iOS开发代码块-OC版

iOS开发代码块-OC版 资源分享资源使用详情Xcode自带代码块自定义代码块 资源分享 自提: 通过网盘分享的文件:CodeSnippets 2.zip 链接: https://pan.baidu.com/s/1Yh8q9PbyeNpuYpasG4IiVg?pwddn1i 提取码: dn1i Xcode中的代码片段默认放在下面的目录中…

基于微信小程序的校园访客登记系统

基于微信小程序的校园访客登记系统 功能列表 用户端功能 注册与登录 :支持用户通过手机号短信验证码注册和登录。个人资料管理 :允许用户编辑和更新个人信息及其密码。站内信消息通知:通知公告。来访预约:提交来访预约支持车牌…

mac启ssh服务用于快速文件传输

x.1 在mac上启SSH服务 方法一:图形交互界面启ssh(推荐) 通过sharing - advanced - remote login来启动ssh;(中文版mac应该是 “系统设置 → 通用 → 共享”里打开“远程登录”来启动) 查看自己的用户名和…

Magnet: 基于推送的大规模数据处理Shuffle服务

本文翻译自:《Magnet: Push-based Shuffle Service for Large-scale Data Processing》 摘要 在过去的十年中,Apache Spark 已成为大规模数据处理的流行计算引擎。与其他基于 MapReduce 计算范式的计算引擎一样,随机Shuffle操作(即…

面试真题:Integer(128)引发的思考

引言 在 Java 编程语言中,数据类型的使用至关重要。作为一种静态类型语言,Java 提供了丰富的基本数据类型和对应的包装类。其中,Integer 类是 int 类型的包装类,承载着更复杂的功能,如缓存、装箱和拆箱等。掌握 Integ…

Windows脚本清理C盘缓存

方法一:使用power文件.ps1的文件 脚本功能 清理临时文件夹: 当前用户的临时文件夹(%Temp%)。系统临时文件夹(C:\Windows\Temp)。 清理 Windows 更新缓存: 删除 Windows 更新下载缓存&#xff0…

Type-c接口

6P Type C 接口座: 仅支持充电 16P 与 12P Type C 接口座: 支持数据传输 Type-c引脚: SUB1,SUB2为辅助通讯引脚,主要用在音视频信号传输中,很多DIY都用不到 CC1、CC2引脚用于连接检测,一般可以不用连接&am…

基于python语音启动电脑应用程序

osk模型进行输入语音转换 txt字典导航程序路径 pyttsx3引擎进行语音打印输出 关键词程序路径 import os import json import queue import sounddevice as sd from vosk import Model, KaldiRecognizer import subprocess import time import pyttsx3 import threading# 初始…