【PyTorch】softmax回归

news2025/1/9 16:26:20

文章目录

  • 1. 模型与代码实现
    • 1.1. 模型
    • 1.2. 代码实现
  • 2. Q&A

1. 模型与代码实现

1.1. 模型

  • 背景
    在分类问题中,模型的输出层是全连接层,每个类别对应一个输出。我们希望模型的输出 y ^ j \hat{y}_j y^j可以视为属于类 j j j的概率,然后选择具有最大输出值的类别作为我们的预测。
    softmax函数能够将未规范化的输出变换为非负数并且总和为1,同时让模型保持可导的性质,而且不会改变未规范化的输出之间的大小次序
  • softmax函数
    y ^ = s o f t m a x ( o ) \mathbf{\hat{y}}=\mathrm{softmax}(\mathbf{o}) y^=softmax(o)其中 y ^ j = e x p ( o j ) ∑ k e x p ( o k ) \hat{y}_j=\frac{\mathrm{exp}({o_j})}{\sum_{k}\mathrm{exp}({o_k})} y^j=kexp(ok)exp(oj)
  • softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定,因此,softmax回归是一个线性模型
  • 为了避免将softmax的输出直接送入交叉熵损失造成的数值稳定性问题,将softmax和交叉熵损失结合在一起,具体做法是:不将softmax概率传递到损失函数中, 而是在交叉熵损失函数中传递未规范化的输出,并同时计算softmax及其对数。因此定义交叉熵损失函数时也进行了softmax运算

1.2. 代码实现

import torch
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from tensorboardX import SummaryWriter

# 全局参数设置
batch_size = 256
num_workers = 0
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

writer = SummaryWriter()

# 加载数据集
root = "./dataset"
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = FashionMNIST(
    root=root, 
    train=True, 
    transform=transform, 
    download=True
)
mnist_test = FashionMNIST(
    root=root, 
    train=False, 
    transform=transform, 
    download=True
)
dataloader_train = DataLoader(
    mnist_train, 
    batch_size, 
    shuffle=True, 
    num_workers=num_workers
)
dataloader_test = DataLoader(
    mnist_test, 
    batch_size, 
    shuffle=False,
    num_workers=num_workers
)

# 定义神经网络
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)).to(device)

# 初始化模型参数
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean=0, std=0.01)
        nn.init.constant_(m.bias, val=0)
net.apply(init_weights)

# 定义损失函数
criterion = nn.CrossEntropyLoss(reduction='none')

# 定义优化器
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

class Accumulator:
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

for epoch in range(num_epochs):
    # 训练模型
    net.train()
    train_metrics = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数
    for X, y in dataloader_train:
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        loss = criterion(y_hat, y)
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        train_metrics.add(float(loss.sum()), accuracy(y_hat, y), y.numel())
    train_loss = train_metrics[0] / train_metrics[2]
    train_acc = train_metrics[1] / train_metrics[2]

    # 测试模型
    net.eval()
    with torch.no_grad():    
        test_metrics = Accumulator(2)   # 测试准确度总和、样本数
        for X, y in dataloader_test:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            loss = criterion(y_hat, y)
            test_metrics.add(accuracy(y_hat, y), y.numel())
        test_acc = test_metrics[0] / test_metrics[1]
        writer.add_scalars("metrics", {
            'train_loss': train_loss, 
            'train_acc': train_acc, 
            'test_acc': test_acc
            }, 
            epoch)
writer.close()   


输出结果:
tensorboard

2. Q&A

  • 运行过程中出现以下警告:

UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at …\torch\csrc\utils\tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

该警告的大致意思是给定的NumPy数组不可写,并且PyTorch不支持不可写的张量。这意味着你可以使用张量写入底层(假定不可写)NumPy数组。在将数组转换为张量之前,可能需要复制数组以保护其数据或使其可写。在本程序的其余部分,此类警告将被抑制。因此需要修改C:\Users\%UserName%\anaconda3\envs\%conda_env_name%\lib\site-packages\torchvision\datasets\mnist.py的第498行,将return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)中的False改成True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1282398.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络入侵检测技术研究

