【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR

news2025/1/30 10:28:39

【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR

在这里插入图片描述

1 算法原理

Tarun A K, Chundawat V S, Mandal M, et al. Fast yet effective machine unlearning[J]. IEEE Transactions on Neural Networks and Learning Systems, 2023.

本文提出了一种名为 UNSIR(Unlearning with Single Pass Impair and Repair) 的机器遗忘框架,用于从深度神经网络中高效地卸载(遗忘)特定类别数据,同时保留模型对其他数据的性能。以下是算法的主要步骤:

1. 零隐私设置(Zero-Glance Privacy Setting)

  • 假设:用户请求从已训练的模型中删除其数据(例如人脸图像),并且模型无法再访问这些数据,即使是为了权重调整。
  • 目标:在不重新训练模型的情况下,使模型忘记特定类别的数据,同时保留对其他数据的性能。

2. 学习误差最大化噪声矩阵(Error-Maximizing Noise Matrix)

  • 初始化:随机初始化噪声矩阵 N,其大小与模型输入相同。

  • 优化目标:通过最大化模型对目标类别的损失函数来优化噪声矩阵 N。具体优化问题为:
    a r g N m i n E ( θ ) = − L ( f , y ) + λ ∥ w n o i s e ∥ argNminE(θ)=−L(f,y)+λ∥wnoise∥ argNminE(θ)=L(f,y)+λwnoise

    其中:

    • L(f,y) 是针对要卸载的类别的分类损失函数。
    • λ∥wnoise∥ 是正则化项,防止噪声值过大。
    • 使用交叉熵损失函数 L 和 L2 归一化。
  • 噪声矩阵的作用:生成的噪声矩阵 N 与要卸载的类别标签相关联,用于在后续步骤中破坏模型对这些类别的记忆。

3. 单次损伤与修复(Single Pass Impair and Repair)

  • 损伤步骤(Impair Step)
    • 操作:将噪声矩阵 N 与保留数据子集Dr结合,训练模型一个周期(epoch)。
    • 目的:通过高学习率(例如 0.02)快速破坏模型对要卸载类别的权重。
    • 结果:模型对要卸载类别的性能显著下降,同时对保留类别的性能也会受到一定影响。
  • 修复步骤(Repair Step)
    • 操作:仅使用保留数据子集 Dr再次训练模型一个周期(epoch),学习率较低(例如 0.01)。
    • 目的:恢复模型对保留类别的性能,同时保持对要卸载类别的遗忘效果。
    • 结果:最终模型在保留数据上保持较高的准确率,而在卸载数据上准确率接近于零。

2 Python代码实现

相关函数

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset,TensorDataset
from torch.amp import autocast, GradScaler  
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings
import random
from copy import deepcopy
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

warnings.filterwarnings("ignore")
MODEL_NAMES = "MLP"
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义三层全连接网络
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 加载MNIST数据集
def load_MNIST_data(batch_size,forgotten_classes,ratio):
    transform = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    forgotten_train_data,_ = generate_subset_by_ratio(train_data, forgotten_classes,ratio)
    retain_train_data,_ = generate_subset_by_ratio(train_data, [i for i in range(10) if i not in forgotten_classes])

    forgotten_train_loader= DataLoader(forgotten_train_data, batch_size=batch_size, shuffle=True)
    retain_train_loader= DataLoader(retain_train_data, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader, retain_train_loader, forgotten_train_loader

# worker_init_fn 用于初始化每个 worker 的随机种子
def worker_init_fn(worker_id):
    random.seed(2024 + worker_id)
    np.random.seed(2024 + worker_id)
def get_transforms():
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]
    ])
    
    return train_transform, test_transform
# 模型训练函数
def train_model(model, train_loader, criterion, optimizer, scheduler=None,use_fp16 = False):
    use_fp16 = True
    # 使用新的初始化方式:torch.amp.GradScaler("cuda")
    scaler = GradScaler("cuda")  # 用于混合精度训练
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # 前向传播
        with autocast(enabled=use_fp16, device_type="cuda"):  # 更新为使用 "cuda"
            outputs = model(images)
            loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        if use_fp16:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
    if scheduler is not None:
        # 更新学习率
        scheduler.step()

    print(f"Loss: {running_loss/len(train_loader):.4f}")
