PyTorch 训练集、验证集、测试集、模型存档、正则化项

news2024/12/23 17:19:07

为什么要将数据集划分为三个部分?三个部分的作用?三个部分数据集的比例应如何设定?

另外一种常见的数据集划分方法是将数据集划分为两个部分(训练集和测试集),这种划分方法存在的问题在于,模型利用训练集数据进行训练,测试集数据进行模型泛化性测试。但我们不能利用测试集测试的bad case或者根据测试集的测试精度调整模型的参数。这是因为对于模型来说,数据集应该是只有训练集可见的,其他数据均不可见,如果利用测试集的结果调整模型相对于模型也”看到了“测试集的数据。将数据集划分为是独立同分布的三个部分就可以解决这个问题,将训练集数据用于模型训练,验证集(开发集)数据用于模型调参,测试集数据用于验证模型泛化性。

介绍

训练集

训练集(Training Dataset)是用来训练模型使用的。
在这里插入图片描述

验证集

当我们的模型训练好之后,我们并不知道他的表现如何。这个时候就可以使用验证集(Validation Dataset)来看看模型在新数据(验证集和测试集是不同的数据)上的表现如何。同时通过调整超参数,让模型处于最好的状态
在这里插入图片描述
验证集有2个主要的作用:
1、评估模型效果,为了调整超参数而服务
2、调整超参数,使得模型在验证集上的效果最好
说明:
1、验证集不像训练集和测试集,它是非必需的。如果不需要调整超参数,就可以不使用验证集,直接用测试集来评估效果。
2、验证集评估出来的效果并非模型的最终效果,主要是用来调整超参数的,模型最终效果以测试集的评估结果为准。

在机器学习中,训练集是用来训练模型的数据,而验证集通常是从原始数据集中划分出来的一个子集,用于在训练过程中检查模型的性能,是在过拟合或欠拟合的情况下对模型进行评估和调整的数据。

验证集的主要目的是为了找到一个最佳的模型及参数,使得模型在未知数据上的表现最好。

之前提到,训练集一般会占用60%或80%的比例,对应的验证集则一般会占用20%或10%的比例。划分比例的依据可以根据实际需求和数据集的大小来确定。

通常情况下,我们可以使用随机抽样的方法从原始数据集中划分验证集

验证集在整个模型训练的过程中起着关键的作用,我们从几个方面出发,聊聊其重要性。

1.调整模型的超参数

在机器学习模型中,有许多超参数需要我们参与设置,例如学习率、隐藏层神经元数量等。这些超参数的选择对模型的性能有很大影响。

为了找到最优的超参数组合,我们可以将训练过程分为多个阶段,每个阶段使用不同的超参数组合进行训练。然后,我们可以使用验证集来评估每个阶段模型的性能,从而选择出最优的超参数组合。

2.早停策略

在训练过程中,如果我们发现模型在验证集上的性能不再提高时,可以提前停止训练。

具体来说,我们可以设置一个小的阈值,当模型在连续多个迭代周期内,验证集上的误差没有降低到这个阈值以下时,我们就认为模型已经收敛,可以停止训练。

这样既可以节省训练时间,也可以降低不必要的成本。

3. 防止过拟合

过拟合是机器学习中的一个常见问题,指的是模型在训练集上表现很好,但在测试集上表现较差的现象。这是因为模型过于复杂,学习到了训练集中的一些噪声和异常数据。

为了解决这个问题,我们可以使用验证集来监控模型在训练过程中的性能。

如果发现模型在训练集上的表现越来越好,但在验证集上的表现越来越差,那么我们可以考虑减少模型的复杂度或者增加正则化项,以防止过拟合的发生。

在训练过程中可以得到验证集的Loss或者acc.的曲线,在曲线上就能大致判断发生over-fitting的点,选取在这个点之前的模型的参数作为学习到的参数,能让模型有较好的泛化能力。

