9 多分类问题

news2024/11/19 8:46:08

文章目录

    • 问题引入
    • 网络设计
      • 改进网络方法
    • softmax层
    • loss
    • MINIST引入
    • 代码实现

课程内容来源: 链接
课程文本借鉴: 链接
以及Birandaの

突然发现的也挺好:链接

问题引入

前篇中,对糖尿病数据集的问题是一个二分类问题,但实际问题中,二分类问题较少,更多的是以MINIST、CIFAR为例的多分类问题。

在这里插入图片描述

网络设计

把每一个分类作为二分类进行判断。

eg:当输出为1时,对其他的非1输出都规定为0,以此来进行判断。

在这里插入图片描述
但这种情况下,类别之间所存在的互相抑制的关系没有办法体现,当一个类别出现的概率较高时,其他类别出现的概率仍然有可能很高。
换言之,当计算输出为1的概率之后,再计算输出为2的概率时,并不是在输出为非1的条件下进行的,也就是说,所有输出的概率之和实际上是大于1的。
即对于一个多分类问题,其解决方案应该基于如下要求:

每个分类的出现概率大于等于0
P ( y = i ) ≥ 0 P(y=i) \geq 0 P(y=i)0

各个分类出现概率之和为1
∑ i = 0 n P ( y = i ) = 1 \sum_{i=0}^{n} P(y=i) = 1 i=0nP(y=i)=1
综上,多分类输出之间是需要有竞争性的

改进网络方法

改最后的sigmod层为softmax层,来实现多分类问题的基本要求。

在这里插入图片描述

softmax层

假定 Z l Z^l Zl为最后一层线性层的输出, Z i Z_i Zi为第i类的输出。则最终的softmax层函数应为
P ( y = i ) = e z i ∑ j = 0 K − 1 e z i , i ∈ { 0 , ⋯ , K − 1 } P(y=i)=\frac{e^{z_i}}{\sum_{j=0}^{K-1}e^{z_i}}, i \in \{0,{\cdots},K-1\} P(y=i)=j=0K1eziezi,i{0,,K1}

在这里插入图片描述
事实上,对于多分类问题输出,Softmax会先对所有输出进行指数运算,以满足(1)式要求,再对结果进行归一化处理,以满足(2)式要求。

loss

依照前篇所提及的交叉熵相关理论可知,交叉熵的计算公式如下
H ( P , Q ) = − ∑ i = 1 n P ( X i ) l o g ( Q ( X i ) ) H(P,Q) =-\sum^n_{i=1} P(X_i)log(Q(X_i)) H(P,Q)=i=1nP(Xi)log(Q(Xi))
在多分类问题中,该公式可扩展为
H ( P , Q ) = − ∑ i = 1 n ∑ j = 1 m P ( X i j ) l o g ( Q ( X i j ) ) H(P,Q) =-\sum^n_{i=1}\sum^m_{j=1} P(X_{ij})log(Q(X_{ij})) H(P,Q)=i=1nj=1mP(Xij)log(Q(Xij))

由于上述计算过程中 P ( X i j ) P(X_{ij}) P(Xij)非0即1,且有且只能有一个1,因此一个样本所有分类的loss计算过程可以简化为
L o s s = − l o g ( P ( X ) ) = − Y l o g Y ^ Loss = -log(P(X)) = -Ylog \widehat Y Loss=log(P(X))=YlogY
其中, X X X表示事件预测值与实际值相同, Y Y Y表示非0即1的指示变量, Y ^ \widehat Y Y 表示Softmax的输出。
此时 Y Y Y其实是作为独热编码(One-hot)输入的,以对离散的变量进行分类。即只在实际值处为1,其他均为0.

在这里插入图片描述
在这里插入图片描述

MINIST引入

MINIST数据集中每个数字都是一个 28 ∗ 28 = 784 28*28=784 2828=784大小的灰度图,将灰度图中的每个像素值映射到 ( 0 , 1 ) (0,1) (0,1)区间内,可以进行映射。

在这里插入图片描述
在这里插入图片描述

代码实现

包引入

import torch
#组建DataLoader
from torchvision import transforms #图像
from torchvision import datasets
from torch.utils.data import DataLoader
#激活函数和优化器
import torch.nn.functional as F
import torch.optim as optim

数据准备

#Dataset&Dataloader必备
bacth_size = 64
#pillow(PIL)读的原图像格式为W*H*C,原值较大
# 转为格式为C*W*H值为0-1的Tensor
transform = transforms.Compose([
    #变为格式为C*W*H的Tensor
    transforms.ToTensor(),
    #第一个是均值,第二个是标准差,变值为0-1
    transforms.Normalize((0.1307, ), (0.3081, ))
])