# 模型评估(计算保留和遗忘类别的准确率)
def test_model(model, test_loader, forgotten_classes=[0]):
    """
    测试模型的性能,计算总准确率、遗忘类别准确率和保留类别准确率。

    :param model: 要测试的模型
    :param test_loader: 测试数据加载器
    :param forgotten_classes: 需要遗忘的类别列表
    :return: overall_accuracy, forgotten_accuracy, retained_accuracy
    """
    model.eval()
    correct = 0
    total = 0
    forgotten_correct = 0
    forgotten_total = 0
    retained_correct = 0
    retained_total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            # 计算总的准确率
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # 计算遗忘类别的准确率
            mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=device))
            forgotten_total += mask_forgotten.sum().item()
            forgotten_correct += (predicted[mask_forgotten] == labels[mask_forgotten]).sum().item()

            # 计算保留类别的准确率(除遗忘类别的其他类别)
            mask_retained = ~mask_forgotten
            retained_total += mask_retained.sum().item()
            retained_correct += (predicted[mask_retained] == labels[mask_retained]).sum().item()

    overall_accuracy = correct / total
    forgotten_accuracy = forgotten_correct / forgotten_total if forgotten_total > 0 else 0
    retained_accuracy = retained_correct / retained_total if retained_total > 0 else 0

    # return overall_accuracy, forgotten_accuracy, retained_accuracy
    return  round(overall_accuracy, 4), round(forgotten_accuracy, 4), round(retained_accuracy, 4)


主函数

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from models.Base import load_MNIST_data, test_model, load_CIFAR100_data, init_model

class UNSIRForget:
    def __init__(self, model):
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 学习误差最大化噪声矩阵
    def learn_error_maximizing_noise(self, train_loader, forgotten_classes, lambda_reg=0.01, learning_rate=0.01, num_epochs=5):
        self.model.eval()
        
        # 初始化噪声矩阵 N,大小与输入图像相同(例如28x28图像)
        noise_matrix = torch.randn(1, 1, 28, 28, device=self.device, requires_grad=True)  # 假设输入是28x28的图像

        # 优化器用于优化噪声矩阵
        optimizer = torch.optim.SGD([noise_matrix], lr=learning_rate)

        noise_data = []
        noise_labels = []

        # 生成噪声数据集
        for epoch in range(num_epochs):
            total_loss = 0.0

            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)

                # 只对属于遗忘类别的数据进行优化
                mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))
                noisy_images = images.clone()

                # 对遗忘类别的图像添加噪声
                noisy_images[mask_forgotten] += noise_matrix

                # 保存噪声数据
                noise_data.append(noisy_images)
                noise_labels.append(labels)

                # 前向传播
                outputs = self.model(noisy_images.view(-1, 28 * 28))  # 假设模型的输入是28x28的图像
                loss = F.cross_entropy(outputs, labels)

                # L2 正则化项(噪声矩阵的L2范数)
                l2_reg = lambda_reg * torch.norm(noise_matrix)

                # 总损失(包含交叉熵损失和L2正则化)
                total_loss = loss + l2_reg

                # 反向传播并更新噪声矩阵
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}")

        # 返回包含噪声数据和标签的噪声数据集
        return torch.cat(noise_data), torch.cat(noise_labels), noise_matrix.detach()

    # 实现机器遗忘(针对特定类别,使用噪声矩阵进行干扰)
    def unlearn(self, train_loader, forgotten_classes, noise_data, noise_labels, noise_matrix, alpha_impair, alpha_repair, num_epochs=1):
        # 损伤步骤
        self.model.train()
        print("执行损伤中...")
        for epoch in range(num_epochs):
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)

                # 仅选择保留类别的数据
                mask_retained = ~torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))
                retained_images = images[mask_retained]
                retained_labels = labels[mask_retained]

                # 生成新的数据集,将噪声数据添加到保留数据中
                augmented_images = torch.cat([retained_images, noise_data], dim=0)
                augmented_labels = torch.cat([retained_labels, noise_labels], dim=0)

                # 前向传播
                outputs = self.model(augmented_images.view(-1, 28 * 28))  # 假设模型的输入是28x28的图像
                loss = F.cross_entropy(outputs, augmented_labels)

                # 更新模型权重
                self.model.zero_grad()
                loss.backward()
                with torch.no_grad():
                    for param in self.model.parameters():
                        param.data -= alpha_impair * param.grad.data

        # 修复步骤
        print("执行修复中...")
        for epoch in range(num_epochs):
            for images, labels in train_loader:
                images, labels = images.to(self.device), labels.to(self.device)

                # 仅使用保留类别的数据进行修复
                mask_retained = ~torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))
                retained_images = images[mask_retained]
                retained_labels = labels[mask_retained]

                if retained_images.size(0) == 0:
                    continue

                # 前向传播和损失计算
                outputs = self.model(retained_images.view(-1, 28 * 28))
                loss = F.cross_entropy(outputs, retained_labels)

                # 更新模型权重
                self.model.zero_grad()
                loss.backward()
                with torch.no_grad():
                    for param in self.model.parameters():
                        param.data -= alpha_repair * param.grad.data

        return self.model


