深度学习pytorch——减少过拟合的几种方法(持续更新)

news2025/1/10 16:33:19

1、增加数据集

2、正则化(Regularization)

正则化:得到一个更加简单的模型的方法。

以一个多项式为例:

随着最高次的增加,会得到一个更加复杂模型,模型越复杂就会更好的拟合输入数据的模型(图-1),拟合的程度越大,表现在参数上的现象就是高次的系数趋近于0,如果直接将趋近于0的高次去掉,就可以得到一个更加简单的模型,这种方法称为正则化

图-1

 直观的看,经过正则化的模型更加平滑(图-2).

图-2

 正则化的方法:

(1)L1-正则化:在原来的模型基础上加上一个 1-范数(这里使用二分类模型作为示例):

 (2)L2-正则化:在原来的模型基础上加上一个 2-范数(这里使用二分类模型作为示例):

 代码示例:

# L2-正则化
device = torch.device('cuda:0')
net = MLP.to(device)
optimizer = optim.SGD(net.parameters,lr = learning_rate,weight_decay=0.01) #weight_decay=0.01就代表进行L2-正则化
criteoon = nn.CrossEntropyLoss().to(device)
# L1-正则化
# 对于L1-正则化,pytorch并没有提供直接的方法,就只能使用人工去做了
regularization_loss = 0
for param in model.parameters():                # 相求1-范数的总和
    regularization_loss += torch.sum(torch.abs(param))

classify_loss = criteon(logits,target)
loss = classify_loss + 0.01*regularization_loss     # 再将得到的正则损失加入模型损失,其中0.01是1-范数总和前面的系数

optimizer.zero_grad()
loss.backward()
optimizer.step()

3、加入动量(momentum)

动量即惯性——本次向哪移动,还需要考虑上一次移动的方向。

正常更新梯度的公式(公式-1):

公式-1

加入动量之后的公式(公式-2):

公式-2

将z(k+1)带入梯度更新公式,即公式-1减去,其中Z(k)相当于上一次的梯度,系数\alpha和β的大小决定了是当前梯度对方向的决定性大,还是上一梯度对方向的决定性大。

当动量为0时的梯度更新情况(图-3):

图-3

 动量不为0时的梯度更新情况(图-4):

图-4

将图-3和图-4对比,可以得出动量不为0,即考虑上一梯度,梯度更新更加稳定,不会出现巨大的跳跃情况,并且不加动量的没有找到最小点,一直在局部最小值点徘徊,如果加入动量,考虑到上一梯度,可以在一定程度上解决这种情况(图-4是加入动量之后最好的情况)。

代码演示,直接在优化器部分使用momentum属性就可以了,但是如果使用Adam优化器,就不需要添加,因为在Adam优化器内部定义的有momentum属性:

4、学习率(Learning Rate ) 

不同学习率梯度更新情况(图-5):

图-5

当学习率太小的时候,梯度更新比较慢,需要较多次的更新。

当学习率太大的时候,梯度更新比较激烈,找到的极值点Loss太大。

如何找到正确的的学习率?

在训练之初,可以先设置一个较大的学习率加快更新的速度,然后逐步减小学习率,即设置一个动态学习率。

图-6

 从图-6,可以看到有一个突然下降的点,这个点就是学习率训练一些数据之后,学习率突然变小导致的结果。在此之前可以看到Loss趋于不变,可以合理的猜测是因为学习率太大了,出现了来回摇摆不定的情况(图-7):

图-7

 当学习率突然减小,梯度更新变慢,易找到极小点(图-8):

图-8

 代码演示:

 5、dropout

dropout:减少神经元之间的连接,减少模型的学习量。标准的神经网络是全连接的,相比经过dropout的神经网络减少了一些连接(图-9)。

图-9

代码演示,可以使用Dropout方法断开连接,0.5代表断开两层之间的50% :

 这种方法被用在模型训练中,但当模型测试过程中,为了提高test的表现,要结束这个操作,将所有的连接都使用上,可以使用net_dropped.eval()方法结束这个操作,代码演示如下:

6、随机梯度下降 (Strochastic Gradient Descent )

这里的随机并不是指任意,这里面是有一套规则的,是一套映射的关系,即将原来的数据x送入f(x)得到一种分布。经过随机从原数据中得到一组小数据,使用这一小组数据训练模型。

 学习:课时60 Early stopping, dropout等_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1546541.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构与算法】用染色法判定二分图

问题描述 给定一个 n 个点 m 条边的无向图,图中可能存在重边和自环。 请你判断这个图是否是二分图。 输入格式 第一行包含两个整数 n 和 m。 接下来 m 行,每行包含两个整数 u 和 v,表示点 u 和点 v 之间存在一条边。 输出格式 如果给定图…

js中如何用点击地图获取经纬度

要实现在地图上点击并获取被点击地址的经纬度,然后渲染至页面中的功能,你需要首先确保你使用的地图API支持点击事件,并且能够返回点击位置的经纬度。以高德地图(AMap)为例,你可以按照以下步骤实现这个功能&…

wma怎么转换成mp3?无损转换!

WMA(Windows Media Audio)文件格式诞生于微软公司的数字音频技术研发。由于其高压缩性能和较好的音质,在推出初期主要用于Windows Media Player等微软产品。然而,随着MP3格式的盛行,WMA的使用范围逐渐受到限制。 MP3文…

ES6 基础

文章目录 1. 初识 ES62. let 声明变量3. const 声明常量4. 解构赋值 1. 初识 ES6 ECMAScript6.0(以下简称ES6)是JavaScript语言的下一代标准,已经在2015年6月正式发布了。它的目标,是使得」JavaScript语言可以用来编写复杂的大型应用程序,成为…

