从0开始深度学习(8)——softmax回归

news2024/11/26 12:30:07

1 分类问题

深度学习从大方向上来说,就是回归预测和分类问题。

假设输入一个 2 ∗ 2 2*2 22的灰度图像,可能属于“鸡、猫、狗”三个类别中的一个,那如何在计算机中表示标签呢?最常见的想法是 y = { 1 , 2 , 3 } y= \{1,2,3\} y={1,2,3},其中的数字分别代表 { 狗 , 猫 , 鸡 } \{ 狗,猫,鸡 \} {,,}

但是一般的分类问题并不与类别之间的自然顺序有关,所以会使用one-hot编码, 即类别对应的分量设置为1,其他所有分量设置为0。

在我们的例子中,标签将是一个三维向量, 其中 ( 0 , 1 , 0 ) (0,1,0) (0,1,0)对应于“猫”、 ( 0 , 0 , 1 ) (0,0,1) (0,0,1)对应于“鸡”、 ( 1 , 0 , 0 ) (1,0,0) (1,0,0)对应于“狗”。

2 网络架构

为了估计所有可能类别的条件概率,我们需要一个有多个输出的模型,每个类别对应一个输出。

我们假设有4个特征和3个可能对的输出,所有有12个标量表示权重 w w w,3个标量表示偏置项 b b b o o o 是预测输出:
在这里插入图片描述

为了解决这种多分类问题,这里使用softmax网络:
在这里插入图片描述
softmax网络是一个有多个输入、多个输出的单层神经网络,我们使用向量表达式 o = W x + b o=Wx+b o=Wx+b 简洁的表达模型

3 全连接层的参数开销

具体来说,对于任何具有 d d d 个输入和 q q q 个输出的全连接层, 参数开销为 O ( q d ) O(qd) O(qd) ,这个数字在实践中可能高得令人望而却步。

幸运的是,将 d d d 个输入转换为 q q q 个输出的成本可以减少到 O ( d q / n ) O(dq/n) O(dq/n), 其中超参数 n n n 可以由我们灵活指定。

4 softmax运算

我们希望模型最后输出的 数据 X 数据X 数据X 对应的 各个标签 各个标签 各个标签 的是一个概率,然后把最大的概率标签视为我们最后的预测结果。要将输出视为概率,我们必须保证在任何数据上的输出都是非负的且总和为1。

此外,我们需要一个训练的目标函数,来激励模型精准地估计概率。 softmax函数会将每个元素转换为一个介于0和1之间的值,同时保证所有输出的概率总和为1,同时让模型保持可导的性质,公式如下:

P = e o i ∑ j = 1 k e o j P = \frac{e^{o_{i} } }{ {\textstyle \sum_{j=1}^{k}}e^{o_{j} } } P=j=1keojeoi

o o o是把输入的特征向量 x x x经过线性变换得到的向量,长度为 k k k P P P表示输入向量 x x x属于类别 i i i的概率

尽管softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定。 因此,softmax回归是一个线性模型。

5 小批量样本的矢量化

为了提高计算效率并充分利用GPU,我们通常会对小批量样本的数据执行矢量计算。

假设我们读取一个批量的样本 X X X,特征维度为 d d d,批量大小为 n n n,我们输出有 q q q 个类别。

所以小批量样本的特征为 X ∈ R n × d X \in \mathbb{R}^{n\times d} XRn×d,权重为 W ∈ R d × q W \in \mathbb{R}^{d\times q} WRd×q,偏置为 W ∈ R 1 × q W \in \mathbb{R}^{1\times q} WR1×q

所以softmax回归的矢量表达式为:
在这里插入图片描述
相对于一次处理一个样本, 小批量样本的矢量化使用了矩阵-向量乘法,可以充分利用GPU进行加速计算。

6 损失函数

在线性回归中,我们使用的是MSE作为损失函数,但之前那个例子是回归预测,这里是分类任务,所以这里使用最大似然估计,概念如下:

假设我们有一个概率模型 P ( x ∣ θ ) P(x∣θ) P(xθ),其中 x x x 是观测数据, θ θ θ 是模型的参数。我们的目标是找到参数 θ θ θ 的最优值,使得观测数据 x x x 出现的概率最大。

6.1 对数似然

softmax函数会输出一个向量 y ^ \hat{y} y^,即“输入的 x x x 对应的每个标签的条件概率”,例如 y ^ 1 = P ( y = 猫 ∣ X ) \hat{y}_{1}=P(y=猫 | X) y^1=P(y=X)

假设整个数据集 { X , Y } \{X,Y\} {X,Y},其中索引为 i i i的样本由特征向量 x ( i ) x^{(i)} x(i)和one-hot编码 y ( i ) y^{(i)} y(i)组成,所以我们可以将估计值和真实值进行比较:
在这里插入图片描述
表示在给定整个数据集的特征 X X X 的情况下,标签 Y Y Y 的概率等于每个样本中在给定该样本特征 x ( i ) x^{(i)} x(i) 的情况下标签 y ( i ) y^{(i)} y(i) 的概率之积。

我们要最大化 P ( Y ∣ X ) P(Y|X) P(YX) ,所以应该取负对数,因为取对数可以把累积转化为累加,同时因为对数函数是单调递增的,对概率取对数后再取负,最小化负对数似然就等价于最大化原始的概率,所以损失函数如下:
在这里插入图片描述

6.2 softmax及其导数

将softmax函数带入损失函数得到:
**注意:**因为是 y y y 独热标签向量,即除了对应真实类别的那个位置为 1 1 1,其余位置都为 0 0 0,所以 ∑ j = 1 q y i = 1 \sum_{j=1}^{q}y_{i}=1 j=1qyi=1
在这里插入图片描述
然后对损失函数求 o j o_{j} oj的导数(为了计算梯度),步骤如下:
在这里插入图片描述

在这里插入图片描述
换句话说,这个导数是我们softmax模型分配的概率与实际发生的情况(由独热标签向量表示)之间的差异。

6.3 交叉熵损失

最后输出的是 ( 0.1 , 0.2 , 0.7 ) (0.1,0.2,0.7) (0.1,0.2,0.7),而不是 ( 0 , 0 , 1 ) (0,0,1) (0,0,1),所以所有标签分布的预期损失值,称为交叉熵损失(cross-entropy loss),它是分类问题最常用的损失之一。

下面将通过介绍信息论来帮助理解交叉熵损失

7 信息论基础

信息论(Information Theory)是研究信息的量化、存储、传输和处理的数学理论,涉及编码、解码、发送以及尽可能简洁地处理信息或数据。

7.1 熵

信息论的核心思想是量化数据中的信息内容。 在信息论中,该数值被称为分布 P P P 的(entropy)。

对于一个随机变量 X X X ,其概率分布为 P ( X = j ) = p ( j ) P(X=j)=p(j) P(X=j)=p(j),即一个事件 X = j X=j X=j 的概率为 P ( j ) P(j) P(j),所以该事件的自信息量(表示一个事件发生所带来的信息量) 被定义为 I ( j ) = − l o g P ( j ) I(j)=-logP(j) I(j)=logP(j)

之所以这样定义,是因为概率 P ( j ) P(j) P(j) 越小, − l o g P ( j ) -logP(j) logP(j) 的值就越大,符合低概率事件带来高信息量的直观理解。

所以熵 H [ P ] H[P] H[P]就是所有可能发生的事件的自信息量的期望:
在这里插入图片描述

7.2 重新审视交叉熵

交叉熵(Cross-Entropy)是信息论中的一个重要概念,用于衡量两个概率分布之间的差异。

