【计算机视觉】Softmax代码实现、过拟合和欠拟合的表现与解决方法

news2025/1/16 14:09:12

Softmax原理


Softmax函数用于将分类结果归一化,形成一个概率分布。作用类似于二分类中的Sigmoid函数。

对于一个k维向量z,我们想把这个结果转换为一个k个类别的概率分布p(z)。softmax可以用于实现上述结果,具体计算公式为:

image-20210825001951092

对于k维向量z来说,其中zi∈Rzi∈R,我们使用指数函数变换可以将元素的取值范围变换到(0,+∞)(0,+∞),之后我们再所有元素求和将结果缩放到[0,1],形成概率分布。

常见的其他归一化方法,如max-min、z-score方法并不能保证各个元素为正,且和为1。

Softmax性质


输入向量x加上一个常数c后求softmax结算结果不变,即:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5nEeYQfX-1673278208682)(null)]

我们使用softmax(x)的第i个元素的计算来进行证明:

image-20210825002106122

函数实现


由于指数函数的放大作用过于明显,如果直接使用softmax计算公式[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-c3Fqizxd-1673278208641)(null)]进行函数实现,容易导致数据溢出(上溢)。所以我们在函数实现时利用其性质:先对输入数据进行处理,之后再利用计算公式计算。具体使得实现步骤为:

  1. 查找每个向量x的最大值c;
  2. 每个向量减去其最大值c, 得到向量y = x-c;
  3. 利用公式进行计算 s o f t m a x ( x ) = s o f t m a x ( x − c ) = s o f t m a x ( y ) softmax(x) = softmax(x-c) = softmax(y) softmax(x)=softmax(xc)=softmax(y)
import numpy as np
def softmax(x, axim=1):
    '''
    x: m*n m个样本,n个分类输出
    return s:m*n
    '''
    row_max = np.max(x, axis=axis) # 计算最大值
    row_max = row_max.reshape(-1, 1) # 将数据展开为m*1的形状,方便使用广播进行作差
    x = x - row_max # 减去最大值
    x_exp = np.exp(x) # 求exp
    s = x_exp / np.sum(x_exp, axis=axis, keepdim=True) # 求softmax
    return s

问题

过拟合和欠拟合的表现和解决方法。

其实除了欠拟合和过拟合,还有一种是适度拟合,适度拟合就是我们模型训练想要达到的状态,不过适度拟合这个词平时真的好少见,在做酷狗音乐的笔试题时还懵逼了一会,居然还真的有这样的说法。

这应该是基础中的基础了,笔试题都做烂了。那就当做今天周末,继续放个假吧……

过拟合

过拟合的表现

模型在训练集上的表现非常好,但是在测试集、验证集以及新数据上的表现很差,损失曲线呈现一种高方差,低偏差状态。(高方差指的是训练集误差较低,而测试集误差比训练集大较多)

过拟合的原因

从两个角度去分析:

  1. 模型的复杂度:模型过于复杂,把噪声数据的特征也学习到模型中,导致模型泛化性能下降
  2. 数据集规模大小:数据集规模相对模型复杂度来说太小,使得模型过度挖掘数据集中的特征,把一些不具有代表性的特征也学习到了模型中。例如训练集中有一个叶子图片,该叶子的边缘是锯齿状,模型学习了该图片后认为叶子都应该有锯齿状边缘,因此当新数据中的叶子边缘不是锯齿状时,都判断为不是叶子。

