使用PyTorch解决多分类问题:构建、训练和评估深度学习模型

news2025/1/12 13:17:23

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🍋引言
  • 🍋什么是多分类问题?
  • 🍋处理步骤
  • 🍋多分类问题
  • 🍋MNIST dataset的实现
  • 🍋NLLLoss 和 CrossEntropyLoss

🍋引言

当处理多分类问题时,PyTorch是一种非常有用的深度学习框架。在这篇博客中,我们将讨论如何使用PyTorch来解决多分类问题。我们将介绍多分类问题的基本概念,构建一个简单的多分类神经网络模型,并演示如何准备数据、训练模型和评估结果。

🍋什么是多分类问题?

多分类问题是一种机器学习任务,其中目标是将输入数据分为多个不同的类别或标签。与二分类问题不同,多分类问题涉及到三个或更多类别的分类任务。例如,图像分类问题可以将图像分为不同的类别,如猫、狗、鸟等。

🍋处理步骤

  • 准备数据
    收集和准备数据集,确保每个样本都有相应的标签,以指明其所属类别。
    划分数据集为训练集、验证集和测试集,以便进行模型训练、调优和性能评估。

  • 数据预处理
    对数据进行预处理,例如归一化、标准化、缺失值处理或数据增强,以确保模型训练的稳定性和性能。

  • 选择模型架构
    选择适当的深度学习模型架构,通常包括卷积神经网络(CNN)、循环神经网络(RNN)、Transformer等,具体取决于问题的性质。

  • 定义损失函数
    为多分类问题选择适当的损失函数,通常是交叉熵损失(Cross-Entropy Loss)。

  • 选择优化器
    选择合适的优化算法,如随机梯度下降(SGD)、Adam、RMSprop等,以训练模型并调整权重。

  • 训练模型
    使用训练数据集来训练模型。在每个训练迭代中,通过前向传播和反向传播来更新模型参数,以减小损失函数的值。

  • 评估模型
    使用验证集来评估模型性能。常见的性能指标包括准确性、精确度、召回率、F1分数等。

  • 调优模型
    根据验证集的性能,对模型进行调优,可以尝试不同的超参数设置、模型架构变化或数据增强策略。

  • 测试模型
    最终,在独立的测试数据集上评估模型的性能,以获得最终性能评估。

  • 部署模型
    将训练好的模型部署到实际应用中,用于实时或批处理多分类任务。

🍋多分类问题

之前我们讨论的问题都是二分类居多,对于二分类问题,我们若求得p(0),南无p(1)=1-p(0),还是比较容易的,但是本节我们将引入多分类,那么我们所求得就转化为p(i)(i=1,2,3,4…),同时我们需要满足以上概率中每一个都大于0;且总和为1

处理多分类问题,这里我们新引入了一个称为Softmax Layer
在这里插入图片描述

接下来我们一起讨论一下Softmax Layer层
在这里插入图片描述
首先我们计算指数计算e的zi次幂,原因很简单e的指数函数恒大于0;分母就是e的z1次幂+e的z2次幂+e的z3次幂…求和,这样所有的概率和就为1了。


下图形象的展示了Softmax,Exponent这里指指数,和上面我们说的一样,先求指数,这样有了分子,再将所有指数求和,最后一一divide,得到了每一个概率。
在这里插入图片描述


接下来我们一起来看看损失函数
在这里插入图片描述
如果使用numpy进行实现,根据刘二大人的代码,可以进行如下的实现

import numpy as np
y = np.array([1,0,0])
z = np.array([0.2,0.1,-0.1])
y_pred = np.exp(z)/np.exp(z).sum()
loss = (-y * np.log(y_pred)).sum()
print(loss)

运行结果如下
在这里插入图片描述
注意:神经网络的最后一层不需要激活


在pytorch中

import torch
y = torch.LongTensor([0])  # 长整型
z = torch.Tensor([[0.2, 0.1, -0.1]])
criterion = torch.nn.CrossEntropyLoss() 
loss = criterion(z, y)
print(loss)

运行结果如下

在这里插入图片描述
下面根据一个例子进行演示