train_dataset = datasets.MNIST(root='../dataset/mnist/',
                               train=True,
                               download=True,
                               transform = transform)

train_loader = DataLoader(train_dataset,shuffle=True,batch_size=bacth_size)

test_dataset = datasets.MNIST(root='../dataset/mnist/',
                               train=False,
                               download=True,
                               transform = transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=bacth_size)

模型设计
在这里插入图片描述

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 线性层1,input784维 output512维
        self.l1 = torch.nn.Linear(784, 512)
        # 线性层2,input512维 output256维
        self.l2 = torch.nn.Linear(512, 256)
        # 线性层3,input256维 output128维
        self.l3 = torch.nn.Linear(256, 128)
        # 线性层4,input128维 output64维
        self.l4 = torch.nn.Linear(128, 64)
        # 线性层5,input64维 output10维
        self.l5 = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        # 改变张量形状view\reshape
        # view 只能用于内存中连续存储的Tensor,transpose\permute之后的不能用
        # 变为二阶张量(矩阵),-1用于计算填充batch_size
        x = x.view(-1, 784)
        # relu 激活函数
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        # 第五层不再进行relu激活
        return self.l5(x)

model = Net()

Loss&Optimizer

#交叉熵损失
criterion = torch.nn.CrossEntropyLoss()
#随机梯度下降,momentum表冲量,在更新时一定程度上保留原方向
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

模型训练及测试

def train(epoch):
    running_loss = 0.0
    #提取数据
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        #优化器清零
        optimizer.zero_grad()
        #前馈+反馈+更新
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        #累计loss
        running_loss += loss.item()

        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0.0

def test():
    correct = 0
    total = 0
    #避免计算梯度
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            #取每一行(dim=1表第一个维度)最大值(max)的下标(predicted)及最大值(_)
            _, predicted = torch.max(outputs.data, dim=1)
            #加上这一个批量的总数(batch_size),label的形式为[N,1]
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('Accuracy on test set: %d %%' % (100 * correct/total))
        
if __name__=='__main__':
    for epoch in range(10):
        train(epoch)
        test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/191734.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue-node解决 rollbackFailedOptional: verb npm-session fd23ceb3f5797b77进度条卡住的问题

一、文章引导 #mermaid-svg-qv5tmCFBaoUwQojc {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-qv5tmCFBaoUwQojc .error-icon{fill:#552222;}#mermaid-svg-qv5tmCFBaoUwQojc .error-text{fill:#552222;stroke:#55222…

RabbitMQ常见场景问题

RabbitMQ常见场景问题 文章目录RabbitMQ常见场景问题6种工作模式1.直连模式2.发布订阅模式3.Routing路由模式4.Topic通配符模式5.Header模式6.RPC消息不丢失消息发送到交换机失败1.配置文件开启发布确认2.配置回调函数3.测试4.如何处理失败消息RabbitMQ服务器故障持久化消息发送…

存量房贷利率,一种简单估算其自然年利率调整的方法。

1.摘要2022年过去了,总所周知LPR被多次下调,目前有存量房贷的朋友,如果(普遍)设置的是根据自然年LPR动态调整利率,到2023年2月应该注意到了比较明显的房贷金额变动。这里主要给出一种根据这个变动&#xff…

Plecs电力电子仿真专业教程-第一季 第一节 Plecs简介

Plecs电力电子仿真专业教程-第一季 第一章 Plecs是什么? 第一节 Plecs简介 Plecs是瑞士Plexim GmbH公司开发的系统级电力电子仿真软件PLECS。PLECS是一个用于电路和控制结合的多功能仿真软件,尤其适用于电力电子和传动系统。不管您是工业领域中的开发…

[架构之路-96]:《软件架构设计:程序员向架构师转型必备》-6-需求与用户用例User Case/Senario建模

第6章 需求与用户用例User Case建模备注:严格意义上讲,用户用例属于需求分析领域,不属于架构设计。用户用例是架构设计最重要的输入参考之一。User Case和User Senario是非常重要的描述需求的重要手段6.1 常用的4种用例技术6.1.1 用例图6.1.2…

学习Java开发按此路线规划,从10K到40K全都有了,我就是这样过来的

如果有一天我醒来时,发现自己的几年Java开发经验被抹掉,重新回到了一个小白的状态。我想要重新自学Java,然后找到一份自己满意的Java工作,我想大概只需要6个月的时间就够了,如果顺利的话,4个月也差不多。如…

用光盘怎样重装电脑系统

用光盘怎样重装电脑系统?重装系统,听起来好像很难的样子。其实没那么难,用光盘装还是比较容易的。下面一起看看如何用光盘重装系统吧。 工具/原料: 系统版本:win7 品牌型号:联想yoga13 方法/步骤&#xf…