可以记录下在每个时间戳(在验证集上测试时)验证集上的performance和模型参数,然后最后再去选取认为最好的模型。这个可以用checkpoint来做。

过拟合:随着模型复杂度和训练Epoch的增加,CNN模型在训练集上的误差降低,但是在测试集上的误差会降低后升高,如下图所示:
在这里插入图片描述
这是由于训练过程过于细致,以至于把样本的特有特征也当做识别特征一起训练了,所以在验证集上会表现得较差。
针对过拟合的情况,可以从两个方面来解决(1)模型训练方式,(2)数据集。
1、模型训练方式

  • 正则化(Regulation)
  • Dropout:随机失活
  • 提前停止训练

2、数据集
合理的从给定的数据集中拆分出训练集和验证集,将大大减低模型过拟合的可能,常用的验证集划分的方法有:

  • 留出法(Hold-out):按一定比例直接将训练集划分为两部分。
  • K折交叉验证(K-flod Cross Validation):将训练集划分为K份,将其中K-1份作为训练集,剩下一份作为验证集,循环K次训练。
  • 自助采样(BootStrap):通过有放回的采样方式得到新的训练集和验证集,这样每次的训练集和验证集都是有区别的。这对小数据量的训练较为适用。

4. 对比不同模型结构

我们可以通过对比不同模型结构在验证集上的性能,选择最适合任务的模型结构。

这有助于避免选择过于简单或过于复杂的模型,从而提高模型的实际效果。

比如,比较卷积神经网络(CNN) 和循环神经网络(RNN) 在情感分析任务上的性能。通过观察它们的表现,选择更适合处理文本数据的模型结构。

测试集

当我们调好超参数后,就要开始「最终考试」了。我们通过测试集(Test Dataset)来做最终的评估。

通过测试集的评估,我们会得到一些最终的评估指标,例如:准确率、精确率、召回率、F1等。

扩展阅读:分类模型评估指标:Accuracy、Precision、Recall、F1、ROC曲线、AUC、PR曲线

图解

训练集&测试集 模式

在这里插入图片描述

需要注意的是这里C点所在的测试集一定只能用来测试,不能用来训练
在这里插入图片描述

训练集&验证集&测试集 模式

Training set & Validation set & Test set
在这里插入图片描述
既然有多种模型可选,且每种模型的参数可调。那么我们如何选择最好的模型呢?
——上一节我们提到测试集永远只能用来测试,因此我们将训练集拆成两个部分——训练集 -> 训练集+验证集
在这里插入图片描述
在训练集训练模型得到模型参数后,验证集就要对这个在各种超参数下得到的模型进行评估,找到一组最优的超参数。然后将超参数固定,再拿到整个训练集上重新训练模型,反复迭代比较,获取到基本的模型。最后,由测试集评估最终模型的性能

如何合理的划分数据集?

在这里插入图片描述
下面的数据集划分方式主要针对「留出法」的验证方式,除此之外还有其他的交叉验证法,详情见下文 — — 交叉验证法。

数据划分的方法并没有明确的规定,不过可以参考3个原则:

1、对于小规模样本集(几万量级),常用的分配比例是 60% 训练集、20% 验证集、20% 测试集。
2、对于大规模样本集(百万级以上),只要验证集和测试集的数量足够即可,例如有 100w 条数据,那么留 1w 验证集,1w 测试集即可。1000w 的数据,同样留 1w 验证集和 1w 测试集。
3、超参数越少,或者超参数很容易调整,那么可以减少验证集的比例,更多的分配给训练集。

交叉验证法

为什么要用交叉验证法?

假如我们教小朋友学加法:1个苹果+1个苹果=2个苹果

当我们再测试的时候,会问:1个香蕉+1个香蕉=几个香蕉?

如果小朋友知道「2个香蕉」,并且换成其他东西也没有问题,那么我们认为小朋友学习会了「1+1=2」这个知识点。