criterion = torch.nn.CrossEntropyLoss()
Y = torch.LongTensor([2,0,1])
Y_pred1 = torch.Tensor([[0.1, 0.2, 0.9], 
                        [1.1, 0.1, 0.2], 
                        [0.2, 2.1, 0.1]]) 
Y_pred2 = torch.Tensor([[0.8, 0.2, 0.3], 
                        [0.2, 0.3, 0.5], 
                        [0.2, 0.2, 0.5]])
l1 = criterion(Y_pred1, Y)
l2 = criterion(Y_pred2, Y)
print("Batch Loss1 = ", l1.data, "\nBatch Loss2=", l2.data)

运行结果如下
在这里插入图片描述

根据上面的代码可以看出第一个损失比第二个损失要小。原因很简单,想对于Y_pred1每一个预测的分类与Y是一致的,而Y_pred2则相差了一下,所以损失自然就大了些

🍋MNIST dataset的实现

首先第一步还是导包

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader 
import torch.nn.functional as F 
import torch.optim as optim

之后是数据的准备

batch_size = 64
# transform可以将其转化为0-1,形状的转换从28×28转换为,1×28×28
transform = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.1307, ), (0.3081, ))   # 均值mean和标准差std
])
train_dataset = datasets.MNIST(root='../dataset/mnist/', 
								train=True,
								download=True,
								transform=transform)  
train_loader = DataLoader(train_dataset,
							shuffle=True,
							batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', 
							train=False,
							download=True,
							transform=transform)
test_loader = DataLoader(test_dataset,
						shuffle=False,
						batch_size=batch_size)

在这里插入图片描述
接下来我们构建网络

class Net(torch.nn.Module):
	def __init__(self):
		super(Net, self).__init__()
		self.l1 = torch.nn.Linear(784, 512) 
		self.l2 = torch.nn.Linear(512, 256) 
		self.l3 = torch.nn.Linear(256, 128) 
		self.l4 = torch.nn.Linear(128, 64) 
		self.l5 = torch.nn.Linear(64, 10)
	def forward(self, x):
		x = x.view(-1, 784)
		x = F.relu(self.l1(x)) 
		x = F.relu(self.l2(x)) 
		x = F.relu(self.l3(x)) 
		x = F.relu(self.l4(x)) 
		return self.l5(x)  # 注意最后一层不做激活
model = Net()

之后定义损失和优化器

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

接下来就进行训练了

def train(epoch):
	running_loss = 0.0
	for batch_idx, data in enumerate(train_loader, 0): 
		inputs, target = data
		optimizer.zero_grad()
		# forward + backward + update
		outputs = model(inputs)
		loss = criterion(outputs, target)
		loss.backward()
		optimizer.step()
		running_loss += loss.item()
	if batch_idx % 300 == 299:
		print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300)) 
		running_loss = 0.0
def test():
	correct = 0
	total = 0
	with torch.no_grad(): # 这里可以防止内嵌代码不会执行梯度
		for data in test_loader:
			images, labels = data
			outputs = model(images)
			_, predicted = torch.max(outputs.data, dim=1)
			total += labels.size(0)
			correct += (predicted == labels).sum().item()
	print('Accuracy on test set: %d %%' % (100 * correct / total))

最后调用执行

if __name__ == '__main__': 
	for epoch in range(10): 
		train(epoch)
		test()

🍋NLLLoss 和 CrossEntropyLoss

NLLLoss 和 CrossEntropyLoss(也称为交叉熵损失)是深度学习中常用的两种损失函数,用于测量模型的输出与真实标签之间的差距,通常用于分类任务。它们有一些相似之处,但也有一些不同之处。

相同点:

用途:两者都用于分类任务,评估模型的输出和真实标签之间的差异,以便进行模型的训练和优化。
数学基础:NLLLoss 和 CrossEntropyLoss 本质上都是交叉熵损失的不同变种,它们都以信息论的概念为基础,衡量两个概率分布之间的相似度。
输入格式:它们通常期望模型的输出是一个概率分布,表示各个类别的预测概率,以及真实的标签。

不同点:

输入格式:NLLLoss 通常期望输入是对数概率(log probabilities),而 CrossEntropyLoss 通常期望输入是未经对数化的概率。在实际应用中,CrossEntropyLoss 通常与softmax操作结合使用,将原始模型输出转化为概率分布,而NLLLoss可以直接使用对数概率。
对数化:NLLLoss 要求将模型输出的概率经过对数化(取对数)以获得对数概率,然后与真实标签的离散概率分布进行比较。CrossEntropyLoss 通常在 softmax 操作之后直接使用未对数化的概率值与真实标签比较。
输出维度:NLLLoss 更通用,可以用于多种情况,包括多类别分类和序列生成等任务,因此需要更多的灵活性。CrossEntropyLoss 通常用于多类别分类任务。

总之,NLLLoss 和 CrossEntropyLoss 都用于分类任务,但它们在输入格式和使用上存在一些差异。通常,选择哪个损失函数取决于你的模型输出的格式以及任务的性质。如果你的模型输出已经是对数概率形式,通常使用NLLLoss,否则通常使用CrossEntropyLoss

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1099727.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【学习笔记】RabbitMQ02:交换机,以及结合springboot快速开始

参考资料 RabbitMQ官方网站RabbitMQ官方文档噼咔噼咔-动力节点教程 文章目录 四、RabbitMQ :Exchange 交换机4.1 交换机类型4.2 扇形交换机 Fanout Exchange4.2.1 概念4.2.1 实例:生产者4.2.1.1 添加起步依赖4.2.1.2 配置文件4.2.1.3 JavaBean进行配置4.…

NGF ; -R : Trk NTRK

NTRK基因融合的机制与靶向治疗 - 知乎 【NTRK基因】共识已发布,24款获证,2款NGS产品已布局

开关电源测试解决方案之浪涌电流测试 -纳米软件

浪涌电流是由于开关、断电等引起的突发电流,这种电流会直接影响到设备的寿命和可靠性,因此浪涌电流测试是开关电源测试的一个重要项目。本文将会介绍浪涌电流测试要求以及测试方法。 什么是浪涌电流? 浪涌电流是指电源接通瞬间,流入电源设备…

Python中兔子递归函数的例子

本文将详细介绍Python中兔子递归函数的例子,展示递归函数的基本实现方法及其原理。 一、递归函数的概念 递归函数是指在函数内部调用自身的函数。通过递归函数,可以将复杂问题分解成简单的子问题来解决。 这种过程是有限的,当子问题足够小…

分布式存储系统——ceph

目录 一、分布式存储类型 1、块存储 2、文件存储 3、对象存储 二、ceph概述 1、ceph简介 2、ceph的优势 3、ceph架构 1)RADOS 基础存储系统(Reliab1e,Autonomic,Distributed object store 2)LIBRADOS 基础库 …

Linux 下I/O操作

一、文件IO 文件 IO 是 Linux 系统提供的接口,针对文件和磁盘进行操作,不带缓存机制;标准IO是C 语言函数库里的标准 I/O 模型,在 stdio.h 中定义,通过缓冲区操作文件,带缓存机制。   标准 IO 和文件 IO 常…

论文阅读 Memory Enhanced Global-Local Aggregation for Video Object Detection

Memory Enhanced Global-Local Aggregation for Video Object Detection Abstract 人类如何识别视频中的物体?由于单一帧的质量低下,仅仅利用一帧图像内的信息可能很难让人们在这一帧中识别被遮挡的物体。我们认为人们识别视频中的物体有两个重要线索&…

怎么使用动态代理IP提升网络安全,动态代理IP有哪些好处呢

随着互联网的普及和数字化时代的到来,网络安全问题越来越受到人们的关注。动态代理IP作为网络安全中的一种技术手段,被越来越多的人所采用。本文将介绍动态代理IP的概念、优势以及如何应用它来提升网络安全。 一、动态代理IP的概念 动态代理IP是指使用代…

nginx.4——正向代理和反向代理(七层代理和四层代理)