过拟合的解决方法

  1. 获得更多的训练数据:使用更多的训练数据是解决过拟合问题最有效的手段,因为更多的样本能够让模型学习到更多更有效的特征,减少噪声的影响。

    当然直接增加实验数据在很多场景下都是没那么容易的,因此可以通过数据扩充技术,例如对图像进行平移、旋转和缩放等等。

    除了根据原有数据进行扩充外,还有一种思路是使用非常火热的**生成式对抗网络 GAN **来合成大量的新训练数据。

    还有一种方法是使用迁移学习技术,使用已经在更大规模的源域数据集上训练好的模型参数来初始化我们的模型,模型往往可以更快地收敛。但是也有一个问题是,源域数据集中的场景跟我们目标域数据集的场景差异过大时,可能效果会不太好,需要多做实验来判断。

  2. 降低模型复杂度:在深度学习中我们可以减少网络的层数,改用参数量更少的模型;在机器学习的决策树模型中可以降低树的高度、进行剪枝等。

  3. 正则化方法如 L2 将权值大小加入到损失函数中,根据奥卡姆剃刀原理,拟合效果差不多情况下,模型复杂度越低越好。至于为什么正则化可以减轻过拟合这个问题可以看看这个博客,挺好懂的.。

    添加BN层(这个我们专门在BN专题中讨论过了,BN层可以一定程度上提高模型泛化性能)

    使用dropout技术(dropout在训练时会随机隐藏一些神经元,导致训练过程中不会每次都更新(预测时不会发生dropout),最终的结果是每个神经元的权重w都不会更新的太大,起到了类似L2正则化的作用来降低过拟合风险。)

  4. Early Stopping:Early stopping便是一种迭代次数截断的方法来防止过拟合的方法,即在模型对训练数据集迭代收敛之前停止迭代来防止过拟合。

    Early stopping方法的具体做法是:在每一个Epoch结束时(一个Epoch集为对所有的训练数据的一轮遍历)计算validation data的accuracy,当accuracy不再提高时,就停止训练。这种做法很符合直观感受,因为accurary都不再提高了,在继续训练也是无益的,只会提高训练的时间。那么该做法的一个重点便是怎样才认为validation accurary不再提高了呢?并不是说validation accuracy一降下来便认为不再提高了,因为可能经过这个Epoch后,accuracy降低了,但是随后的Epoch又让accuracy又上去了,所以不能根据一两次的连续降低就判断不再提高。一般的做法是,在训练的过程中,记录到目前为止最好的validation accuracy,当连续10次Epoch(或者更多次)没达到最佳accuracy时,则可以认为accuracy不再提高了。

  5. 集成学习方法:集成学习是把多个模型集成在一起,来降低单一模型的过拟合风险,例如Bagging方法。

    如DNN可以用Bagging的思路来正则化。首先我们要对原始的m个训练样本进行有放回随机采样,构建N组m个样本的数据集,然后分别用这N组数据集去训练我们的DNN。即采用我们的前向传播算法和反向传播算法得到N个DNN模型的W,b参数组合,最后对N个DNN模型的输出用加权平均法或者投票法决定最终输出。不过用集成学习Bagging的方法有一个问题,就是我们的DNN模型本来就比较复杂,参数很多。现在又变成了N个DNN模型,这样参数又增加了N倍,从而导致训练这样的网络要花更加多的时间和空间。因此一般N的个数不能太多,比如5-10个就可以了。

  6. 交叉检验,如S折交叉验证,通过交叉检验得到较优的模型参数,其实这个跟上面的Bagging方法比较类似,只不过S折交叉验证是随机将已给数据切分成S个互不相交的大小相同的自己,然后利用S-1个子集的数据训练模型,利用余下的子集测试模型;将这一过程对可能的S种选择重复进行;最后选出S次评测中平均测试误差最小的模型。

欠拟合

欠拟合的表现

模型无论是在训练集还是在测试集上的表现都很差,损失曲线呈现一种高偏差,低方差状态。(高偏差指的是训练集和验证集的误差都较高,但相差很少)

img

欠拟合的原因

同样可以从两个角度去分析:

  1. 模型过于简单:简单模型的学习能力比较差
  2. 提取的特征不好:当特征不足或者现有特征与样本标签的相关性不强时,模型容易出现欠拟合

