Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

news2024/12/25 10:32:52

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

  • CIFAR数据集
  • FGSM介绍
  • FGSM代码实现
    • FGSM算法实现
    • 攻击效果
  • 代码汇总
    • fgsm.py
    • train.py
    • advtest.py

之前已经针对CIFAR10训练了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
Pytorch | 从零构建ParNet对CIFAR10进行分类

本篇文章我们使用Pytorch实现快速梯度符号攻击Fast Gradient Sign Method, FGSM)对CIFAR10上的ResNet分类器进行攻击.

CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

FGSM介绍

FGSM(Fast Gradient Sign Method)算法是一种基于梯度的快速攻击算法,由Goodfellow等人在2015年提出,主要用于评估神经网络模型的鲁棒性。以下是对FGSM算法原理的详细介绍:

算法原理

  • FGSM算法的核心思想是利用神经网络的梯度信息来生成对抗样本。对于给定的输入样本,通过计算模型对该样本的损失函数关于输入的梯度,然后根据梯度的符号来确定扰动的方向,最后在该方向上添加一个小的扰动,得到对抗样本。
  • 具体而言,给定一个输入样本 x x x,其对应的真实标签为 y y y,模型的参数为 θ \theta θ,损失函数为 J ( θ , x , y ) J(\theta, x, y) J(θ,x,y)。首先计算损失函数 J J J 关于输入 x x x 的梯度 ∇ x J ( θ , x , y ) \nabla_x J(\theta, x, y) xJ(θ,x,y),然后根据梯度的符号确定扰动的方向 sign ( ∇ x J ( θ , x , y ) ) \text{sign}(\nabla_x J(\theta, x, y)) sign(xJ(θ,x,y)),最后生成对抗样本 x ′ = x + ϵ ⋅ sign ( ∇ x J ( θ , x , y ) ) x' = x + \epsilon \cdot \text{sign}(\nabla_x J(\theta, x, y)) x=x+ϵsign(xJ(θ,x,y)),其中 ϵ \epsilon ϵ 是一个很小的正数,用于控制扰动的大小。

FGSM代码实现

FGSM算法实现

import torch
import torch.nn as nn


def FGSM(model, criterion, original_images, labels, epsilon):
    """
    FGSM (Fast Gradient Sign Method)

    参数:
    - model: 要攻击的模型
    - criterion: 损失函数
    - original_images: 原始图像
    - labels: 原始图像的标签
    - epsilon: 扰动幅度
    
    """
    perturbed_images = original_images.clone().detach().requires_grad_(True)
    
    outputs = model(perturbed_images)
    loss = criterion(outputs, labels)
    
    model.zero_grad()
    loss.backward()
    
    data_grad = perturbed_images.grad.data
    sign_data_grad = data_grad.sign()
    perturbed_images = perturbed_images + epsilon * sign_data_grad
    perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)
    
    return perturbed_images

攻击效果

在这里插入图片描述

代码汇总

fgsm.py

import torch
import torch.nn as nn


def FGSM(model, criterion, original_images, labels, epsilon):
    """
    FGSM (Fast Gradient Sign Method)

    参数:
    - model: 要攻击的模型
    - criterion: 损失函数
    - original_images: 原始图像
    - labels: 原始图像的标签
    - epsilon: 扰动幅度
    
    """
    perturbed_images = original_images.clone().detach().requires_grad_(True)
    
    outputs = model(perturbed_images)
    loss = criterion(outputs, labels)
    
    model.zero_grad()
    loss.backward()
    
    data_grad = perturbed_images.grad.data
    sign_data_grad = data_grad.sign()
    perturbed_images = perturbed_images + epsilon * sign_data_grad
    perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)
    
    return perturbed_images

train.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import ResNet18


# 数据预处理
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 加载Cifar10训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# 定义设备(GPU或CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 初始化模型
model = ResNet18(num_classes=10)
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