1、正向代理反向代理 nginx当中有两种代理方式 七层代理(http协议) 四层代理(tcp/udp流量转发) 七层代理 七层代理:代理的是http的请求和响应。 客户端请求代理服务器,由代理服务器转发给客户端http请求。转发到内部服务器(可以单台&#…

Mac OS m1 下安装Gradle4.8.1

1. 下载、解压 1.1 下载地址 https://gradle.org 往下翻 或者选择 任何 你想要的版本 ,点击 binary-only 即可下载 . 1.2 解压到指定目录 2. 配置环境变量 2.1 编辑环境文件 vi ~/.bash_profile #GRADLE相关配置 GRADLE_HOME/Users/zxj/Documents/devSoft/gradle-4.8.1 e…

视频集中存储/视频监控管理平台EasyCVR如何免密登录系统?详细操作如下

视频云存储/安防监控EasyCVR视频汇聚平台基于云边端智能协同,支持海量视频的轻量化接入与汇聚、转码与处理、全网智能分发、视频集中存储等。音视频流媒体视频平台EasyCVR拓展性强,视频能力丰富,具体可实现视频监控直播、视频轮播、视频录像、…

【Mac】时间机器频繁提示磁盘没有正常推出

问题描述 有一次在进行时间机器备份的时候总是提示“磁盘没有正常推出”,并且好几次直接导致系统重启… 估计是 MacOS 系统 bug 解决 看了 Vex 一个帖子之后设置了一个硬盘是否休眠就好了,不要勾选让硬盘处于休眠就可以了,在电池选项界面中…

Sonar跨服务扫描Jenkins Job

需求:验收环境新增sonar扫描,(困难点:验收环境jenkins和开发环境sonar不通,验收环境没有办法重新安装Sonar,数据库等服务。 解决方法:验收Jenkins和开发环境jenkins做免密,通过修改验收环境jenk…

基于 nodejs+vue网上考勤系统

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

VSCode连接代理

VSCode连接代理 首先有代理 然后在设置里搜代理 然后再在windows的设置–>网络–>代理 拼接上就行 最后重启

问题解决:使用Mybatis Plus的Mapper插入达梦数据库报“数据溢出”错误

前言 使用Mybatis Plus的Mapper插入达梦数据库报“数据溢出”错误 文章目录 前言问题描述错误日志输出排查过程最终解决办法 问题描述 在进行批量插入中,抛出异常为数据溢出 插入方法:this.baseMapper.insertBatchSomeColumn()抛出异常:数据…

Hadoop分布式文件系统——HDFS

1.介绍 HDFS (Hadoop Distributed File System)是 Hadoop 下的分布式文件系统,具有高容错、高吞吐量等特性,可以部署在低成本的硬件上。 2.HDFS 设计原理 2.1 HDFS 架构 HDFS 遵循主/从架构,由单个 NameNode(NN) 和多个 DataNode(DN) 组成:

C++【命名空间详解】

系列文章目录 本篇文章呢主要说的是C中关于命名空间相关知识点哦~ 文章目录 一、C简单介绍二、详解 1.命名空间2.输入输出总结 一、C简单介绍 前面我们学习的C语言,它是结构化和模块化的语言,适合处理较小规模的程序。那么,对于一些复杂的问…

汉得欧洲x甄知科技 | 携手共拓全球化布局,助力出海中企数智化发展

HAND Europe 荣幸获得华为云颁发的 GrowCloud 合作伙伴奖项,进一步巩固了其在企业数字化领域的重要地位。于 2023 年 10 月 5 日,HAND Europe 参加了华为云荷比卢峰会,并因其在全球拓展方面的杰出贡献而荣获 GrowCloud 合作伙伴奖项的认可。 …

华为eNSP配置专题-ACL的配置

文章目录 华为eNSP配置专题-ACL的配置1、前置环境1.1、宿主机1.2、eNSP模拟器 2、基本环境搭建2.1、基本终端构成和连接2.2、各终端基本配置2.2.1、PC1和PC2的配置2.2.2、模拟互联网的路由器的配置2.2.3、财务部服务器的配置2.2.4、路由器AR1的配置 2.3、让各网段能够ping通互联…