假设有两个概率分布 P P P Q Q Q ,其中 P P P 表示真实分布, Q Q Q 表示模型预测的分布,所以交叉熵分布 H ( P , Q ) H(P,Q) H(P,Q) 定义为:
H ( P , Q ) = ∑ x P ( x ) l o g Q ( x ) H(P,Q)=\sum_{x}P(x)logQ(x) H(P,Q)=xP(x)logQ(x)
x x x 表示所有可能的事件或者类别。

8 模型预测和评估

在训练softmax回归模型后,给出任何样本特征,我们可以预测每个输出类别的概率。 通常我们使用预测概率最高的类别作为输出类别。 如果预测与实际类别(标签)一致,则预测是正确的。 在接下来的实验中,我们将使用精度(accuracy)来评估模型的性能。 精度等于正确预测数与预测总数之间的比率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2207339.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

现金1.8kw, 年入150w, 财务不自由...

他的生活状态无疑是许多人梦寐以求的。拥有两套无贷款的房产,家庭和睦,两台车价值约50万元。现金资产高达1800万元,家庭年收入约150万元,职业稳定,属于中层管理阶层。 重要的是,孩子们的成绩优异&#xff0…

ListView的Items绑定和comboBox和CheckBox组合使用实现复选框的功能

为 ListView 控件的内容指定视图模式的方法,参考官方文档。 ComboBox 样式和模板 案例说明:通过checkBox和ComboBox的组合方式实现下拉窗口的多选方式,同时说明了ListView中Items项目的两种绑定方式. 示例: 设计样式 Xaml代码…

机器学习 | 特征选择如何减少过拟合?

在快速发展的机器学习领域,精确模型的开发对于预测性能至关重要。过度拟合的可能性,即模型除了数据中的潜在模式外,还拾取训练集特有的噪声和振荡,这是一个固有的问题。特征选择作为一种有效的抗过拟合武器,为提高模型…

好用,易用,高效,稳定 基于opencv 的 图像模板匹配 - python 实现

在定位、搜索固定界面图块时,经常用到模板匹配,opencv自带的图像模板匹配好用,易用,高效,稳定,且有多种匹配计算方式。 具体示例如下: 模板图: 待搜索图: 具体实现代码…

苹果正式宣布:iPhone全面开放近场通信(Near Field Communication,简称NFC)【使用安全元件提供app内NFC数据交换功能】

文章目录 引言I iPhone的NFC功能开发者用户数据交换的体验革新安全与隐私II 知识扩展:近场通信(NFC)技术钱包NFC开关打开读取NFC标签(NFC tags )权限demo引言 2014年iPhone 6开始,苹果首次引入了NFC功能,但最初只允许自家的Apple Pay进行移动支付。慢慢地适配了交通卡,增…

基于go开发的终端版即时通信系统(c-s架构)

项目架构图 类似一个聊天室一样 整体是一个客户端和服务端之间的并发多线程网络通信,效果可以翻到最后面看。 为了巩固基础的项目练手所以分为9个阶段进行迭代开发 版本⼀:构建基础Server 新建一个文件夹就叫golang-IM_system 第一阶段先将server的大…

沈阳化工大学第十一届程序设计沈阳区竞赛:凿冰 Action(博弈论,思维)

链接:登录—专业IT笔试面试备考平台_牛客网 来源:牛客网 题目描述 北极探险队有新收获了!!! 北极探险队发现了NNN条长度不一的冰柱,由于冰柱里封存有价值的生物,现在需要两名生物学家小A和小…

TON生态小游戏开发:推广、经济模型与UI设计的建设指南

随着区块链技术的快速发展,基于区块链的Web3游戏正引领行业变革。而TON生态小游戏,借助Telegram庞大的用户基础和TON(The Open Network)链上技术,已成为这一领域的明星之一。国内外开发者正迅速涌入,开发和…

如何在算家云搭建Kolors(图像生成)

一、模型介绍 Kolors 是快手 Kolors 团队基于潜在扩散的大规模文本转图片生成模型。经过数十亿个文本-图片对的训练,Kolors 在视觉质量、复杂语义准确性和中英文文本渲染方面均比开源和闭源模型具有显著优势。此外,Kolors 支持中英文输入,在…