欠拟合的解决方法

  1. 增加模型复杂度:如线性模型增加高次项改为非线性模型、在神经网络模型中增加网络层数或者神经元个数、深度学习中改为使用参数量更多更先进的模型等等。
  2. 增加新特征:可以考虑特征组合等特征工程工作(这主要是针对机器学习而言,特征工程还真不太了解……)
  3. 如果损失函数中加了正则项,可以考虑减小正则项的系数 λ \lambda λ

参考资料

过拟合与欠拟合及方差偏差 (这个博客总结地很好,可以看看)
《百面机器学习》
机器学习+过拟合和欠拟合+方差和偏差
如何判断欠拟合、适度拟合、过拟合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/153044.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

程序员不了解这些投简历的巨坑,面试注定一开始就失败!

目录 前言第一阶段:练手第二阶段:冲刺第三阶段:收尾 前言 之前写了两篇文章,给大家介绍了一下如何利用短期的时间,尽可能充分的为面试做准备: 1.《我只是把握好了这3点,1个月后成功拿下大厂…

2023春节祝福系列第一弹(下)(放飞祈福孔明灯,祝福大家身体健康)(附完整源代码及资源免费下载)

2023春节祝福系列第一弹(下) (放飞祈福孔明灯,祝福大家身体健康) (附完整源代码及资源免费下载) 目录 四、画一朵真实的祥云 (1)、画一个渐变的白色径向渐变背景 &a…

外业调查工具助手,照片采集、精准定位、导航、地图查看

你是不是在外业调查时要背着一堆图纸 是不是一不小心图纸污损或丢失,工作又得重做 是不是经常会出现图纸标注的空间不足 是不是外业采集中要携带一大堆繁琐的仪器 是不是每次收集的数据、照片等在整理的过程中发现工作量巨大 是不是经常会出现采集回来的内容跟…

《MySQL 入门教程》第 36 篇 Python 访问 MySQL

本篇我们介绍如何利用 Python DB API 连接和操作 MySQL 数据库,包括数据的增删改查操作、存储过程调用以及事务处理等。 Python 是一种高级、通用的解释型编程语言,以其优雅、准确、 简单的语言特性,在云计算、Web 开发、自动化运维、数据科…

Spark / Java - atomic.LongAccumulator 与 Spark.util.LongAccumulator 计数使用

目录 一.引言 二.atomic.LongAccumulator 1.构造方法 2.使用方法 3.创建并使用 三.Spark.util.LongAccumulator 1.构造方法 2.使用方法 一.引言 使用 Spark 进行大数据分析或相关操作时,经常需要统计某个步骤或多个步骤的相对耗时或数量,java.u…

Java设计模式-适配器模式Adapter

介绍 适配器模式(Adapter Pattern)将某个类的接口转换成客户端期望的另一个接口表示,主的目的是兼容性,让原本 因接口不匹配不能一起工作的两个类可以协同工作。其别名为包装器(Wrapper)适配器模式属于结构型模式主要分为三类:类适配器模式、…

树莓派自带的python3.9->python3.7

卸载python3.9:sudo apt-get remove python3卸载之后一些包可以使用sudo apt autoremove这个命令删除卸载成功如果出现问题后续再来更新(出现问题后后续安装python也会失败)(先不要安装先看)安装python3.7:…

C语言第30课笔记

1.strerror(errno要包含头文件errno.h) 2.perror头文件为stdio.h 3.一些字符函数 4.字母大小写转换函数 5.memmove理论上是memcpy的升级版(可以自己拷贝自己)。 6.匿名结构体类型在类型创建好了之后直接创建变量,只能用一次。两个完全相同的匿名结构体类型&#xf…

【八】Netty HTTP协议--文件服务系统开发

Netty HTTP协议--文件服务系统开发介绍HTTP-文件系统场景描述流程图代码展示netty依赖服务端启动类 HttpFileServer服务端业务逻辑处理类 HttpFileServerHandler结果展示错误路径文件夹路径文件路径遗留bugbug版本总结介绍 由于Netty天生是异步事件驱动的架构,因此…

java EE初阶 — Synchronized 的原理