if __name__ == "__main__":
    # 训练模型
    for epoch in range(10):  # 可以根据实际情况调整训练轮数
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')
                running_loss = 0.0

    torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')
    print('Finished Training')

advtest.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *
from attacks import *
import ssl
import os
from PIL import Image
import matplotlib.pyplot as plt

ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet18(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = "weights/epoch_10.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))


if __name__ == "__main__":
    # 在测试集上进行FGSM攻击并评估准确率
    model.eval()  # 设置为评估模式
    correct = 0
    total = 0
    epsilon = 16 / 255  # 可以调整扰动强度
    for data in testloader:
        original_images, labels = data[0].to(device), data[1].to(device)
        original_images.requires_grad = True
        
        attack_name = 'FGSM'
        if attack_name == 'FGSM':
            perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)
        
        perturbed_outputs = model(perturbed_images)
        _, predicted = torch.max(perturbed_outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    # Attack Success Rate
    ASR = 100 - accuracy
    print(f'Load ResNet Model Weight from {weights_path}')
    print(f'epsilon: {epsilon}')
    print(f'ASR of {attack_name} : {ASR}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265234.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mapbox基础,加载mapbox官方地图

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象…

保护模式基本概念

CPU 架构 RISC(Reduced Instruction Set Computer) 中文即"精简指令集计算机”。RISC构架的指令格式和长度通常是固定的(如ARM是32位的指令)、且指令和寻址方式少而简单、大多数指令在一个周期内就可以执行完毕 CISC&…

【Unity3D】Particle粒子特效或3D物体显示在UGUI上的方案

目录 一、RawImage Camera RenderTexture方式 (1)扩展知识:实现射线检测RawImage内的3D物体 (2)扩展知识:实现粒子特效显示RawImage上 二、UI摄像机 Canvas(Screen Space - Camera模式)方式 &#…

编程新选择:深入了解仓颉语言的优雅与高效

初识仓颉编程语言 仓颉编程语言(Cangjie Programming Language)是一种现代化的、面向未来的通用编程语言,其设计理念是为了降低编程的门槛,同时提供高度灵活性和表达力的开发体验。这种语言以其简洁优雅的语法和直观的设计理念受…

vue3项目history路由模式部署上线405、刷新404问题(包括部分页面刷新404问题)

一、找不到js模块 解决方法:配置Nginx配置文件: // root /your/program/path/dist root /www/wwwroot/my_manage_backend_v1/dist;二、刷新页面导致404问题(Not found) 经过一系列配置后发现进入页面一切正常,包括路由前进和回退&#xff0…

谷歌开发者工具 - 控制台篇

Chrome DevTools - Console控制台篇 一、官网二、主要用途三、控制台篇1.JavaScript/浏览器消息记录(1)演示效果 / 两种记录状态(2)显示导致调用的堆栈轨迹 2.过滤消息(1)按日志级别过滤(2&…

003-aop-切点表达式

spring-aop-切点表达式 spring-aop-pom依赖

【蓝桥杯——物联网设计与开发】基础模块8 - RTC

目录 一、RTC (1)资源介绍 🔅简介 🔅时钟与分频(十分重要‼️) (2)STM32CubeMX 软件配置 (3)代码编写 (4)实验现象 二、RTC接口…

Web3.0安全开发实践:探索比特币DeFi生态中的PSBT

近年来,部分签名比特币交易(PSBT)在比特币生态系统中获得了显著关注。随着如Ordinal和基于铭文的资产等创新的兴起,安全的多方签名和复杂交易的需求不断增加,这使得PSBT成为应对比特币生态不断发展中不可或缺的工具。 …

MaxKB基于大语言模型和 RAG的开源知识库问答系统的快速部署教程

1 部署要求 1.1 服务器配置 部署服务器要求: 操作系统:Ubuntu 22.04 / CentOS 7.6 64 位系统CPU/内存:4C/8GB 以上磁盘空间:100GB 1.2 端口要求 在线部署MaxKB需要开通的访问端口说明如下: 端口作用说明22SSH安装…

【VMware虚拟机】安装win10系统教程双机可ping通

目录 1、下载1.1、点击链接下载媒体创建工具:1.2、下载后得到MediaCreationTool_22H2.exe:1.3、获取ISO镜像 2、安装3、显示4、配置网络4.1、配置4.2、排查4.2.1、关闭防火墙4.2.2、增加路由 1、下载 Windows10微软官网下载链接: https://www.microsoft…

AI一键制作圣诞帽头像丨附详细教程

我用AI换上圣诞帽头像啦~🎅 不管是搞笑表情、宠物头像还是你的自拍!!都能一键添加圣诞帽元素,毫无违和感!🎉 详细教程在P3、P4,手残党也能轻松搞定! 宝子们需要打“need”&#xff0…

活动图的理解和实践

在软件开发和系统设计中,理解系统的工作流程和并发行为是至关重要的。活动图作为一种重要的建模工具,为我们提供了一种直观且有效的方法来描述这些复杂的过程。本文将详细探讨活动图的理解与实践,包括其基本概念、用途、构建方法以及实际应用…

电磁兼容(EMC):一文解读磁芯复合材料——塑磁

目录 01 塑磁的定义 02 塑磁的常见规格型号 03 塑磁材料的优点 04 塑磁的应用 塑磁,也称为注塑磁,是一种将磁性粉末注入到塑料基体中制成的复合磁体材料。以下是塑磁的定义、应用和材料特性的总结: 01 塑磁的定义 塑磁是以塑料为基体,通过特殊工艺在其中加入磁性粒子(…

C语言-结构体内存大小

#include <stdio.h> #include <string.h> struct S1 { char a;//1 int b;//4 char c;//1 }; //分析 默认对齐数 成员对齐数 对齐数(前两个最小值) 最大对齐数 // 8 1 …

设计模式的主要分类是什么?请简要介绍每个分类的特点。

大家好&#xff0c;我是锋哥。今天分享关于【设计模式的主要分类是什么&#xff1f;请简要介绍每个分类的特点。】面试题。希望对大家有帮助&#xff1b; 设计模式的主要分类是什么&#xff1f;请简要介绍每个分类的特点。 1000道 互联网大厂Java工程师 精选面试题-Java资源分…

Java Web开发基础——Web应用的请求与响应机制

在本节中&#xff0c;我们将深入探讨Web应用程序中最为核心的部分之一——请求与响应机制。理解Web应用如何处理客户端请求并生成响应是成为Java Web开发者的关键。我们将从HTTP协议的基础知识开始&#xff0c;逐步过渡到请求参数的获取、响应内容的发送以及会话管理&#xff0…

免杀对抗—Behinder魔改流量特征去除

前言 在现实的攻防中&#xff0c;往往webshell要比主机后门要用得多&#xff0c;因为我们首先要突破的目标是网站嘛&#xff0c;而且waf也往往会更注重webshell的检测。webshell的免杀分为两个&#xff0c;一是静态查杀&#xff0c;二是流量查杀。静态查杀不用多说了&#xff…

Flutter 异步编程简述

1、isolate 机制 1.1 基本使用 Dart 是基于单线程模型的语言。但是在开发当中我们经常会进行耗时操作比如网络请求&#xff0c;这种耗时操作会堵塞我们的代码。因此 Dart 也有并发机制 —— isolate。APP 的启动入口main函数就是一个类似 Android 主线程的一个主 isolate。与…

RAID5原理简介和相关问题

1、RAID5工作原理 2、RAID5单块硬盘的数据连续吗&#xff1f; 3、RAID5单块硬盘存储的是原始数据&#xff0c;还是异或后的数据&#xff1f; 4、RAID5的分块大小 ‌RAID5的分块大小一般选择4KB到64KB之间较为合适‌。选择合适的分块大小主要取决于以下几个考量因素&#xff1…