# UNSIR算法的主要流程
def unsir_unlearning(model_before, retrain_data, forget_data, all_data, forgotten_classes, lambda_reg=0.01, learning_rate=0.01, alpha_impair=0.5, alpha_repair=0.001, num_epochs=5):
    """
    执行 UNSIR 算法的主要流程,包括学习误差最大化噪声矩阵、损伤、修复步骤,最终返回遗忘后的模型。
    """
    unsir_forgetter = UNSIRForget(model_before)

    # 计算学习误差最大化噪声矩阵
    noise_data, noise_labels, noise_matrix = unsir_forgetter.learn_error_maximizing_noise(all_data, forgotten_classes, lambda_reg, learning_rate, num_epochs)

    # 执行 unlearn(损伤与修复步骤)
    unlearned_model = unsir_forgetter.unlearn(all_data, forgotten_classes, noise_data, noise_labels, noise_matrix, alpha_impair, alpha_repair, num_epochs)

    return unlearned_model


def main():
    # 超参数设置
    batch_size = 256
    forgotten_classes = [0]
    ratio = 1
    model_name = "MLP"
    
    # 加载数据
    train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)
    model_before = init_model(model_name, train_loader)
    
    # 在训练之前测试初始模型准确率
    overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)

    print("执行 UNSIR 遗忘...")
    model_after = unsir_unlearning(
        model_before,
        retain_loader,
        forget_loader,
        train_loader,
        forgotten_classes,
        lambda_reg=0.01,
        learning_rate=0.01,
        alpha_impair=0.5,
        alpha_repair=0.001,
        num_epochs=5,
    )

    # 测试遗忘后的模型
    overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)

    # 输出遗忘前后的准确率变化
    print(f"Unlearning前遗忘准确率: {100 * forgotten_acc_before:.2f}%")
    print(f"Unlearning后遗忘准确率: {100 * forgotten_acc_after:.2f}%")
    print(f"Unlearning前保留准确率: {100 * retained_acc_before:.2f}%")
    print(f"Unlearning后保留准确率: {100 * retained_acc_after:.2f}%")


if __name__ == "__main__":
    main()

3 总结

当前方法不支持随机样本或类别子集的卸载,这可能违反零隐私假设。

仍属于重新优化的算法,即还需要训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2286380.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SpringBoot的阳光幼儿园管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

【开源免费】基于SpringBoot+Vue.JS景区民宿预约系统(JAVA毕业设计)

本文项目编号 T 162 ,文末自助获取源码 \color{red}{T162,文末自助获取源码} T162,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

安卓逆向之脱壳-认识一下动态加载 双亲委派(一)

安卓逆向和脱壳是安全研究、漏洞挖掘、恶意软件分析等领域的重要环节。脱壳(unpacking)指的是去除应用程序中加固或保护措施的过程,使得可以访问应用程序的原始代码或者数据。脱壳的重要性: 分析恶意软件:很多恶意软件…

马尔科夫模型和隐马尔科夫模型区别

我用一个天气预报和海藻湿度观测的比喻来解释,保证你秒懂! 1. 马尔可夫模型(Markov Model, MM) 特点:状态直接可见 场景:天气预报(晴天→雨天→阴天…)核心假设: 下一个…

Python NumPy(7):连接数组、分割数组、数组元素的添加与删除

1 连接数组 函数描述concatenate连接沿现有轴的数组序列stack沿着新的轴加入一系列数组。hstack水平堆叠序列中的数组(列方向)vstack竖直堆叠序列中的数组(行方向) 1.1 numpy.concatenate numpy.concatenate 函数用于沿指定轴连…

【LLM】deepseek多模态之Janus-Pro和JanusFlow框架

note 文章目录 note一、Janus-Pro:解耦视觉编码,实现多模态高效统一技术亮点模型细节 二、JanusFlow:融合生成流与语言模型,重新定义多模态技术亮点模型细节 Reference 一、Janus-Pro:解耦视觉编码,实现多模…

2000-2021年 全国各地级市专利申请与获得情况、绿色专利申请与获得情况数据

2000-2021年 全国各地级市专利申请与获得情况、绿色专利申请与获得情况数据.ziphttps://download.csdn.net/download/2401_84585615/89575931 https://download.csdn.net/download/2401_84585615/89575931 2000至2021年,全国各地级市的专利申请与获得情况呈现出显著…