如果小朋友只知道「1个苹果+1个苹果=2个苹果」,但是换成其他东西就不会了,那么我们就不能说小朋友学会了「1+1=2」这个知识点。

评估模型是否学会了「某项技能」时,也需要用新的数据来评估,而不是用训练集里的数据来评估。这种「训练集」和「测试集」完全不同的验证方法就是交叉验证法。

3 种主流的交叉验证法

在这里插入图片描述
留出法(Holdout cross validation)
上文提到的,按照固定比例将数据集静态的划分为训练集、验证集、测试集的方式就是留出法。

留一法(Leave one out cross validation)
每次的测试集都只有一个样本,要进行 m 次训练和预测。 这个方法用于训练的数据只比整体数据集少了一个样本,因此最接近原始样本的分布。但是训练复杂度增加了,因为模型的数量与原始数据样本数量相同。 一般在数据缺乏时使用。

k 折交叉验证(k-fold cross validation)
静态的「留出法」对数据的划分方式比较敏感,有可能不同的划分方式得到了不同的模型。「k 折交叉验证」是一种动态验证的方式,这种方式可以降低数据划分带来的影响。具体步骤如下:
1、将数据集分为训练集和测试集,将测试集放在一边
2、将训练集分为 k 份
3、每次使用 k 份中的 1 份作为验证集,其他全部作为训练集。
4、通过 k 次训练后,我们得到了 k 个不同的模型。
5、评估 k 个模型的效果,从中挑选效果最好的超参数
6、使用最优的超参数,然后将 k 份数据全部作为训练集重新训练模型,得到最终模型。
在这里插入图片描述
k 一般取 10。数据量小的时候,k 可以设大一点,这样训练集占整体比例就比较大,不过同时训练的模型个数也增多。 数据量大的时候,k 可以设小一点。

PyTorch中的 训练和验证

首先应该设置模型的状态:如果是训练状态,那么模型的参数应该支持反向传播的修改;如果是验证/测试状态,则不应该修改模型参数。在PyTorch中,模型的状态设置非常简便,如下的两个操作二选一即可:

model.train()   # 训练状态
model.eval()   # 验证/测试状态
# 构造训练集
train_loader = torch.utils.data.DataLoader(
	train_dataset,
	batch_size=10,
	shuffle=True,
	num_workers=10, )

# 构造验证集
val_loader = torch.utils.data.DataLoader(
	val_dataset,
	batch_size=10,
	shuffle=False,
	num_workers=10, )

model = SVHN_Model1()  # 加载模型
criterion = nn.CrossEntropyLoss (size_average=False) # 损失函数
optimizer = torch.optim.Adam(model.parameters(), 0.001) # 优化器
best_loss = 1000.0 # 最优损失

for epoch in range(20):
	print('Epoch: ', epoch)
	train(train_loader, model, criterion, optimizer, epoch)
	val_loss = validate(val_loader, model, criterion)

	# 记录下验证集精度
	if val_loss < best_loss:
		best_loss = val_loss
		torch.save(model.state_dict(), './model.pt')

def train(train_loader, model, criterion, optimizer, epoch):
	# 切换模型为训练模式
	model.train()
	for i, (input, target) in enumerate(train_loader):
		c0, c1, c2, c3, c4, c5 = model(data[0])
		loss = criterion(c0, data[1][:, 0]) + \
				criterion(c1, data[1][:, 1]) + \
				criterion(c2, data[1][:, 2]) + \
				criterion(c3, data[1][:, 3]) + \
				criterion(c4, data[1][:, 4]) + \
				criterion(c5, data[1][:, 5])
	loss /= 6
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()

验证/测试的流程基本与训练过程一致,不同点在于:

  • 需要预先设置torch.no_grad,以及将model调至eval模式
  • 不需要将优化器的梯度置零
  • 不需要将loss反向回传到网络
  • 不需要更新optimizer
