动手学深度学习——softmax分类

news2024/12/28 23:27:43

1. 分类问题

回归与分类的区别:

  • 回归可以用于预测多少的问题, 比如"预测房屋被售出价格",它是个单值输出。
  • softmax可以用来预测分类问题,例如"某个图片中是猫、鸡还是狗?",这是一个多值输出,输出个数等于类别个数,输出的第i个值表示预测为第i类别的概率。

两者的区别在于是问多少还是问哪一个?

分类可以用来描述下面两个问题:

  1. 样本属于哪个类别
  2. 样本属于每个类别的概率

比较经典的分类问题有:

  1. MNIST数据集,手写数字识别,有0-9十个类别。
  2. ImageNet数据集,从一百万张图片中识别自然物体,有1000个类别。
  3. kaggle上的恶意软件类别识别。
  4. 区分淘宝商品的评论是正面还是负面评论。

2. 分类编码

由于自然语言表示的类别不方便运算,所以为了计算的需要,有必要对类别进行编码。

对于分类问题,最常用的编码方式为一位有效编码,也称为独热编码(one-hot encoding)。它可以表示为一个向量,长度等于类别数量,向量中只有一个特征为1,其它特征均为0。

这里我们以一个图像分类问题为例来讨论, 假设要预测一张图片是猫、鸡还是狗,那么我们对这三种类别进行一位有效编码的形式如下:

  • (1,0, 0)对应于“猫”
  • (0,1,0)对应于“鸡”
  • (0,0,1)对应于“狗”
  1. 正确类别对应的分量设置为1,其它所有分量均为0.
  2. 类别数量等于分量数量(这里的分量是指向量在具体一个维度上的值)

分类问题对模型的要求:正确类的置信度要远远大于非正确类的置信度,即Oy > Oi。

相比具体每个类别的预测值大小,我们更关心正确类别的预测值是否远大于其它非正确类别的预测值,只有这样,才能表明模型能真正区分出正确类别。

3. 网络架构

与线性回归一样,softmax回归也是一个单层神经网络。

接着上面的例子,假设每次输入是一个2*2的灰度图像,我们可以用一个标量表示每个像素值,每个图像对应四个特征[x1,x2,x3,x4]。

我们可以定义输出向量y=[o1,o2,o3], 其中o1、o2、o3分别表示输入i是猫、鸡、狗的预测值大小。

由于我们有4个特征和3个可能的输出类别, 我们将需要12个标量来表示权重w, 3个标量来表示偏置b。则每个类别的计算可以表示为:
在这里插入图片描述

由于计算每个输出o1、o2和o3取决于 所有输入x1、x2、x3和x4, 所以softmax回归的输出层也是全连接层。
在这里插入图片描述
如同线性回归一样,可以将计算公式简洁表示,o = Wx + b。这是将所有权重放到一个W矩阵中。 对于给定数据样本的特征x, 我们的输出y是由权重W与输入特征x进行矩阵-向量乘法再加上偏置b得到。

4. 输出概率化

对于分类问题,我们希望模型的输出yj可以视为它属于类别j的概率,然后只需要选择具有最大输出值的类别argmax(xj,yj) 作为我们的预测即可, 这样能同时方便人脑理解和算术运算。

例如,如果为猫、鸡和狗的概率分别为0.1、0.8和0.1, 因为0.8概率最大,所以我们预测的类别是2,在我们的例子中代表“鸡”。

这里之所以要进行标准化概率计算,而不直接将预测o作为输出,其原因在于将线性层的输出o视作概率会存在一些问题:

  1. 线性层输出没有限制输出数字的总和为1,不符合概率分布。
  2. 根据输入的不同,线性层的输出是可以为负值的,会影响我们的计算。

要将输出视为概率,我们必须保证以下两点:

  1. 在任何类别上的输出都是非负
  2. 所有类别的预测值总和为1。

而softmax函数则正好能够将未规范化的预测变换为非负数并且总和为1,同时让模型保持可导的性质。它的作法为:

  • 对每个未规范化的预测求幂(指数),这样可以确保输出非负
  • 让每个求幂后的结果除以它们的总和,就能确保最终输出的概率值总和为1
    在这里插入图片描述通过对输出向量o进行softmax运算后,预测值就是一个概率分布。

而真实的值经过独热编码后也符合这个特征,因为它也符合概率的特性:

  1. 非负数:只有0和1两种值;
  2. 和为1:只有一个值为1,其它均为0;

