BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例

news2024/12/28 23:21:07

加载数据集

# 载入MNIST训练集和测试集
transform = transforms.Compose([
            transforms.ToTensor(),
            ])
train_loader = datasets.MNIST(root='data',
                              transform=transform,
                              train=True,
                              download=True)
test_loader = datasets.MNIST(root='data',
                             transform=transform,
                             train=False)
# 可视化样本 大小28×28
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

在训练集中植入5000个中毒样本

# 在训练集中植入5000个中毒样本
for i in range(5000):
    train_loader.data[i][26][26] = 255
    train_loader.data[i][25][25] = 255
    train_loader.data[i][24][26] = 255
    train_loader.data[i][26][24] = 255
    train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

训练模型

data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,
                                                batch_size=64,
                                                shuffle=True,
                                                num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,
                                               batch_size=64,
                                               shuffle=False,
                                               num_workers=0)
# LeNet-5 模型
class LeNet_5(nn.Module):
    def __init__(self):
        super(LeNet_5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1)
        self.conv2 = nn.Conv2d(6, 16, 5, 1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), 2, 2)
        x = F.max_pool2d(self.conv2(x), 2, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x
# 训练过程
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        pred = model(data)
        loss = F.cross_entropy(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))
    torch.save(model.state_dict(), 'badnets.pth')


# 测试过程
def test(model, device, test_loader):
    model.load_state_dict(torch.load('badnets.pth'))
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.cross_entropy(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()
        total_loss /= len(test_loader.dataset)
        acc = correct / len(test_loader.dataset) * 100
        print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
def main():
    # 超参数
    num_epochs = 10
    lr = 0.01
    momentum = 0.5
    model = LeNet_5().to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=momentum)
    # 在干净训练集上训练,在干净测试集上测试
    # acc=98.29%
    # 在带后门数据训练集上训练,在干净测试集上测试
    # acc=98.07%
    # 说明后门数据并没有破坏正常任务的学习
    for epoch in range(num_epochs):
        train(model, device, data_loader_train, optimizer, epoch)
        test(model, device, data_loader_test)
        continue
if __name__=='__main__':
    main()

测试攻击成功率

# 攻击成功率 99.66%  对测试集中所有图像都注入后门
    for i in range(len(test_loader)):
        test_loader.data[i][26][26] = 255
        test_loader.data[i][25][25] = 255
        test_loader.data[i][24][26] = 255
        test_loader.data[i][26][24] = 255
        test_loader.targets[i] = 9
    data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,
                                                   batch_size=64,
                                                   shuffle=False,
                                                   num_workers=0)
    test(model, device, data_loader_test2)
    plt.imshow(test_loader.data[0].numpy())
    plt.show()

可视化中毒样本,成功被预测为特定目标类别“9”,证明攻击成功。
在这里插入图片描述
在这里插入图片描述

完整代码

from packaging import packaging
from torchvision.models import resnet50
from utils import Flatten
from tqdm import tqdm
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

# 载入MNIST训练集和测试集
transform = transforms.Compose([
            transforms.ToTensor(),
            ])
train_loader = datasets.MNIST(root='data',
                              transform=transform,
                              train=True,
                              download=True)
test_loader = datasets.MNIST(root='data',
                             transform=transform,
                             train=False)
# 可视化样本 大小28×28
# plt.imshow(train_loader.data[0].numpy())
# plt.show()

# 训练集样本数据
print(len(train_loader))

# 在训练集中植入5000个中毒样本
''' '''
for i in range(5000):
    train_loader.data[i][26][26] = 255
    train_loader.data[i][25][25] = 255
    train_loader.data[i][24][26] = 255
    train_loader.data[i][26][24] = 255
    train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()


data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,
                                                batch_size=64,
                                                shuffle=True,
                                                num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,
                                               batch_size=64,
                                               shuffle=False,
                                               num_workers=0)


# LeNet-5 模型
class LeNet_5(nn.Module):
    def __init__(self):
        super(LeNet_5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, 1)
        self.conv2 = nn.Conv2d(6, 16, 5, 1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), 2, 2)
        x = F.max_pool2d(self.conv2(x), 2, 2)
        x = x.view(-1, 16 * 4 * 4)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


# 训练过程
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        pred = model(data)
        loss = F.cross_entropy(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx % 100 == 0:
            print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))
    torch.save(model.state_dict(), 'badnets.pth')


# 测试过程
def test(model, device, test_loader):
    model.load_state_dict(torch.load('badnets.pth'))
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += F.cross_entropy(output, target, reduction="sum").item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()
        total_loss /= len(test_loader.dataset)
        acc = correct / len(test_loader.dataset) * 100
        print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))