def validate(val_loader, model, criterion):
	# 切换模型为预测模型
	model.eval()
	val_loss = []
	# 不记录模型梯度信息
	with torch.no_grad():
			for i, (input, target) in enumerate(val_loader):
				c0, c1, c2, c3, c4, c5 = model(data[0])
				loss = criterion(c0, data[1][:, 0]) + \
						criterion(c1, data[1][:, 1]) + \
						criterion(c2, data[1][:, 2]) + \
						criterion(c3, data[1][:, 3]) + \
						criterion(c4, data[1][:, 4]) + \
						criterion(c5, data[1][:, 5])
				loss /= 6
				val_loss.append(loss.item())
	return np.mean(val_loss)

保存模型

# 模型保存到 ./model.pt文件下
torch.save(model_object.state_dict(), 'model.pt')



# 加载模型 ./model.pt
model.load_state_dict(torch.load(' model.pt'))
# 或
checkpoint = torch.load( 'best_model.pt' ) 
model.load_state_dict(checkpoint)

模型训练通常需要花很久的时间,有时候会遇到被断电或各种中止程式的状况,所以随时存档是件重要的事,这样就能接续训练。但如果要接续训练,存档就不能只存模型,要把优化器一起储存,因为它会记录一些训练资讯(像是动量之类的值)。要同时存模型和优化器,可以用「字典」的方式存,简单来讲就是用大括号:{ }。范例如下:

torch.save({ 'model_state_dict' : model.state_dict(), 
           'optimizer_state_dict' : optimizer.state_dict() 
           }, 'best_model.pt' )

可以看到模型参数model.state_dict()存在’model_state_dict’标签下;优化器参数optimizer.state_dict()存在’optimizer_state_dict’标签下,标签名称可以随便设定没关系,要存多一点资讯进去也行,像是epoch、训练误差…等。

要读取模型和优化器,可以执行下面的程式:

checkpoint = torch.load( 'best_model.pt' ) 
model.load_state_dict(checkpoint[ 'model_state_dict' ]) 
optimizer.load_state_dict(checkpoint[ 'optimizer_state_dict' ])

回到刚刚的问题,储存最佳模型。我会先设定一个变数best_loss,用来记录最低的验证误差,我会把值设定的很大(如:999999),如果计算出来的误差比这个值更低,去把误差值取代,并存下模型,然后一直重复。这样存下来的模型就是验证误差最小的模型。程式码如下:

if val_loss < best_loss: 
    best_loss = val_loss 
    torch.save({ 'epoch' : epoch, 
                'model_state_dict' : model.state_dict(), 
               'optimizer_state_dict' : optimizer.state_dict() 
               }, 'best_model.pt' )

PyTorch中使用正则化项

L2 Regularization

若使用L2正则化项:
在这里插入图片描述
只要直接在训练前为optimizer设置正则化项的λ \lambdaλ参数(这里不叫Regularization而是用了Weight Decay这个叫法):

optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)

正则化项目是用来克服over-fitting的,如果网络本身就没有发生over-fitting,那么设置了正则化项势必会导致网络的表达能力不足,引起网络的performance变差。

L1 Regularization

若使用L1正则化项,即对所有参数绝对值求和再乘以一个系数:
在这里插入图片描述
在PyTorch中还没有直接设置L1范数的方法,可以在训练时Loss做BP之前(也就是.backward()之前)手动为Loss加上L1范数:

		# 为Loss添加L1正则化项
		L1_reg = 0
		for param in net.parameters():
			L1_reg += torch.sum(torch.abs(param))
		loss += 0.001 * L1_reg  # lambda=0.001

Reference

【小萌五分钟】机器学习 | 数据集的划分(一): 训练集及测试集
一文看懂 AI 数据集:训练集、验证集、测试集(附:分割方法+交叉验证)
Pytorch实现模型训练与验证
深入浅出PyTorch
训练、验证、存档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2230086.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flask轻松上手:从零开始搭建属于你的Web应用

