深度学习PyTorch 之 DNN-多分类

news2025/1/10 10:46:46

前面讲了深度学习&PyTorch 之 DNN-二分类,本节讲一下DNN多分类相关的内容,这里分三步进行演示

结构化数据

我们还是以iris数据集为例,因为这个与前面的流程完全一样,只有在模型定义时有些区别

  • 损失函数不一样
    二分类时用的损失函数是:loss_fn = nn.BCELoss()
    在多分类时需要使用: loss_fn = torch.nn.CrossEntropyLoss()

  • 输出类别不一样
    二分类输出时,需要使用sigmoid函数进行激活,x = torch.sigmoid(self.hidden3(x))
    多分类不需要使用激活函数,只需要输出全连接后的数据就可以

所以模型定义如下

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(4, 120)
        self.linear_2 = nn.Linear(120, 84)
        self.linear_3 = nn.Linear(84, 4)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.linear_1(x))
        x = torch.relu(self.linear_2(x))
        logits = self.linear_3(x)
        return logits    # 未激活的输出,叫做logits

训练与之前一样,就不写了

重点讲非结构化数据,图片

图片多分类 - Minst

原理部分

MNIST数据集是由0〜9手写数字图片和数字标签所组成的,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片,如下
在这里插入图片描述

我们其实就是识别图片中的数字,没一个数字图片其实是有一个一个像素组成的,
在这里插入图片描述
神经元装着数字代表对应像素的灰度值,0代表纯黑色,1代表纯白像素
我们要想训练这些像素点,需要将像素进行重组,就是将这些像素重新排列,将每一行的像素首尾相连,最终连接成一个长串,因为一行有28个像素点,一共28行,即最终有28*28个特征
将这些转换好的数据带入到模型中进行训练。
在这里插入图片描述

代码部分

1.数据准备

数据准备直接包含了数据导入+数据拆分+ToTensor

  • 导入有两种方式,一种是使用下载到本地的数据集,另一种是使用torchvision直接在线下载,速度还是比较快的
  • 训练和测试数据分别导入
  • 在导入时,可以设置transform=ToTensor(),进行转换
train_ds_m = torchvision.datasets.MNIST('data',
                                      train=True,
                                      transform=ToTensor(),
                                      download=True)
test_ds_m = torchvision.datasets.MNIST('data',
                                      train=False,
                                      transform=ToTensor(),
                                      download=True)

2. 数据重构

这里与之前一直,不多说

train_dl_m = torch.utils.data.DataLoader(train_ds_m, 
                                       batch_size=64,
                                       shuffle=True)
test_dl_m = torch.utils.data.DataLoader(test_ds_m, 
                                       batch_size=64)

3. 数据查看

因为是图片数据,我们加一步数据查看,看一下导入的数据格式,加深我们的理解

idx_to_class = dict((v, k) for  k, v in train_ds_m.class_to_idx.items())
idx_to_class
#label格式
{0: '0 - zero',
 1: '1 - one',
 2: '2 - two',
 3: '3 - three',
 4: '4 - four',
 5: '5 - five',
 6: '6 - six',
 7: '7 - seven',
 8: '8 - eight',
 9: '9 - nine'}

dataloader本质上是一个可迭代对象,可以使用iter()进行访问,采用iter(dataloader)返回的是一个迭代器,然后可以使用next()访问。

imgs, labels = next(iter(train_dl_m))
imgs.shape
#torch.Size([64, 1, 28, 28])

我们可以看到,ims的数据格式是64,1,28,27

  • 64是我们定义的batch_size = 64
  • 1是指通道数,这里是黑白图片,所以通道数是1;如果是彩色图片通道数应该是3,即RGB三个通道
  • 28*28,就是图片的大小,我们前面原理部分说过了

图片展示一下

plt.figure(figsize=(16, 6))
for i,(img, label) in enumerate(zip(imgs[:16],labels[:16])):
    img = (img.permute(1,2,0).numpy() + 1)/2
    plt.subplot(2, 8, i+1)
    plt.title(idx_to_class.get(label.item()))
    plt.imshow(img)

在这里插入图片描述

4. 定义模型

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(28*28, 120)
        self.linear_2 = nn.Linear(120, 84)
        self.linear_3 = nn.Linear(84, 10)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.linear_1(x))
        x = torch.relu(self.linear_2(x))
        logits = self.linear_3(x)
        return logits    # 未激活的输出,叫做logits

这个跟之前一致,只是需要注意两点

  • x.view(x.size(0), -1),将数据拉平,就是将图片连城一个长串,28*28的格式
  • 最终输出的种类需要与我们预测的label类别数一致
model = Model()
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(), lr=0.001)
model

Model(
(linear_1): Linear(in_features=784, out_features=120, bias=True)
(linear_2): Linear(in_features=120, out_features=84, bias=True)
(linear_3): Linear(in_features=84, out_features=10, bias=True)
)