def main():
    # 超参数
    num_epochs = 10
    lr = 0.01
    momentum = 0.5
    model = LeNet_5().to(device)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=lr,
                                momentum=momentum)
    # 在干净训练集上训练,在干净测试集上测试
    # acc=98.29%
    # 在带后门数据训练集上训练,在干净测试集上测试
    # acc=98.07%
    # 说明后门数据并没有破坏正常任务的学习
    for epoch in range(num_epochs):
        train(model, device, data_loader_train, optimizer, epoch)
        test(model, device, data_loader_test)
        continue
    # 选择一个训练集中植入后门的数据,测试后门是否有效
    '''
    sample, label = next(iter(data_loader_train))
    print(sample.size())  # [64, 1, 28, 28]
    print(label[0])
    # 可视化
    plt.imshow(sample[0][0])
    plt.show()
    model.load_state_dict(torch.load('badnets.pth'))
    model.eval()
    sample = sample.to(device)
    output = model(sample)
    print(output[0])
    pred = output.argmax(dim=1)
    print(pred[0])
    '''
    # 攻击成功率 99.66%
    for i in range(len(test_loader)):
        test_loader.data[i][26][26] = 255
        test_loader.data[i][25][25] = 255
        test_loader.data[i][24][26] = 255
        test_loader.data[i][26][24] = 255
        test_loader.targets[i] = 9
    data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,
                                                    batch_size=64,
                                                    shuffle=False,
                                                    num_workers=0)
    test(model, device, data_loader_test2)
    plt.imshow(test_loader.data[0].numpy())
    plt.show()


if __name__=='__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1133094.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Go版]算法通关村第十八关青铜——透析回溯的模版

目录 认识回溯思想回溯的代码框架从 N 叉树说起有的问题暴力搜索也不行回溯 递归 局部枚举 放下前任Go代码【LeetCode-77. 组合】回溯热身-再论二叉树的路径问题题目:二叉树的所有路径Go 代码 题目:路径总和 IIGo 代码 回溯是最重要的算法思想之一&am…

看我为了水作业速通 opengl freeglut!

参考视频计算机图形学基础–OpenGL的实现_哔哩哔哩_bilibiliT 图形绘制 点 GL_POINTS #define FREEGLUT_STATIC // Define a static library for calling functions #include <GL/freeglut.h> // Include the header filevoid myPoints() { //show three points in sc…

MySQL中大量数据优化方案

文章目录 1 大量数据优化1.1 引言1.2 评估表数据体量1.2.1 表容量1.2.2 磁盘空间1.2.3 实例容量 1.3 出现问题的原因1.4 解决问题1.4.1 数据表分区1.4.1.1 简介1.4.1.2 优缺点1.4.1.2 操作 1.4.2 数据库分表1.4.2.1 简介1.4.2.2 分库分表方案1.4.2.2.1 取模方案1.4.2.2.2 range…

JAVA毕业设计105—基于Java+Springboot+Vue的校园跑腿系统(源码+数据库)

基于JavaSpringbootVue的校园跑腿系统(源码数据库)105 一、系统介绍 本系统前后端分离 本系统分为管理员和用户两个角色 用户&#xff1a; 登录&#xff0c;注册&#xff0c;余额充值&#xff0c;密码修改&#xff0c;发布任务&#xff0c;接受任务&#xff0c;订单管理&…

(多线程)并发编程的三大基础应用——阻塞队列、定时器、线程池【手搓源码】

9.2 阻塞式队列 BlockingQueue<Integer> blockingQueue new LinkedBlockingQueue<Integer>();BlockingQueue<String> queue new LinkedBlockingQueue<>(); // 入队列 queue.put("abc"); // 出队列. 如果没有 put 直接 take, 就会阻塞. St…

IDEA 删除一次性删除所有断点

Ctrl Shift F8 &#xff08;打开“断点”对话框&#xff09; Ctrl A &#xff08;选择所有断点&#xff09; Alt Delete &#xff08;删除选定的断点&#xff09; Enter &#xff08;确认&#xff09;

数字孪生技术:工业数字化转型的引擎

数字孪生是一种将物理实体数字化为虚拟模型的技术&#xff0c;这些虚拟模型与其物理对应物相互关联。这种虚拟模型通常是在数字平台上创建的&#xff0c;它们复制了实际设备、工厂、甚至整个供应链的运作方式。这使工业企业能够实现以下益处&#xff1a; 1. 实时监测和分析 数…

(Java)中的数据类型和变量