目录 一、准备工作 二、安装Flask 三、创建你的第一个Flask应用 创建一个新的Python文件 编写Flask应用代码 运行Flask应用 四、创建一个简单的博客系统 定义路由和文章列表 创建模板文件 运行并测试博客系统 五、使用数据库存储用户信息 安装Flask-SQLAlchemy 修…

游戏启动失败:8种修复xinput1_3.dll错误的几种方法教程,轻松解决xinput1_3.dll错误

当你准备好在一天的工作后放松一下&#xff0c;启动你最爱的游戏&#xff0c;却突然收到一个“xinput1_3.dll 丢失”的错误消息&#xff0c;这无疑是令人沮丧的。幸运的是&#xff0c;xinput1_3.dll丢失问题通常可以通过几个简单的步骤来解决。本文将详细介绍这些步骤&#xff…

Halcon-模板匹配(WPF)

halcon的代码 dev_open_window (0, 0, 512, 512, black, WindowHandle) read_image (Image, C:/Users/CF/Desktop/image.jpg) dev_display (Image)draw_rectangle1 (WindowHandle, Row1, Column1, Row2, Column2) gen_rectangle1 (Rectangle, Row1, Column1, Row2, Column2) r…

《AI从0到0.5》之提示工程

参考资料&#xff1a;《AI提示工程&#xff1a;基础 应用 实例》万欣 主要内容&#xff1a; 该文章是对《AI提示工程&#xff1a;基础 应用 实例》这本书的浓缩整理&#xff0c;旨在让读者快速的了解AI提示工程的概念和设计原则、策略和技巧、部分应用案例。并结合笔者自…

@FISCO BCOS的朋友们,年度生态大会邀您查收成果集结令

七载春秋&#xff0c;繁星相映。站在开源七周年的重要节点上&#xff0c;FISCO BCOS年度生态大会再次面向全社区发出产业数字化成果集结令&#xff0c;邀请FISCO BCOS的朋友们于今年12月份共探区块链产业的发展现状与未来。 作为深圳国际金融科技节的重要组成部分和特色活动&a…

Linux文件清空的五种方法总结分享

简介&#xff1a; 每种方法各有优势&#xff0c;选择最合适的一种或几种&#xff0c;可以极大提高您的工作效率。更多有关Linux系统管理的技巧与资源&#xff0c;欢迎访问&#xff0c;持续提升您的运维技能。 在Linux操作系统环境下&#xff0c;清空文件内容是日常维护和管理中…

Android文件选择器[超级轻量级FilePicker测试没有问题][挣扎解决自带文件管理器获取不到绝对地址问题而是返回msf%3A1000038197]

超级轻量级FilePicker测试没有问题 本文摘录于&#xff1a;https://blog.csdn.net/gitblog_00365/article/details/141449437只是做学习备份之用&#xff0c;绝无抄袭之意&#xff0c;有疑惑请联系本人&#xff01; 今天真的是发了疯的找文件管理器,因为调用系统自带的文件管理…

向量的基础知识和矩阵向量的坐标旋转

向量的基础&#xff1a; 定义&#xff1a; 既有大小&#xff0c;又有方向的量叫做向量&#xff08;Vector&#xff09;。 在几何上&#xff0c;向量用有向线段来表示&#xff0c;有向线段长度表示向量的大小&#xff0c;有向线段的方向表示向量的方向。其实有向线段本身也是向…

java控制台打印加法口诀

具体代码&#xff1a; public class AdditionTable {public static void main(String[] args) {//add();//add2();//add3();add1();}public static void add(){for(int i2;i<10;i){for(int j1;j<i;j){String format String.format("%-7s",j""(i-j)…

【Deno运行时】深入解析Deno:下一代JavaScript和TypeScript运行时

