【pytorch】MNIST 梯度上升法求使得某类概率最大的样本

news2025/3/3 4:27:20

目标:用 MNIST 训练一个 CNN 模型,然后用梯度上升法生成一张图片,使得模型对这张图片的预测结果为 8


import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

CNN 模型训练

# 下载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 被归一化到 [-1, 1] 之间
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=2)
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')

# 训练模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Net()
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 开始训练
epochs = 10

from tqdm import tqdm

for epoch in range(epochs):
    avg_loss = 0
    for i, data in enumerate(tqdm(trainloader)):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        avg_loss += loss.item()
    avg_loss = avg_loss / (i + 1)
    print('epoch: %d, loss: %.4f' % (epoch + 1, avg_loss))
print('Finished Training')

# 保存模型
PATH = './mnist_net.pth'
torch.save(net.state_dict(), PATH)
# 读取模型
PATH = './mnist_net.pth'
net = Net()
net.load_state_dict(torch.load(PATH))

梯度上升法进行图片生成

net = net.to(device)
# 固定 net 的参数
for param in net.parameters():
    param.requires_grad = False
    
net.eval()
# 进行梯度上升,让模型生成一张图片,使得模型对这张图片的预测结果为 9
img_gen = torch.randn(1, 1, 28, 28, requires_grad=True)

img_gen = img_gen.to(device)

epochs = 200
for epoch in range(epochs):

    output = net(img_gen)
    value_to_max = output[0][8] # 使得类别 8 的概率输出最大化
    
    # 计算梯度
    grad = torch.autograd.grad(value_to_max, img_gen)[0] 
    img_gen = img_gen.data + 0.1 * grad.data / torch.sqrt(grad.data * grad.data) # torch.Size([1, 1, 28, 28]) 
    # 往梯度上升的方向前进

    # 把 img_gen 有 nan 的位置变成 0
    img_gen.data[img_gen.data != img_gen.data] = 0
    
    # 重新计算梯度
    img_gen = img_gen.clone().detach().requires_grad_(True).to(device)

    if epoch % 20 == 0:
        print('epoch: {}, loss: {}'.format(epoch, value_to_max.item()))
        plt.imshow(img_gen[0][0].cpu().detach().numpy(), cmap='gray')
        plt.show()
        

epoch: 0, loss: 1.4248332977294922
在这里插入图片描述

epoch: 180, loss: 259.0355224609375
在这里插入图片描述

# 把 这个 img_gen 标准化到 -1 ,1 之间,然后输入网络,看看网络的预测结果
# 把最大值变成 1, 最小值变成 -1
img_gen = img_gen - torch.min(img_gen)
img_gen = img_gen / torch.max(img_gen)
img_gen = img_gen * 2 - 1
# 看看图片
plt.imshow(img_gen[0][0].cpu().detach().numpy(), cmap='gray')
# 输入网络,看看网络的预测结果和各类的概率
output = net(img_gen)
# 看各类的概率
for i in range(10):
    print('{}: {}'.format(classes[i], output[0][i].item()))

0: -5.7123026847839355
1: -0.5687944889068604
2: -1.5327638387680054
3: 0.04780220612883568
4: -2.2129156589508057
5: 2.809201955795288
6: -3.1844711303710938
7: -7.135143280029297
8: 13.538104057312012
9: -0.9435712099075317

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/24089.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jsp科研管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 科研管理系统 是一套完善的web设计系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开 发,数据库为Mysql,使用ja…

面试:KOOM内存泄漏的监控

LeakCannary 为什么各大厂自研的内存泄漏检测框架都要参考 LeakCanary?因为它是真强啊!_慕课手记 内存快照是在触发了onDestory中做的 目前,LeakCanary 支持以下五种 Android 场景中的内存泄漏监测: 1、已销毁的 Activity 对象…

基于java+ssm的在线投票管理系统-计算机毕业设计

项目介绍 基于SSM的在线投票系统以XXX学院为背景,运用在校所学习的软件开发原理,采用SpringSpringMVCMyBatis技术和MySQL数据库构建一个基于B/S模式的在线投票系统。本系统在设计之初,结合网络上。现有的在线投票系统。经过具体分析之后都出…

【Go】 力扣 - 剑指 Offer 第五天 - 二维数组中的查找