文章目录 一、字面常量二、数据类型三、变量1.变量的概念2.语法的格式3.整型变量4.长整型变量5.短整型变量6.字节型变量 四、浮点型变量1.双精度浮点数2.单精度浮点数 五、字符型常量六、布尔型变量七、类型转换1.自动类型转换&#xff08;隐式&#xff09;2.强制类型转换(显式…

【数据结构】数组和字符串(四):特殊矩阵的压缩存储:稀疏矩阵——三元组表

文章目录 4.2.1 矩阵的数组表示4.2.2 特殊矩阵的压缩存储a. 对角矩阵的压缩存储b~c. 三角、对称矩阵的压缩存储d. 稀疏矩阵的压缩存储——三元组表结构体初始化元素设置打印矩阵主函数输出结果代码整合 4.2.1 矩阵的数组表示 【数据结构】数组和字符串&#xff08;一&#xff…

一篇教你学会Ansible

前言 Ansible首次发布于2012年&#xff0c;是一款基于Python开发的自动化运维工具&#xff0c;核心是通过ssh将命令发送执行&#xff0c;它可以帮助管理员在多服务器上进行配置管理和部署。它的工作形式依托模块实现&#xff0c;自己没有批量部署的能力。真正具备批量部署的是…

生产管理中,如何做好生产进度控制?

在生产管理中&#xff0c;我们常常会遇到以下问题&#xff1a; 由于计划不清或者无计划&#xff0c;导致物料进度无法保障&#xff0c;经常出现停工待料的情况。 停工待料导致了生产时间不足&#xff0c;为了赶交货期&#xff0c;只能加班加点。 生产计划并未发挥实际作用&am…

14、Python -- 列表推导式(for表达式)与控制循环

目录 for表达式&#xff08;列表推导式&#xff09;列表推导式的说明使用break跳出循环使用continue忽略本次循环使用return结束函数 列表推导式 使用break跳出循环 使用continue忽略本次循环 for表达式&#xff08;列表推导式&#xff09; for表达式用于利用其他区间、元组、…

哪些车企是前向雷达大客户?国产突围/4D升级进展如何

可穿透尘雾、雨雪、不受恶劣天气影响&#xff0c;唯一能够“全天候全天时”工作&#xff0c;同时在中远距离的物体识别能力&#xff0c;毫米波雷达成为二十几年前豪华车ACC功能的必备传感器。 此后&#xff0c;随着视觉感知技术的不断成熟&#xff0c;尤其是Mobileye、特斯拉等…

强化学习代码实战(3) --- 寻找真我

前言 本文内容来自于南京大学郭宪老师在博文视点学院录制的视频&#xff0c;课程仅9元地址&#xff0c;配套书籍为深入浅出强化学习 编程实战 郭宪地址。 正文 我们发现多臂赌博机执行一个动作之后&#xff0c;无论是选择摇臂1&#xff0c;摇臂2&#xff0c;还是摇臂3之后都会返…

MySQL Join 类型

文章目录 1 Join 类型有哪些2 Inner Join3 Left Join4 Right Join5 Full Join 1 Join 类型有哪些 SQL Join 类型的区别 Inner Join: 左,右表都有的数据Left Join: 左表返回所有的行, 右表没有的补充为 NULLRight Loin: 右表返回所有的行, 左表没有的补充为 NULLFull Outer J…

【会员管理系统】篇二之项目搭建、初始化、安装第三方库

一、项目搭建 1.全局安装vue-cli npm install -g vue/cli查看版本信息 vue -V 2.创建项目 vue create 项目名称 回车 回车 剩余选择如下 之后等待项目创建 最后npm run serve 二、初始化配置 1.更改标题 打开public下的index&#xff0c;将title标签里的改成想要设置的…

【模式识别】贝叶斯决策模型理论总结

贝叶斯决策模型理论 一、引言二、贝叶斯定理三、先验概率和后验概率3.1 先验概率3.2 后验概率 四、最大后验准则五、最小错误率六、最小化风险七、最小最大决策八、贝叶斯决策建模参考 一、引言 在概率计算中&#xff0c;我们常常遇到这样的一类问题&#xff0c;某事件的发生可…

【Redis】redis 十大数据类型 概述

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ redis十大数据类型 一、redis字符串&#xff0…

【Elasticsearch】es脚本编程使用详解

目录 一、es脚本语言介绍 1.1 什么是es脚本 1.2 es脚本支持的语言 1.3 es脚本语言特点 1.4 es脚本使用场景 二、环境准备 2.1 docker搭建es过程 2.1.1 拉取es镜像 2.1.2 启动容器 2.1.3 配置es参数 2.1.4 重启es容器并访问 2.2 docker搭建kibana过程 2.2.1 拉取ki…