神经网络和深度学习-多分类问题Softmax Classifier

news2025/1/13 13:58:31

多分类问题Softmax Classifier

在之前做糖尿病数据集的时候做的二分类问题,因为只有两类,所以只需要输出一个概率值,另一个概率值用1去减去就可以得到

在这里插入图片描述

实际上在大多数数据集中是在处理一个多分类问题,例如MNIST中有10类标签

在这里插入图片描述

神经网络如何设计

我们在输出的时候,在原来只有p(y=1)的输出变为10个输出,这样就可以输出每一个样本属于每一个分类的概率,可能出现大多数分类都是高概率,这其中肯定是矛盾的,希望在输出的分类的概率需要满足一个分布的要求,满足离散分类

  • 全部概率>0

  • 总概率 = 1

在这里插入图片描述

Softmax Layer

在处理多分类问题的时候,在前面的神经网络还是采用Sigmoid Layer,在最终输出层中我们使用Softmax Layer

在这里插入图片描述

下面我们就要针对分布要求,看看Softmax Layer是如何设计的

P ( y = i ) = e z i ∑ j = 0 K − 1 e z j , i ∈ { 0 , … , K − 1 } P(y=i)=\frac{e^{z_{i}}}{\sum_{j=0}^{K-1} e^{z_{j}}}, i \in\{0, \ldots, K-1\} P(y=i)=j=0K1ezjezi,i{0,,K1}

  • 分子部分,我们可以满足全部概率>0

  • 分母部分,满足总概率 = 1

假设我们有三个分类,经过线性之后我们有了三个输出值(0.2,0.1,-0.1),之后经过exp、sum、divide这三步,最终得到这三个类的概率y hat

在这里插入图片描述

Loss function

我们使用one-hot独热编码来解决多分类这个问题,只会保存Y=1的项

Loss ⁡ ( Y ^ , Y ) = − Y log ⁡ Y ^ \operatorname{Loss}(\hat{Y}, Y)=-Y \log \hat{Y} Loss(Y^,Y)=YlogY^

在torch中有这么一个损失,NLLLoss(Negative Log Likelihood Loss),这个函数的功能是:Y输入的就是标签号

在Numpy中的Cross Entropy

在这里插入图片描述

在PyTorch中的Cross Entropy,提供了交叉熵损失这个函数

在这里插入图片描述

我们来看一个具体的例子,加入有三个样本,分别属于(2,0,1)类

第一个预测中Y_pred1对应得分类都比较准确,所以损失会小

第二个预测中Y_pred2对应得分类都不准确,所以损失会比较大

在这里插入图片描述

交叉熵损失和NLL损失之间的关系

在这里插入图片描述

MNIST Dataset

在数据集中一个图像是28*28=784个像素点组成的,每一个像素点的取值是0-255。

做一个线性映射到0-1的区间,我们可以看到在矩阵中就表示了图的形状

在这里插入图片描述

多分类实现MNIST Dataset

按照四个步骤,最后要加上测试集

在这里插入图片描述

工具包部分

其中用到的transforms针对图像进行一些处理,还用到了relu激活函数所需要用到torch.nn.functional

在这里插入图片描述

Prepare Dataset

用transform把原始PIL的图像转换为Tensor的图像格式

在这里插入图片描述

这个过程就可以用transforms中的ToTensor来实现

在这里插入图片描述

其中Normalize中第一个(0.1307,)就是求mean,第二个(0.3081,)就是std标准化,这两个值是在计算了整个数据集的mean和std得到的结果,所用到的归一化方程如下

P i x e l norm  =  Pixel  origin  −  mean   std  Pixel _{\text {norm }}=\frac{\text { Pixel }_{\text {origin }}-\text { mean }}{\text { std }} Pixelnorm = std  Pixel origin  mean 

Design Model

输入图像为(N,1,28,28)其中有N个样本。

第一步就是把(1,28,28)这个三阶张量变成向量,用view函数来改变张量的形状(-1,784)二阶张量,第一个值是-1代表自动去算它的值是多少,比如N为64,则把-1 变为64,第二个值是图像的像素点,最后我们拿到的是N*784的矩阵,经过一系列的层,输出层得到(N,10),10个类

在这里插入图片描述

我们来看一下代码

在这里插入图片描述

Construct Loss and Optimizer

在criterion中我们使用交叉熵损失CrossEntropyLoss

在optimizer中我们使用更好的优化方法,带有冲量momentum(相当于是赋予梯度惯性,让它尽可能跳出局部最低点),设置为0.5来优化训练过程

在这里插入图片描述

Train and Test

将epoch封装在train函数中,输出300次迭代输出一次损失

在这里插入图片描述

优化器在优化之前就选择清零

在test函数中,我们只需要计算前向传播