C语言基础语法——类型转换

数据有不同的类型,不同类型数据之间进行混合运算时涉及到类型的转换问题。 转换的方法有两种: 自动类型转换(隐式转换):遵循一定的规则,由编译系统自动完成强制类型转换(显示转换)…

http协议概述与状态码

目录 1.HTTP概述 1.1请求报文起始行与开头 ​1.2响应报文起始行 ​ 1.3响应报文开头 ​ 2.http状态协议码 1.HTTP概述 默认端口 80 HTTP超文本传输与协议: 数据请求和响应 传输:将网站的数据传递给用户 超文本:图片 视频等 请求request:打开网站 访问网站 响应r…

Python数据分析-垃圾邮件分类

一、研究背景 随着电子通信技术的飞速发展,电子邮件已经成为人们日常工作和生活中不可或缺的一部分。然而,伴随着这一趋势,垃圾邮件(Spam)的数量也在急剧增加。垃圾邮件不仅会占用用户的邮箱空间,还可能含…

设置dl服务解决github pushTimed out问题

提交代码到GitHub,一直提示提交失败 我们一般是fq挂的dl服务器进行的,而git需要配置下dl,此时我们要将dl服务器对应的IP地址和端口为我们所调用。 查找dl服务器(windows直接搜索dl服务器设置,mac参考官网&#xff09…

【北京迅为】《STM32MP157开发板嵌入式开发指南》-第二十二章 安装VMware Tool 工具

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器,既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构,主频650M、1G内存、8G存储,核心板采用工业级板对板连接器,高可靠,牢固耐…

学习python自动化——pytest单元测试框架

一、什么是pytest 单元测试框架,unittest(python自带的),pytest(第三方库)。 用于编写测试用例、收集用例、执行用例、生成测试结果文件(html、xml) 1.1、安装pytest pip instal…

【ARM Linux驱动开发】嵌入式ARM Linux驱动开发基本步骤

【ARM Linux驱动开发】嵌入式ARM Linux驱动开发基本步骤 文章目录 开发环境驱动开发(以字符设备为例)安装驱动应用程序开发附录:压缩字符串、大小端格式转换压缩字符串浮点数压缩Packed-ASCII字符串 开发环境 首先需要交叉编译器和Linux环境…

豆包PixelDance指南:字节跳动推出的AI视频生成大模型,突破多主体互动难关

豆包PixelDance是由字节跳动旗下火山引擎发布的AI视频生成大模型。它是业界首个突破多主体互动难关的视频生成模型,支持多风格多比例的一致性多镜头生成。PixelDance基于DiT架构,具备高效的DiT融合计算单元,能够实现复杂的多主体运动交互和多…

【高等数学】 一元函数积分学

1. 不定积分的计算 1.1. 基本积分表 知识点 例题 1.2. 凑微分(第一类换元法) 知识点 本质:利用复合函数求导法则的逆运算 第一步,识别或者凑出来复合函数的导函数 如果被积函数具备以下特点: 1.它由两项相乘来表…

《案例》—— OpenCV 实现2B铅笔填涂的答题卡答案识别

文章目录 一、案例介绍二、代码解析 一、案例介绍 下面是一张使用2B铅笔填涂选项后的答题卡 使用OpenCV 中的各种方法进行真确答案识别,最终将正确填涂的答案用绿色圈出,错误的答案不圈出,用红色圈出错误题目的正确答案最终统计正确的题目数…

PCL用KDtree,给搜索到的邻近点上色

用KDtree&#xff0c;给搜索到的邻近点上色。 #include <pcl/io/pcd_io.h> #include <pcl/point_types.h>#include <pcl/search/kdtree.h> // 包含kdtree头文件 #include <pcl/visualization/pcl_visualizer.h> #include <boost/thread/thread.hpp&…