摘 要 随着网络技术的发展,全球信息化的步伐越来越快,网络信息系统己成为一个单位、一个部门、一个行业,甚至成为一个关乎国家国计民生的基础设施,团此,网络安全就成为国防安全的重要组成部分,入侵检测技术…

线程的使用1

1. 创建一个线程 1.1 创建线程练习 线程实际上是应用层的概念,在 Linux 内核中,所有的调度实体都被称为任务 (task) , 他们之间的区别是:有些任务自己拥有一套完整的资源,而有些任务彼此之间共享一套资源 对此函数的使…

机器学习深度学学习分类模型中常用的评价指标总结记录与代码实现说明

在机器学习深度学习算法模型大开发过程中,免不了要对算法模型进行对应的评测分析,这里主要是总结记录分类任务中经常使用到的一些评价指标,并针对性地给出对应的代码实现,方便读者直接移植到自己的项目中。 【混淆矩阵】 混淆矩阵…

多模块下MyBatis官方生成器

MyBatis官方生成器 1. 依赖插件 创建generator模块 在父模块中进行版本管理 <dependency><groupId>org.mybatis</groupId><artifactId>mybatis</artifactId><version>3.5.6</version> </dependency><dependency><g…

Python练习题(三)

&#x1f4d1;前言 本文主要是【Python】——Python练习题的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304;每日一句&am…

[足式机器人]Part2 Dr. CAN学习笔记-Ch0-1矩阵的导数运算

本文仅供学习使用 本文参考&#xff1a; B站&#xff1a;DR_CAN Dr. CAN学习笔记-Ch0-1矩阵的导数运算 1. 标量向量方程对向量求导&#xff0c;分母布局&#xff0c;分子布局1.1 标量方程对向量的导数1.2 向量方程对向量的导数 2. 案例分析&#xff0c;线性回归3. 矩阵求导的链…

el-pagination 纯前端分页

需求&#xff1a;后端把所有数据都返给前端&#xff0c;前端进行分页渲染。 实现思路&#xff1a;先把数据存储到一个大数组中&#xff0c;然后调用方法进行切割。主要使用数组的slice方法 所有代码&#xff1a; html <template><div style"padding: 20px&qu…

CSS 局限-contain

CSS 局限 CSS 局限规范的目标在于通过允许浏览器从页面的其余部分中隔离出页面子树而改善性能。若浏览器知道页面的某一部分为独立的&#xff0c;则可优化渲染并改善性能。 此外&#xff0c;此规范允许开发者标示元素究竟是否应当渲染其内容&#xff0c;以及在屏外时是否应当…

【程序员 | 交流】程序员情商修炼指南系列 (沟通是有效合作一大利器)

&#x1f935;‍♂️ 个人主页: AI_magician &#x1f4e1;主页地址&#xff1a; 作者简介&#xff1a;CSDN内容合伙人&#xff0c;全栈领域优质创作者。 &#x1f468;‍&#x1f4bb;景愿&#xff1a;旨在于能和更多的热爱计算机的伙伴一起成长&#xff01;&#xff01;&…

如何能够对使用ShaderGraph开发的Shader使用SetTextureOffset和SetTextureScale方法

假设在ShaderGraph中的纹理的引用名称为"_BaseMap"&#xff0c;同时对这个"_BaseMap"纹理使用了采样的节点"SampleTexture2D"&#xff0c;然后该采样节点的uv接入的TilingAndOffset节点&#xff0c;此时的关键步骤是新建一个Vector4属性&#xf…

HT7183 高功率异步升压转换器 中文资料

HT7183是一款高功率异步升压转换器&#xff0c;集成120mΩ功率开关管&#xff0c;为便携式系统提供G效的小尺寸处理方案。HT7183具有2.6V至5.5V输入电压范围&#xff0c;可为各类不同供电的应用提供支持。HT7183具备3A开关电流能力&#xff0c;并且能够提供高达16V的输出电压。…