这样就得到两个概率:预测值概率和真实值概率。接下来,就可以比较两个概率来作为损失。

5. 损失计算

交叉熵损失:用来衡量两个概率分布之间的差异。

对于分类问题,我们不关心非正确类别的预测值,只关心对正确类的预测值置信度有多大。

假设模型对每个类别的预测概率分别是0.7、0.2和0.1,实际该样本属于第一个类别。交叉熵损失会根据模型对第一个类别的预测概率和实际概率来计算一个损失值。用数学表示如下:

H(p, q) = -Σ p(x) * log(q(x))
  • p(x)表示实际的概率分布,q(x)表示模型预测的概率分布。
  • 前面加负号的目的是为了保证交叉熵为正值。log(q(x))的值通常是小于0的(小于1时,对数为负数),p(x)是一个概率值,介于0和1之间。
  • 交叉熵越小,表示两个概率分布越接近,模型的预测效果越好。

可以把交叉熵H(P,Q)想象为“主观概率为Q
的观察者在看到根据概率P生成的数据时的意外程度”。 当P=Q时,这种意外程度降到最低。

训练的目的:最小化交叉熵来优化模型的参数,使得模型的预测结果更接近于实际标签。

由于真实值p(x)是一个独热编码向量,只有一项为1,其它项均为0,所以这里的交叉熵又可以简写成:
在这里插入图片描述
所以,对于分类问题来说,我们不关心非正确类别的预测值,只关心正确类别的预测值有多大。

而梯度则是预测概率与真实概率之间的差异,损失函数对输出o求导为:
在这里插入图片描述

softmax回归模型训练的目标:给出任何样本特征,我们可以预测每个输出类别的概率。 通常我们使用预测概率最高的类别作为输出类别。 如果预测与实际类别(标签)一致,则预测是正确的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1636406.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Springboot+Vue项目-基于Java+MySQL的入校申报审批系统(附源码+演示视频+LW)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:Java毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计 &…

Tomact安装配置及使用(超详细)

文章目录 web相关知识概述web简介(了解)软件架构模式(掌握)BS:browser server 浏览器服务器CS:client server 客户端服务器 B/S和C/S通信模式特点(重要)web资源(理解)资源分类 URL请求路径(理解)作用介绍格式浏览器通过url访问服务器的过程 服务器(掌握)…

Python | Leetcode Python题解之第57题插入区间

题目: 题解: class Solution:def insert(self, intervals: List[List[int]], newInterval: List[int]) -> List[List[int]]:left, right newIntervalplaced Falseans list()for li, ri in intervals:if li > right:# 在插入区间的右侧且无交集…

ASP.NET数据存储与交换系统设计

摘 要 该系统以Microsoft Visual Studio 2003作为开发工具,选用SQL Server 2000数据库来实现数据存储,并设计开发了一种基于B/S模式的数据存储与交换系统。该系统完成了用户注册管理、后台管理和用户空间管理功能;为每个用户提供了个人的存…

企业家必须提升演讲口才的原因(3篇)

企业家必须提升演讲口才的原因(3篇) **篇:企业家必须提升演讲口才的原因——建立品牌影响力 一、引言 在当今竞争激烈的市场环境中,企业家作为企业的灵魂和代表,其个人形象和品牌影响力对于企业的成功至关重要。而演…

【大语言模型LLM】-基于ChatGPT搭建客服助手(1)

🔥博客主页:西瓜WiFi 🎥系列专栏:《大语言模型》 很多非常有趣的模型,值得收藏,满足大家的收集癖! 如果觉得有用,请三连👍⭐❤️,谢谢! 长期不…

Java File类

1. File类概述 1.1 什么是File类 File是java.io包下作为文件和目录的类。File类定义了一些与平台无关的方法来操作文件,通过调用File类中的方法可以得到文件和目录的描述信息,包括名称、所在路径、读写性和长度等,还可以对文件和目录进行新建…

C语言 | Leetcode C语言题解之第61题旋转链表