MySQL 经典练习 50 题 (记录)

前言: 记录一下sql学习,仅供参考基本都对了,不排除有些我做的太快做错了。里面sql不存在任何sql优化操作,只以完成最后输出结果为目的,包含我做题过程和思路最后一行才是结果。 1.过程: 1.1.插入数据 /* SQLyog Ul…

3.学习前后端关联

目录 1.接口类型 2.错误状态码 3.如何定义路由 4.那如何要求前端传入一个JSON数据呢? 4.解决前后端口不同源,跨域问题 1.使用CrossOrigin 2.直接复制代码使用 5.用户登录校验 1.接口类型 POST(新增数据)、PUT(更新更改数据)、GET(查询)、DELET(删除数据) …

Vivado ECO Flow

Vivado ECO流量 重要!ECOs只在设计检查点上工作。ECO布局仅在设计后可用检查点已在Vivado IDE中打开。 工程变更单(ECOs)是对实施后网表的修改意图在对原始设计影响最小的情况下实施更改。Vivado提供ECO流,允许您修改设计检查点、…

Xshell连接虚拟机非常慢

问题: 打开虚拟机连接时发现过了几分钟依然卡着在,但是主机可以ping通虚拟机,虚拟机也可以ping通主机感觉很奇怪,查询后得知需要修改ssh设置 打开配置 vim /etc/ssh/sshd_config 修改配置 找到 UseDNS,去掉前面的#号…

大学宠物医疗试题及答案,分享几个实用搜题和学习工具 #学习方法#笔记#知识分享

大学开学,就意味着又回到了被线性代数、大学物理等测验题折磨的状态了……网站无法手动输入题干公式,初高中用过的搜题软件又都搜不到,想找个答案解析仿佛在大海捞针!不过不用怕,今天小林就把从大学攒到毕业工作都在使…

苍穹外卖项目-01(开发流程,介绍,开发环境搭建,nginx反向代理,Swagger)

目录 一、软件开发整体介绍 1. 软件开发流程 1 第1阶段: 需求分析 2 第2阶段: 设计 3 第3阶段: 编码 4 第4阶段: 测试 5 第5阶段: 上线运维 2. 角色分工 3. 软件环境 1 开发环境(development) 2 测试环境(testing) 3 生产环境(production) 二、苍穹外卖项目介绍 …

ES 进阶知识

索引Index 一个索引就是一个拥有几分相似特征的文档的集合。比如说,你可以有一个客户数据的索引,另一个产品目录的索引,还有一个订单数据的索引。一个索引由一个名字来标识(必须全部是小写字母),并且当我们…

LeetCode 309—— 买卖股票的最佳时机含冷冻期

阅读目录 1. 题目2.解题思路3. 代码实现 1. 题目 2.解题思路 根据题意,每一天有这样几个状态:买入股票、卖出股票、冷冻期、持有股票,因此,我们假设 f 为每天这几个状态下对应的最大收益,由于持有股票时不知道是哪天买…

3.26日总结

1.Fliptile Sample Input 4 4 1 0 0 1 0 1 1 0 0 1 1 0 1 0 0 1 Sample Output 0 0 0 0 1 0 0 1 1 0 0 1 0 0 0 0 题意:在题目输入的矩阵,在这个矩阵的基础上,通过最少基础反转,可以将矩阵元素全部变为0,如果不能达…

dbgpt部署教程,纯小白教程

1.打开git下载zip文件 下载地址: GitHub - eosphoros-ai/DB-GPT at v0.5.0 2.容器部署 2.1 先启动python3.10环境 docker run -itd --name dbgpt1 --gpus all --shm-size"32g" -p 60035:5000 -p 60037:7860 -p 60038:8000 \ -v /home/tmn/OAPD/jiayq/…

住在我心里的猴子:焦虑那些事儿 - 三余书屋 3ysw.net

精读文稿 您好,本期我们解读的是《住在我心里的猴子》。这是一本由患有焦虑症的作家所著,关于焦虑症的书。不仅如此,作者的父母和哥哥也都有焦虑症,而作者的母亲后来还成为了治疗焦虑症的专家。这本书的中文版大约有11万字&#x…

鸿蒙 HarmonyOS应用开发之API:Context

Context 是应用中对象的上下文,其提供了应用的一些基础信息,例如resourceManager(资源管理)、applicationInfo(当前应用信息)、dir(应用文件路径)、area(文件分区&#x…

Git基础(23):Git分支合并实战保姆式流程

文章目录 前言准备正常分支合并1. 创建两个不冲突分支2. 将dev合并到test 冲突分支合并1. 制造分支冲突2. 冲突合并 前言 Git分支合并操作 准备 这里先在Gitee创建了一个空仓库,方便远程查看内容。 正常分支合并 1. 创建两个不冲突分支 (1&#xf…

淘宝app商品数据API接口|item_get_app-获得淘宝app商品详情原数据

获得淘宝app商品详情原数据 API返回值说明 item_get_app-获得淘宝app商品详情原数据 公共参数​​​​​​ 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地…

pta L1-082 种钻石

L1-082 种钻石 分数 5 全屏浏览 切换布局 作者 陈越 单位 浙江大学 2019年10月29日,中央电视台专题报道,中国科学院在培育钻石领域,取得科技突破。科学家们用金刚石的籽晶片作为种子,利用甲烷气体在能量作用下形成碳的等离子体…

网络层介绍,IP地址分类以及作用

IP地址组成: TTL:生存时间 基于ICMP报文 特殊地址: 0.0.0.0-0.255.255.255 1.代表未指定的地址 默认路由 DHCP下发地址的时候,发个报文给DHCP服务器 临时用0.0.0.0借用地址,未指定地址。 2.全网地址:目…