Vue使用axios发送get请求并携带参数

前言 其实关于Vue使用axios发送get请求并携带参数,我之前写过一篇,但是昨天又发现了另外一种方式,所以就单独写一篇进行总结。 之前写的那篇使用get请求并携带参数都是使用的字符串拼接的方式 感兴趣可以参考: Vue使用axios进行g…

基于Android的校园资产管理系统

需求信息: 管理员用户: 1:用户注册登录:通过手机号码、用户名称以及密码完成用户的注册和登录 2:添加资产:添加资产的编号、名称、归属部门之后生成资产二维码,以及查看添加过的资产信息 3&…

amCharts Javascript Web 5.3.0 Crack

添加新的 JSON 插件,允许您将序列化 (JSON) 配置解析为图表。 2023 年 1 月 31 日 - 16:00新版本 特征 添加了新JSON插件,允许将序列化 (JSON) 配置序列化和解析为图表。 crisp(默认:)false设置已添加到Sprite。如果设…

已经拿到IB成绩的学生,应该怎么为申请大学做准备呢?

2023年将会是过渡的一年,前几年的高分可能一去不复返了,大家心里也是要做好准备。对于今年已经拿到IB成绩的孩子们,应该怎么为申请大学做准备呢?老师也给了大家一些建议。1.如何递交IB成绩给申请的大学?今年1月出成绩的…

Shell + Datax 动态传递时间参数模式

Datax 数据同步模式Shell 脚本实现Datax 数据同步四种模式Datax 数据全量同步模式此脚本省略...Datax 数据实时增量(T1)模式功能:实现前一天日期 00:00:00 至前一天日期 23:59:59 数据同步#!/bin/bash # 切换至增量脚本文件存储目…

[NOI Online #3 入门组] 最急救助

题目描述: 救助中心每天都要收到很多求救信号。收到求救信号后,救助中心会分析求救信号,找出最紧急的求救者给予救助。 求救信号是一个由小写英文字母组成的字符串,字符串中连续三个字符依次组成sos的情况越多(即包含子串sos的数…

【蓝桥杯单片机】工厂灯光控制系统案例解析(小蜜蜂老师基础综合实训)

工厂灯光控制系统案例解析题目流程图关键点复盘参考代码(IO模式)题目 流程图 关键点复盘 设备检测——移位 L1~L8在板子上是从左至右,但是在对P0口赋值时是16进制从高位(L8)—>低位(L0) 根据原理图,LED赋值0亮1灭 为了方便赋值…

OpenShift 4 - 在单节点的 OpenShift 上用 NFS Operator 实现以 RWX 访问存储

《OpenShift / RHEL / DevSecOps 汇总目录》 文本已在 OpenShift Local 4.12 环境中进行验证。 文章目录OpenShift 支持的存储访问模式用 NFS Provisioner Operator 实现 RWX 访问存储安装 NFS Operator解决安装 Operator 过程无法访问谷歌 gcr.io 上的容器镜像配置 NFSProvisi…

《零基础学机器学习》读书笔记三之基本机器学习术语

《零基础学机器学习》读书笔记三之基本机器学习术语 一、机器学习快速上手路径(续) 1.3 基本机器学习术语 1.3.1 特征 特征是机器学习中的输入,原始的特征描述了数据的属性。特征的维度指的是特征的数目。 把向量、矩阵和其他张量的维度统…

React脚手架应用(二)

1、react脚手架 脚手架简介 用来帮助程序员快速创建一个基于xxx库的模板项目 1、包含了所有需要的配置(语法检查、jsx编译、devServer…); 2、下载好了所有相关的依赖; 3、可以直接运行一个简单效果; create-react-a…

加速企业数字化进展,小程序容器来帮忙

近年来,由于新冠疫情,诸多企业面临经济挑战,高效办公常常无法正常保证。在此期间,不少企业纷纷加快了数字化进展。 2021年,在Gartner新型技术成熟度曲线中我们看到:组装式应用、实时事件中心即服务、生成式…

软考高级系统架构师背诵要点---系统安全与系统可靠性分析

系统安全与系统可靠性分析 系统安全: 信息摘要、数字签名、数字信封 被动攻击:收集信息为主,破坏保密性 窃听:用各种可能的合法手段和非法手段窃取系统中的信息资源和敏感信 业务流分析:通过对系统进行长期的监听&a…

Spark08: Spark Job的三种提交模式

一、三种任务提交方式 1. 第一种,standalone模式 基于Spark自己的standalone集群。指定–master spark://bigdata01:7077 2. 第二种,是基于YARN的client模式。 指定–master yarn --deploy-mode client 使用场景:这种方式主要用于测试&am…