我们在做完预测得到输出矩阵,每个样本都有一行,一行有10个量,我们要求最大值的下标是多少,使用torch.max,dim=1指的是行数,反之为0,指列数。返回的值是两个,每一行的最大值和每一行最大值的下标

最后可以计算准确率=正确数/总数

在这里插入图片描述

在训练的过程中只需要调用函数即可,也可以每十轮输出一次测试

在这里插入图片描述

在输出中我们可以看到loss在减少,accuracy在上升,但可能会存在极限

在这里插入图片描述

完整代码

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

# prepare dataset

batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)


# design model using class

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)  # -1其实就是自动获取mini_batch
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)  # 最后一层不做激活,不进行非线性变换


model = Net()

# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)


# training cycle forward, backward, update

def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        # 获得一个批次的数据和标签
        inputs, target = data
        optimizer.zero_grad()
        # 获得模型预测结果(64, 10)
        outputs = model(inputs)
        # 交叉熵代价函数outputs(64,10),target(64)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d,%5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0


def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度
            total += labels.size(0)
            correct += (predicted == labels).sum().item()  # 张量之间的比较运算
    print('accuracy on test set: %d %% ' % (100 * correct / total))


if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

运行结果:

[1,  300] loss: 2.161
[1,  600] loss: 0.838
[1,  900] loss: 0.429
accuracy on test set: 89 % 
[2,  300] loss: 0.329
[2,  600] loss: 0.269
[2,  900] loss: 0.237
accuracy on test set: 93 % 
[3,  300] loss: 0.192
[3,  600] loss: 0.179
[3,  900] loss: 0.157
accuracy on test set: 95 % 
[4,  300] loss: 0.139
[4,  600] loss: 0.126
[4,  900] loss: 0.119
accuracy on test set: 96 % 
[5,  300] loss: 0.096
[5,  600] loss: 0.098
[5,  900] loss: 0.101
accuracy on test set: 96 % 
[6,  300] loss: 0.080
[6,  600] loss: 0.077
[6,  900] loss: 0.078
accuracy on test set: 97 % 
[7,  300] loss: 0.063
[7,  600] loss: 0.064
[7,  900] loss: 0.064
accuracy on test set: 97 % 
[8,  300] loss: 0.051
[8,  600] loss: 0.058
[8,  900] loss: 0.048
accuracy on test set: 97 % 
[9,  300] loss: 0.041
[9,  600] loss: 0.044
[9,  900] loss: 0.045
accuracy on test set: 97 % 
[10,  300] loss: 0.033
[10,  600] loss: 0.036
[10,  900] loss: 0.036
accuracy on test set: 97 % 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/45259.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hifiasm-meta | 你没看错!基于宏基因组的完成图!!

哈佛大学医学院Dana-Farber癌症研究所李恒课题组重磅推出三代HiFi宏基因组组装软件——hifiasm-meta。研究论文“Metagenome assembly of high-fidelity long reads with hifiasm-meta”预印本在线发布。 宏基因组样本的do novo组装是研究微生物群落的常用方法。与单个物种的组…

RNA-seq 详细教程:分析准备(3)

学习目标 了解 RNA-seq 和差异表达基因的分析流程了解如何设计实验了解如何使用 R 语言进行数据分析1. 简介 在过去的十年中,RNA-seq 已成为转录组差异表达基因和 mRNA 可变剪切分析不可或缺的技术。正确识别哪些基因或转录本在特定条件下的表达情况,是理…

【FreeRTOS(四)】显示任务详细信息

文章目录显示任务详细信息 vTaskList代码示例显示任务详细信息 vTaskList 通过 vTaskList来协助分析操作系统当前 task 状态,以帮助优化内存,帮助定位栈溢出问题。 void vTaskList( char *pcWriteBuffer );parameterdescriptionpcWriteBuffer保存任务状态…

11.21~11.28日学习总结

首先这一周,主要进行的几个事情。 1.在星期一~星期二图书报账的相关事情处理已经完毕,记录了现在图书报账的相关流程,比以前的流程有不少改变,已经写了word记录了流程,给下一任图书管理员做参考。 2.进行了项目的中期…

mysql集群的主从复制搭建

1.master上和slave分别安装好mysql(5.7) 2.按照下面的方式进行安装 3.安装完成后,进行初始化,并找到默认的密码进行登录 4.设置为开机自,并检查状态 5.进行登录,用root账户,密码为生成的那个密码…

C++:STL之Vector实现

vector各函数 #include<iostream> #include<vector> using namespace std;namespace lz {//模拟实现vectortemplate<class T>class vector{public:typedef T* iterator;typedef const T* const_iterator;//默认成员函数vector(); …

SpringBoot项目如何引入外部jar及将外部jar打包到项目发布jar包

1、创建一个SpringBoot项目 下载项目之后将项目导入IDEA 2、如何添加外部jar包 准备一个外部的jar包&#xff0c; 我这里使用的是guava-31.1-jre.jar作为演示 下载地址&#xff1a;https://repo1.maven.org/maven2/com/google/guava/guava/31.1-jre/guava-31.1-jre.jar 在项…