实时绘画迎来大更新,本地即可部署

个人网站&#xff1a;https://tianfeng.space 前言 自此LCM公布以来&#xff0c;这一个星期在相关应用方面的更新速度nb&#xff0c;各种实时绘画工作流随之出现&#xff0c;之前还只能依赖krea内测资格使用&#xff0c;让我们来看看上周发生了那些事吧&#xff01; 网盘&am…

做外贸如何写开发信?外贸邮件营销怎么写?

外贸业务员写开发信的技巧&#xff1f;撰写客户开发信模板详解&#xff01; 外贸经营是一项竞争激烈的行业&#xff0c;写好开发信是吸引客户、建立合作关系的重要一环。蜂邮EDM将为您详细介绍如何撰写出色的开发信&#xff0c;以吸引客户的眼球&#xff0c;引领他们与您建立联…

YOLOv8独家原创改进:创新自研CPMS注意力,多尺度通道注意力具+多尺度深度可分离卷积空间注意力,全面升级CBAM

💡💡💡本文自研创新改进:自研CPMS, 多尺度通道注意力具+多尺度深度可分离卷积空间注意力,全面升级CBAM 1)作为注意力CPMS使用; 推荐指数:五星 CPMS | 亲测在多个数据集能够实现涨点,对标CBAM。 收录 YOLOv8原创自研 https://blog.csdn.net/m0_63774211/ca…

PTA 6-1 最小生成树(普里姆算法)使用递归

普利姆算法的原理 普里姆算法查找最小生成树的过程&#xff0c;采用了贪心算法的思想。对于包含 N 个顶点的连通网&#xff0c;普里姆算法每次从连通网中找出一个权值最小的边&#xff0c;这样的操作重复 N-1 次&#xff0c;由 N-1 条权值最小的边组成的生成树就是最小生成树。…

HarmonyOS 振动效果开发指导

Vibrator 开发概述 振动器模块服务最大化开放硬工最新马达器件能力&#xff0c;通过拓展原生马达服务实现振动与交互融合设计&#xff0c;打造细腻精致的一体化振动体验和差异化体验&#xff0c;提升用户交互效率和易用性、提升用户体验、增强品牌竞争力。 运作机制 Vibrato…

数据结构与算法-动态查找表

查找 &#x1f388;3动态查找表&#x1f52d;3.1二叉排序树&#x1f3c6;3.1.1二叉排序树的类定义&#x1f3c6;3.1.2二叉排序树的插入和生成&#x1f3c6;3.1.3二叉树的查找&#x1f3c6;3.1.4二叉排序树的删除 &#x1f52d;3.2平衡二叉树&#x1f3c6;3.2.1平衡二叉树的调整…

基于jsp的搜索引擎

摘 要 随着互联网的不断发展和日益普及&#xff0c;网上的信息量在迅速地增长&#xff0c;在2004年4月&#xff0c;全球Web页面的数目已经超过40亿&#xff0c;中国的网页数估计也超过了3亿。 目前人们从网上获得信息的主要工具是浏览器&#xff0c;搜索引擎在网络中占有举足轻…

平价的开放式耳机怎么选?推荐几款平价好用的耳机,亲测对比

是不是也在为如何在有限的预算内找到一款性价比高的开放式耳机而烦恼呢&#xff1f;别着急&#xff0c;小编为你精心挑选了几款平价好用的开放式耳机&#xff0c;并亲自进行了对比测试&#xff0c;在这个音乐时代&#xff0c;不需要花大价钱就能拥有高品质的音乐体验&#xff0…

微积分-圆的面积和周长(1)

微积分 历史 先有牛顿后有天&#xff0c;创世之后再造仙。作为近代物理学的开山鼻祖&#xff0c;牛顿的贡献怎么评价都不为过。而微积分是首先被牛顿搞出来的也已经是公认的事实&#xff0c;牛顿在研究物理问题的时候顺带做出来的&#xff0c;不知是舍不得发表还是不屑于发表…