&#x1f9d1;‍&#x1f4bc; 一名茫茫大海中沉浮的小小程序员&#x1f36c; &#x1f449; 你的一键四连 (关注 点赞收藏评论)是我更新的最大动力❤️&#xff01; &#x1f4d1; 目录 &#x1f53d; 前言1️⃣ Deno简介2️⃣ Deno的核心特性3️⃣ Deno与Node.js的区别4️⃣ …

OpenCV开发笔记(八十二):两图拼接使用渐进色蒙版场景过渡缝隙

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/143432922 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…

数字IC后端实现之Innovus Place跑完density爆涨案例分析

下图所示为咱们社区a7core后端训练营学员的floorplan。 数字IC后端实现 | Innovus各个阶段常用命令汇总 该学员跑placement前density是59.467%&#xff0c;但跑完place后density飙升到87.68%。 仔细查看place过程中的log就可以发现Density一路飙升&#xff01; 数字IC后端物…

一文总结AI智能体与传统RPA机器人的16个关键区别

基于LLM的AI Agent&#xff08;智能体&#xff09;与**RPA&#xff08;机器人流程自动化&#xff0c;Robotic Process Automation&#xff09;**两种技术在自动化任务领域中扮演着至关重要的角色。AI智能体能够借助LLM拥有极高的灵活性&#xff0c;可以实时理解和响应环境的变化…

ES(2)(仅供自己参考)

Java代码的索引库&#xff1a; package cn.itcast.hotel;import lombok.AccessLevel; import org.apache.http.HttpHost; import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; import org.elasticsearch.client.RequestOptions; import org.elasticsea…

【机器学习】24. 聚类-层次式 Hierarchical Clustering

1. 优势和缺点 优点&#xff1a; 无需提前指定集群的数量 通过对树状图进行不同层次的切割&#xff0c;可以得到所需数量的簇。树状图提供了一个有用的可视化-集群过程的可解释的描述树状图可能揭示一个有意义的分类 缺点&#xff1a; 计算复杂度较大, 限制了其在大规模数据…

Rust 力扣 - 2379. 得到 K 个黑块的最少涂色次数

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 本题可以转换为求长度为k的子数组中白色块的最少数量 我们遍历长度为k的窗口&#xff0c;我们只需要记录窗口内的白色块的数量即可&#xff0c;遍历过程中刷新白色块的数量的最小值 题解代码 impl Solution {…

Git创建和拉取项目分支的应用以及Gitlab太占内存,如何配置降低gitlab内存占用进行优化

一、Git创建和拉取项目分支的应用 1. 关于git创建分支&#xff0c; git创建分支&#xff0c;可以通过git管理平台可视化操作创建&#xff0c;也可以通过git bash命令行下创建&#xff1a; A. 是通过git管理平台创建&#xff1a; 进入gitlab管理平台具体的目标项目中&#xff…

ubuntu-开机黑屏问题快速解决方法

开机黑屏一般是由于显卡驱动出现问题导致。 快速解决方法&#xff1a; 通过ubuntu高级选项->recovery模式->resume->按esc即可进入recovery模式&#xff0c;进去后重装显卡驱动&#xff0c;重启即可解决。附加问题&#xff1a;ubuntu的默认显示管理器是gdm3,如果重…

《高频电子线路》 —— 反馈型振荡器

文章内容来源于【中国大学MOOC 华中科技大学通信&#xff08;高频&#xff09;电子线路精品公开课】&#xff0c;此篇文章仅作为笔记分享。 反馈型振荡器基本工作原理 振荡器分类 自激&#xff1a;没有信号输入他激&#xff1a;有信号输入RC振荡器主要产生低频的正弦波&#x…

unity发布webGL

1.安装WebGL板块 打开unity&#xff0c;进入该界面&#xff0c;然后选择圈中图标 选择添加模块 选择下载WebGL Build Support 2.配置项目设置 打开一个unity项目&#xff0c;如图进行选择 如图进行操作 根据自己的情况进行配置&#xff08;也可直接点击构建和运行&#xff09…