黎曼的几何基础,维度

黎曼的几何基础&#xff0c;让数学领先物理100年&#xff0c;维度是人类最大的障碍 - 知乎 高斯很早就有了“高维几何”的想法&#xff0c;他曾经向他的同事们说起假想完全生活在二维表面上的“书虫”&#xff0c;并想要把这推广到高维空间的几何学中去。然而&#xff0c;由于害…

Java安全编码规范之Web安全漏洞

Java安全编码规范之Web安全漏洞安全现状漏洞案例事件一事件二事件三安全编码规范之常见的安全漏洞敏感数据编码概述漏洞危害常见关键字举例解决方案代码硬编码秘钥错误示例日志打印导致的敏感信息泄露漏洞概述关键字举例解决方案代码中在日志打印token 错误示例文件上传漏洞概述…

CSDN客诉周报第12期|修复10个重大bug,解决29个次要bug,采纳1个用户建议

听用户心声&#xff0c;解用户之需。hello&#xff0c;大家好&#xff0c;这里是《CSDN客诉周报》第12期&#xff0c;接下来就请大家一同回顾我们最近几周解决的bug&#xff5e; 一、重大问题 1、【博客】主页无法访问 反馈量&#xff1a;80 发生时间&#xff1a;10月30日下…

外汇天眼:Axi收回在RGT Capital的全部控制权,Eurotrader获得FCA牌照

在过去的一周里&#xff0c;国外外汇市场上有哪些值得关注的新闻&#xff0c;跟着天眼君一起了解下吧~具体新闻如下&#xff1a; 1、Axi收回在RGT Capital的全部控制权 据天眼君了解&#xff0c;总部位于澳大利亚的零售外汇和差价合约经纪商Axi在澳大利亚投资公司RGT Capital的…

AutoDL使用手册

官网&#xff1a;AutoDL-品质GPU租用平台-租GPU就上AutoDL 1.服务器购买 2.新建虚拟环境 conda create -n tf python3.8 # 构建一个虚拟环境&#xff0c;名为&#xff1a;tf conda init bash && source /root/.bashrc # 更新bashrc中的环境变量 conda acti…

【Flink】使用水位线实现热门商品排行以及Flink如何处理迟到元素

文章目录一 WaterMark1 水位线特点总结2 实时热门商品【重点】&#xff08;1&#xff09;数据集&#xff08;2&#xff09;实现思路a 分流 - 开窗 - 聚合分流&#xff1a;开窗&#xff1a;聚合&#xff1a;b 再分流 -- 统计再分流&#xff1a;统计&#xff1a;&#xff08;3&am…

【Hack The Box】Linux练习-- Seventeen

HTB 学习笔记 【Hack The Box】Linux练习-- Seventeen &#x1f525;系列专栏&#xff1a;Hack The Box &#x1f389;欢迎关注&#x1f50e;点赞&#x1f44d;收藏⭐️留言&#x1f4dd; &#x1f4c6;首发时间&#xff1a;&#x1f334;2022年9月7日&#x1f334; &#x1f…

SpringBoot结合Liquibase实现数据库变更管理

《从零打造项目》系列文章 工具 比MyBatis Generator更强大的代码生成器 ORM框架选型 SpringBoot项目基础设施搭建 SpringBoot集成Mybatis项目实操 SpringBoot集成MybatisPlus项目实操 SpringBoot集成Spring Data JPA项目实操 数据库变更管理 数据库变更管理&#xff1a;Li…

[附源码]Python计算机毕业设计Django的党务管理系统

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

行为型模式-命令模式

package per.mjn.pattern.command;import java.util.HashMap; import java.util.Map;// 订单类 public class Order {// 餐桌号码private int diningTable;// 点的餐品和份数private Map<String, Integer> foodDir new HashMap<>();public int getDiningTable() {…

[附源码]计算机毕业设计springboot高校车辆管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

家居建材企业竞争白热化,如何通过供应商协同系统转型升级,提高核心竞争力

伴随房地产高景气时代逐渐退去&#xff0c;新房销售红利期或已接近尾声&#xff0c;家居建材需求正迈入平稳发展新阶段&#xff0c;企业之间竞争更加白热化。在面对数字化时代的快速发展&#xff0c;很多家居建材企业已达成这样的共识&#xff1a;数字化是企业未来发展的必由之…

人工智能岗位可以考什么证书?考试难不难?

最近几年人工智能在市场上的热度越来越大&#xff0c;很多企业都会利用这个项目来发展自己新渠道&#xff0c;那么想进入这一行的人需要怎么提升自己的技能呢&#xff1f;那就是考取人工智能相关的证书&#xff0c;目前阿里云人工智能是国内市场最热门的认证分为两个等级&#…