题目: 题解: struct ListNode* rotateRight(struct ListNode* head, int k) {if (k 0 || head NULL || head->next NULL) {return head;}int n 1;struct ListNode* iter head;while (iter->next ! NULL) {iter iter->next;n;}int add n…

c#数据库: 9.删除和添加新字段/数据更新

先把原来数据表的sexy字段删除,然后重新在添加字段sexy,如果添加成功,sexy列的随机内容会更新.原数据表如下: using System; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; using System.Linq; using System.…

3D看车有哪些强大的功能?适合哪些企业使用?

3D看车是一种创新的汽车展示方式,它提供了许多强大的功能,特别适合汽车行业的企业使用。 3D看车可实现哪些功能? 1、细节展示: 51建模网提供全套汽车行业3D数字化解决方案,3D看车能够将汽车展示得更加栩栩如生&…

ClickHouse高原理与实践

ClickHouse高原理与实践 1 ClickHouse的特性1.1. OLAP1.2. 列式存储1.3. 表引擎1.4. 向量化执行1.5. 分区1.6. 副本与分片1.7 其他特性 2. ClickHouse模块设计2.1 Parser分析器与Interpreter解释器2.2 Storage2.3 Column与Field2.4 DataType2.4 Block2.5 Cluster与Replication …

C语言.自定义类型:结构体

自定义类型:结构体 1.结构体类型的声明1.1结构体回顾1.1.1结构体的声明1.1.2结构体变量的创建和初始化 1.2结构体的特殊声明1.3结构体的自引用 2.结构体内存对齐2.1对齐规则2.2为什么存在内存对齐2.3修改默认对齐数 3.结构体传参4.结构体实现位段4.1什么是位段4.2位…

【Linux系统编程】30.pthread_exit、pthread_join、pthread_cancel

目录 pthread_exit 参数retval 测试代码1 测试结果 pthread_join 参数thread 参数retvsl 返回值 测试代码2 测试结果 pthread_cancel 参数thread 返回值 测试代码3 测试结果 pthread_exit 退出当前线程。 man 3 pthread_exit 参数retval 退出值。 NULL&#xf…

我用文心4.0给你做了一个“五一旅行助手”!行程规划、实时查询、景区讲解!

大家好,我是木易,一个持续关注AI领域的互联网技术产品经理,国内Top2本科,美国Top10 CS研究生,MBA。我坚信AI是普通人变强的“外挂”,所以创建了“AI信息Gap”这个公众号,专注于分享AI全维度知识…

使用Gradio搭建聊天UI实现质谱AI智能问答

一、调用智谱 AI API 1、获取api_key 智谱AI开放平台网址: https://open.bigmodel.cn/overview 2、安装库pip install zhipuai 3、执行一下代码,调用质谱api进行问答 from zhipuai import ZhipuAIclient ZhipuAI(api_key"xxxxx") # 填写…

Visual studio 2019 编程控制CH341A芯片的USB设备

1、硬件 买了个USB可转IIC、或SPI、或UART的设备,主芯片是CH341A 主要说明USB转SPI的应用,绿色跳线帽选择IIC&SPI,用到CS0、SCK、MOSI、MISO这4个引脚 2、软件 2.1、下载CH341A的驱动 点CH341A官网https://www.wch.cn/downloads/CH34…

OpenCV如何实现背投(58)

返回:OpenCV系列文章目录(持续更新中......) 上一篇:OpenCV直方图比较(57) 下一篇:OpenCV如何模板匹配(59) 目标 在本教程中,您将学习: 什么是背投以及它为什么有用如何使用 OpenCV 函数 cv::calcBackP…

Mac好用又好看的终端iTerm2 + oh-my-zsh

Mac好用又好看的终端iTerm2 1. iTerm2的下载安装2. oh-my-zsh的安装2.1 官网安装方式2.2 国内镜像源安装方式 3. oh-my-zsh配置3.1 存放主题的路径3.2 存放插件的路径3.3 配置文件路径 1. iTerm2的下载安装 官网下载: iTerm2 2. oh-my-zsh的安装 oh-my-zsh是一…

设备能源数据采集新篇章

在当今这个信息化、智能化的时代,设备能源数据的采集已经成为企业高效运营、绿色发展的重要基石。而今天,我们要向大家介绍的就是一款颠覆传统、引领未来的设备能源数据采集神器——HiWoo Box网关! 一、HiWoo Box网关:一站式解决…

C++11:shared_ptr循环引用问题

一、shared_ptr的弊端 struct Listnode {int _val;std::shared_ptr<Listnode> _prev;std::shared_ptr<Listnode> _next;Listnode(int val ):_val(val),_prev(nullptr),_next(nullptr){}~Listnode(){cout << "~Listnode()" << endl;} }; in…