51单片机(STC89C52)开发:点亮一个小灯

软件安装: 安装开发板CH340驱动。 安装KEILC51开发软件:C51V901.exe。 下载软件:PZ-ISP.exe 创建项目: 新建main.c 将main.c加入至项目中: main.c:点亮一个小灯 #include "reg52.h"sbit LED1P2^0; //P2的…

240. 搜索二维矩阵||

参考题解:https://leetcode.cn/problems/search-a-2d-matrix-ii/solutions/2361487/240-sou-suo-er-wei-ju-zhen-iitan-xin-qin-7mtf 将矩阵旋转45度,可以看作一个二叉搜索树。 假设以左下角元素为根结点, 当target比root大的时候&#xff…

反向代理模块b

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求,然后将请求转发给内部网络上的服务器,将从服务器上得到的结果返回给客户端,此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说,反向代理就相当于…

【Linux权限】—— 于虚拟殿堂,轻拨密钥启华章

欢迎来到ZyyOvO的博客✨,一个关于探索技术的角落,记录学习的点滴📖,分享实用的技巧🛠️,偶尔还有一些奇思妙想💡 本文由ZyyOvO原创✍️,感谢支持❤️!请尊重原创&#x1…

EasyExcel使用详解

文章目录 EasyExcel使用详解一、引言二、环境准备与基础配置1、添加依赖2、定义实体类 三、Excel 读取详解1、基础读取2、自定义监听器3、多 Sheet 处理 四、Excel 写入详解1、基础写入2、动态列与复杂表头3、样式与模板填充 五、总结 EasyExcel使用详解 一、引言 EasyExcel 是…

前端-Rollup

Rollup 是一个用于 JavaScript 的模块打包工具,它将小的代码片段编译成更大、更复杂的代码,例如库或应用程序。它使用 JavaScript 的 ES6 版本中包含的新标准化代码模块格式,而不是以前的 CommonJS 和 AMD 等特殊解决方案。ES 模块允许你自由…

vue3相关知识点

title: vue_1 date: 2025-01-28 12:00:00 tags:- 前端 categories:- 前端vue3 Webpack ~ vite vue3是基于vite创建的 vite 更快一点 一些准备工作 准备后如图所示 插件 Main.ts // 引入createApp用于创建应用 import {createApp} from vue // 引入App根组件 import App f…

微服务网关鉴权之sa-token

目录 前言 项目描述 使用技术 项目结构 要点 实现 前期准备 依赖准备 统一依赖版本 模块依赖 配置文件准备 登录准备 网关配置token解析拦截器 网关集成sa-token 配置sa-token接口鉴权 配置satoken权限、角色获取 通用模块配置用户拦截器 api模块配置feign…

华为小米vivo向上,苹果荣耀OPPO向下

日前,Counterpoint发布的手机销量月度报告显示,中国智能手机销量在2024年第四季度同比下降3.2%,成为2024年唯一出现同比下滑的季度。而对于各大智能手机品牌来说,他们的市场份额和格局也在悄然发生变化。 华为逆势向上 在2024年第…

国产编辑器EverEdit - 输出窗口

1 输出窗口 1.1 应用场景 输出窗口可以显示用户执行某些操作的结果,主要包括: 查找类:查找全部,筛选等待操作,可以把查找结果打印到输出窗口中; 程序类:在执行外部程序时(如:命令窗…

获取snmp oid的小方法1(随手记)

snmpwalk遍历设备的mib # snmpwalk -v <SNMP version> -c <community-id> <IP> . snmpwalk -v 2c -c test 192.168.100.201 .根据获取的值&#xff0c;找到某一个想要的值的oid # SNMPv2-MIB::sysName.0 STRING: test1 [rootzabbix01 fonts]# snmpwalk -v…

望获实时Linux系统:2024回顾与2025展望

2024年回顾 功能安全认证 2024年4月&#xff0c;望获操作系统V2获ISO26262:2018功能安全产品认证&#xff08;ASIL B等级&#xff09;&#xff0c;达到国际功能安全标准。 EtherCAT实时性增强 2024年5月&#xff0c;发布通信实时增强组件&#xff0c;EtherCAT总线通信抖…

2025_1_29 C语言学习中关于指针

1. 指针 指针就是存储的变量的地址&#xff0c;指针变量就是指针的变量。 1.1 空指针 当定义一个指针没有明确指向内容时&#xff0c;就可以将他设置为空指针 int* p NULL;这样对空指针的操作就会使程序崩溃而不会导致出现未定义行为&#xff0c;因为程序崩溃是宏观的&…