文章目录1. Synchronized 的优化操作1.1 偏向锁1.2 轻量级锁(自旋锁)1.3 重量级锁2. 其他的优化操作2.1 锁消除2.2 锁粗化3. 相关面试题1. Synchronized 的优化操作 两个线程针对同一个对象加锁,就会产生阻塞等待。 Synchronized 内部其实还有…

ubuntu docker elasticsearch kibana安装部署

ubuntu docker elasticsearch 安装部署 所有操作尽量在root下操作. 安装docker 1. 由于是基于宝塔面板安装的所以简答的点击操作即可完成安装. 我这里已经是正常的安装好了. 2.dcoker 镜像加速 https://cr.console.aliyun.com/cn-hangzhou/instances访问这个网址进去进行了…

快速上手Golang

自动推导赋值:自动推导赋值Go中 不同的数据类型不能进行计算对于浮点型默认都是float64 精确到小数点后15位单引号的 为字节类型 一位0~255的字符转换双引号的 为字符串类型多重赋值多重赋值a,b:1,2格式输出格式输出printf“%3d”三位整数,不满足三位时头部补空格“…

录制课程用什么软件好?3款超好用的课程视频录课软件

在互联网技术的飞速发展下,在线教学已经成为一种新型的教学形式,与传统的教学方法相比,在线教学具有低成本、突破地域、时间灵活、形式多样的教学方式。那录制课程用什么软件好?今天小编就跟大家分享3款超好用的课程视频录课&…

认真研究MySQL的主从复制(一)

【1】主从复制概述 ① 如何提升数据库并发能力 在实际工作中,我们常常将Redis作为缓存与MySQL配合使用,当有请求的时候,首先会从缓存中进行查找。如果存在就直接取出,如果不存在再访问数据库。这样就提升了读取的效率&#xff0…

中国数据库的诸神之战

作者 | 唐小引出品 | 《新程序员》编辑部“现在的数据库产品实在是太多了!”前几天,我和深耕数据库/大数据近 30 年的卢东明老师相聊时,他发出了这样的感慨。将包括 DB-Engines Ranking 以及国内数据库排行等在内的数据库产品列表进行汇总&am…

快速入门Freemarker模块引擎技术

1、 freemarker 介绍 ​ FreeMarker 是一款 模板引擎: 即一种基于模板和要改变的数据, 并用来生成输出文本(HTML网页,电子邮件,配置文件,源代码等)的通用工具。 它不是面向最终用户的,而是一个Java类库&am…

采场的车辆管理及卸料点计数管理有哪些难题需要解决

近期,安环部检查采矿区域工程车辆驾驶人员情况时,发现有部分驾驶员及工作人员存在违规顶替情况,有非注册备案人员驾驶矿用工程车辆违规作业。为了进行统一有效的人员车辆管理,同时能监督安全员定期对采矿作业区进行安全巡查&#…

Camtasia Studio2023喀秋莎新增功能及电脑配置要求介绍

Camtasia Studio2023具有强大的视频播放和视频编辑功能,录制屏幕后,根据时间轴对视频剪辑进行各种标记、媒体库、画中画、画中画、画外音当然,也可以导入现有视频并对其进行编辑操作。编辑完成后,可以将录制的视频输出为最终的视频…

光伏废水深度除氟装置,用于高盐废水除氟的工艺

光伏行业废水根据生产产品可细分为单品硅生产线排水、多品硅生产线排水。其生产工序中有污水排放的工段主要是:制绒和清洗工段。废水中的主要污染物为由异丙醇引起的高浓度COD、氟离子及酸碱污染,其中以含异丙醇的废水一直是水处理中的难题。如果不对废水…

【自学Python】Python input()函数

Python input()函数 Python input()函数教程 在 Python 中,input() 函数用于获取用于的输入,并给出提示。input() 函数,总是返回 string 类型,因此,我们可以使用 input() 函数,获取用户输入的任何数据类型…