5. 训练及查看

epochs = 100
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_acc, epoch_loss = train(train_dl_m, model, loss_fn, opt)
    epoch_test_acc, epoch_test_loss = test(test_dl_m, model, loss_fn)
    if epoch%10==0:
        train_acc.append(epoch_acc)
        train_loss.append(epoch_loss)
        test_acc.append(epoch_test_acc)
        test_loss.append(epoch_test_loss)

        template = ("epoch:{:2d}, 训练损失:{:.5f}, 训练准确率:{:.1f},验证损失:{:.5f}, 验证准确率:{:.1f}")
        print(template.format(epoch, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
print('Done')

训练函数和测试函数与之前一致

epoch: 0, 训练损失:2.28837, 训练准确率:17.5,验证损失:2.27378, 验证准确率:20.1
epoch:10, 训练损失:0.72655, 训练准确率:81.6,验证损失:0.66988, 验证准确率:83.0
epoch:20, 训练损失:0.42698, 训练准确率:88.2,验证损失:0.40460, 验证准确率:88.7
epoch:30, 训练损失:0.35895, 训练准确率:89.9,验证损失:0.34320, 验证准确率:90.0
epoch:40, 训练损失:0.32467, 训练准确率:90.8,验证损失:0.31260, 验证准确率:90.9
epoch:50, 训练损失:0.30045, 训练准确率:91.4,验证损失:0.29069, 验证准确率:91.5
epoch:60, 训练损失:0.28021, 训练准确率:92.0,验证损失:0.27321, 验证准确率:92.1
epoch:70, 训练损失:0.26272, 训练准确率:92.6,验证损失:0.25764, 验证准确率:92.6
epoch:80, 训练损失:0.24705, 训练准确率:93.0,验证损失:0.24356, 验证准确率:93.0
epoch:90, 训练损失:0.23274, 训练准确率:93.4,验证损失:0.23069, 验证准确率:93.4
Done

查看损失值和准确率的变化

import matplotlib.pyplot as plt

plt.plot(range(len(train_loss)), train_loss, label='train_loss')
plt.plot(range(len(test_loss)), test_loss, label='test_loss')
plt.legend()

plt.plot(range(len(train_acc)), train_acc, label='train_acc')
plt.plot(range(len(test_acc)), test_acc, label='test_acc')
plt.legend()

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/170440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pollard Rho算法

生日悖论 假设一年有nnn天,房间中有kkk人,每个人的生日在这nnn天中,服从均匀分布,两个人的生日相互独立 问至少要有多少人,才能使其中两个人生日相同的概率达到ppp 解:考虑k≤nk\le nk≤n 设kkk个人生日互…

Spring框架介绍及使用

文章目录1.概述1.1 Spring是什么1.2 Spring 的优势1.3 spring 的体系结构2. IoC 的概念和作用2.1 什么是程序的耦合2.2 IoC容器3. AOP的概念和作用超链接: Spring重点内容学习资料1.概述 1.1 Spring是什么 Spring 是分层的 Java SE/EE 应用 full-stack 轻量级开源…

使用docker-compose搭建Prometheus+Grafana监控系统

一、角色分配 Prometheus 采集数据Grafana 用于图表展示redis_exporter 用于收集redis的metricsnode-exporter 用于收集操作系统和硬件信息的metricscadvisor 用于收集docker的相关metrics 二、安装Docker 可以参考:https://ximeneschen.blog.csdn.net/article/d…

JVM调优实战:to-space exhausted Evacuation Failure

一次线上dubbo问题的定位,进行JVM调优实战。问题线上dubbo接口provider抛出异常:org.apache.dubbo.rpc.RpcException: Failfast invoke providers ... RandomLoadBalance select from all providers ... use dubbo version 2.7.16, but no luck to perfo…

vulnhub DC系列 DC-8

总结:exim4提权 目录 下载地址 漏洞分析 信息收集 网站爆破 后台webshell 提权 下载地址 DC-8.zip (Size: 379 MB)Download: http://www.five86.com/downloads/DC-8.zipDownload (Mirror): https://download.vulnhub.com/dc/DC-8.zip使用方法:解压后&#xff…

Cosmos 基础(二)-- Ignite CLI

官网 DOC GitHub 你的项目值得拥有自己的区块链。 Ignite使开发、增长和启动区块链项目比以往任何时候都更快。 Ignite CLI是一个一体化平台,可以在主权和安全的区块链上构建、启动和维护任何加密应用程序 Install Ignite 一、安装 你可以在基于web的Gitpod…

23种设计模式(七)——桥接模式【单一职责】

文章目录 意图什么时候使用桥接真实世界类比桥接模式的实现桥接模式的优缺点亦称:Bridge 意图 桥接模式是将抽象部分与实现部分分离,使它们都可以独立地变化。它是一种对象结构型模式,又称为柄体(Handle and Body)模式或接口(Interfce)模式。 什么时候使用桥接 1、如果一个…

详解MySQL数据库索引实现机制 - B树和B+树

详解MySQL数据库索引实现机制 - B树和B树1.索引的出现2.hash算法的缺点3.二叉排序树BST4.平衡二叉树AVL5.红黑树6.B树诞生了7.B树1.索引的出现 索引是一种用于快速查询和检索数据的数据结构,其本质可以看成是一种排序好的数据结构。 索引的作用就相当于书的目录。…

(Netty)Handler Pipeline

Handler & Pipeline ChannelHandler 用来处理 Channel 上的各种事件,分为入站、出站两种。所有 ChannelHandler 被连成一串,就是 Pipeline 入站处理器通常是 ChannelInboundHandlerAdapter 的子类,主要用来读取客户端数据,写…

【嵌入式处理器】CPU、MPU、MCU、DSP、SoC、SiP的联系与区别

1、CPU(Central Processing Unit) CPU(Central Processing Unit),是一台计算机的运算核心和控制核心。CPU由运算器、控制器和寄存器及实现它们之间联系的数据、控制及状态的总线构成。众所周知的三级流水线:取址、译码、执行的对象就是CPU,差…

重学Android之View——TabLayoutMediator解析

重学Android之View——TabLayoutMediator解析 1.前言 在使用TabLayoutViewPager2Fragment的时候,查询别人的使用例子,看到了 TabLayoutMediator这个类,撰写此文,仅当学习思考,本文是在引用material:1.7.0的版本基础…

记2022年秋招经历

自我介绍求职体验求职心得 一、自我介绍 学历普通本科,专业是网络工程,在校期间学习主要的是计算机体系方面的知识,根据课程,自学过前端、后端等内容。包括前端三板斧(htmlcssjs)、常用的前端框架(bootstarp/Vue等)&am…

Android项目接入React Native方案

本篇文章主要介绍在现有的Android项目中接入React Native的接入过程,分析接入过程中的一些问题和解决方案,接入RN的平台为Android,开发环境为Mac,开发工具为Android Studio。 一、环境配置 1、Android配置 因为是现有的Android项…

Vue实现DOM元素拖放互换位置

一、拖放和释放HTML 拖放接口使得 web 应用能够在网页中拖放文件。这里将介绍了 web 应用如何接受从底层平台的文件管理器拖动DOM的操作。拖放的主要步骤是为 drop 事件定义一个释放区(释放文件的目标元素) 和为dragover事件定义一个事件处理程序。触发 drop 事件的目标元素需要…

day20IO流

1.字符流 1.1为什么会出现字符流【理解】 字符流的介绍 由于字节流操作中文不是特别的方便,所以Java就提供字符流 字符流 字节流 编码表 中文的字节存储方式 用字节流复制文本文件时,文本文件也会有中文,但是没有问题,原因是最…

数学建模-分类模型(SPSS)

目录 1.简介 2.样例-二元 1.对于预测结果不理想,在logistics模型里加入平方项交互项等。 2.如果自变量有分类变量(如男女,行业有互联网行业、旅游行业……) 3.分训练集、测试集 4.fisher线性判别分析 3.样例-多元 注意&…

【Nginx】使用Docker完成Nginx反向代理

本机是在CentOS7上面进行操作的 1.首先安装好Dokcer,这里不再赘述 2.Docker安装Nginx容器 2.1首先需要创建Nginx配置文件,之后完成挂载 启动前需要先创建Nginx外部挂载的配置文件( /home/nginx/conf/nginx.conf) 之所以要先创建…

Redis - Redis 6.0 新特性之客户端缓存

1. 为什么需要客户端缓存 antirez 写了一篇有关客户端缓存设计的想法:《Client side caching in Redis 6》。antirez 认为,Redis 接下来的一个重点是配合客户端,因为客户端缓存显而易见的可以减轻 Redis 的压力,速度也快很多。实…

Android从开机到APP启动流程——基于Android9.0

Android从开机到APP启动流程——基于Android9.0 一、 Zygote进程启动流程 二、 System Server启动流程 三、 ActivityManagerService启动流程 四、 Launcher App (Home Activity)启动流程 五、 Zygote fork()子进程,子进程入口为ActivityThread.main() 六、 Acti…

第02讲:使用kubeadm搭建k8s集群的准备工作

官方地址:https://kubernetes.io/docs/reference/setup-tools/kubeadm/kubeadm/ kubeadm 是官方社区推出的一个用于快速部署 kubernetes 集群的工具,这个工具能通过两条指令完成一个 kubernetes 集群的部署: 第1步、创建一个 Master 节点 kubeadm init第2步&#x…