[Go] 力扣 - 剑指 Offer 第五天 - 二维数组中的查找题目来源题目描述示例题目分析算法暴力法代码实现复杂度分析二分法代码实现复杂度分析模拟 BST 标记查找法代码实现复杂度分析结尾耐心和持久胜过激烈和狂热。 题目来源 来源:力扣(LeetCode&#xff0…

TestStand-从LabVIEW创建TestStand数据类型的簇

文章目录从LabVIEW创建TestStand数据类型的簇从LabVIEW创建TestStand数据类型的簇 TestStand提供数字、字符串、布尔值和对象引用内置数据类型。 TestStand还提供了几种标准的命名数据类型,包括路径、错误、LabVIEW模拟波形等。可以通过创建容器数据类型来保存任何…

【第四部分 | JavaScript 基础】1:JS概述、变量及输入输出

目录 | 概述 | JS的书写位置 | 输入输出 | 变量 命名规范 基本使用 通过输入语句prompt把信息赋值给变量 | 数据类型 JS数据类型的特别 简单数据类型 简介 简单数据类型 Number 简单数据类型 String 简单数据类型 Boolean、Undefined、Null 获取类型 类型转换 | …

巴菲特斥资290亿抄底,台积电跌成“白菜价”?

11月14日,巴菲特旗下伯克希尔向美国证券交易委员会(SEC)提交了13F季度报告。报告显示,三季度伯克希尔斥资41亿美元(约290亿人民币)大幅买入台积电。 报告发出后,第二天台积电美股涨超6%&#x…

多旋翼无人机组合导航系统-多源信息融合算法(Matlab代码实现)

🍒🍒🍒欢迎关注🌈🌈🌈 📝个人主页:我爱Matlab 👍点赞➕评论➕收藏 养成习惯(一键三连)🌻🌻🌻 🍌希…

GLAD:利用全息图实现加密和解密

概述 全息图能够通过两束相干光相干叠加获得。用其中一束光照射生成的全息图就可以得到另一束相干光,这样全息图就可以用作加密/解密的装置了。 系统描述 在本例中一个复杂的随机图样作为参考光源,用来恢复全息图样对应的物光源。加密过程中&am…

单目标应用:人工兔优化算法(Artificial Rabbits Optimization ,ARO)求解旅行商问题TSP(提供MATLAB代码)

一、算法简介 人工兔优化算法(Artificial Rabbits Optimization ,ARO)由Liying Wang等人于2022年提出,该算法模拟了兔子的生存策略,包括绕道觅食和随机躲藏,并通过能量收缩在两种策略之间转换。绕道觅食策…

显示订单列表【项目 商城】

显示订单列表【项目 商城】前言显示订单列表1 持久层1.1 规划SQL语句1.2 实现接口与抽象方法1.3 配置SQL映射测试2 业务层2.1 规划异常2.2 编写接口与抽象方法2.3 实现抽象方法测试3 控制器3.1 处理异常3.2 设计请求3.3 处理请求测试4 前端页面测试前言 写作于 2022-10-14 17:…

【MySQL】安装与配置(内附安装包+未将对象引用设置到对象的实例的错误解决方法)

目录 一、数据库分类 (1)关系型数据库(RDBMS) (2)非关系型数据库 二、MySQL服务器安装 三、安装包文件分享 一、数据库分类 数据库大体可以分为关系型数据库和非关系型数据库 (1&#xff0…

U盘复制错误0x80071ac3如何解决?

U盘是一款移动存储设备,但是在使用中也会遇到一些错误问题,比如文件复制、粘贴或移动时提示0x80071ac3错误代码要如何解决呢?下面就和小编一起来看看解决办法吧。 方法一: 1、有些用户是使用U盘时出现的问题,先按下快捷…

记宝塔使用webhook自动化同步gitee代码

1、服务器ssh密钥 1.1、输入命令查看服务器是否存在密钥: cd ~/.sshls id_xxx.pub的是公钥、id_xxx的是私钥 如果没有,就要先生成一下,生成ssh密钥参考https://gitee.com/help/articles/4181#article-header0 1.2、复制ssh公钥到码云公钥…

【Hack The Box】linux练习-- Blocky

HTB 学习笔记 【Hack The Box】linux练习-- Blocky 🔥系列专栏:Hack The Box 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 📆首发时间:🌴2022年11月17日🌴 &#x1f3…

UE4 回合游戏项目 22- 添加第二个玩家

在上一节(UE4 回合游戏项目 21- 添加多种类型的敌人)基础上新添加一个玩家角色 效果: 步骤: 1.打开进阶游戏资源,解压“回合迁移_第七节(只是新人物包)” 2.解压后双击打开工程 3.选中“ziyuan…

如何通过快解析实现外网远程访问JupyterNotebook

什么是Jupyter Notebook?官网介绍:Jupyter Notebook是基于网页的用于交互计算的应用程序。其可被应用于全过程计算:开发、文档编写、运行代码和展示结果。简单地说,Jupyter Notebook是以网页的形式打开,可以在网页页面…

Spring Boot——yml和properties详解

文章目录1. 配置文件作用2. 配置文件的格式和分类2.1 规则(tips)2.2 为配置文件安装提示插件3. properties 配置文件说明3.1 properties 基本语法3.2 关于 properties 中文乱码的问题处理:4. 读取 properties 配置文件4.1 读取单个配置文件5.…

Spring @DateTimeFormat日期格式化时注解浅析分享

文章目录总结写前面为什么用怎么用场景一场景二场景三场景四场景五方式一方式二总结写前面 关于它 DateTimeFormat: 可以接收解析前端传入字符时间数据;不能格式化接收的字符时间类型数据,需要的转换格式得配置;入参格式必须与后…

罗丹明PEG羟基,RB-PEG-OH,Rhodamine-PEG-OH

产品名称:罗丹明PEG羟基 英文名称:RB-PEG-OH,Rhodamine-PEG-OH,Rhodamine PEG hydroxyl,RB-PEG-OH CAS:1030-000-8 结构式: 罗丹明吸收波长570 nm,发射波长约595 nm。罗丹明B可追踪粉红色和红…