【深度学习】(7)--神经网络之保存最优模型

news2024/11/18 14:52:30

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset) # 总数据大小
    num_batches = len(dataloader) # 划分的小批次数量
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

    # 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)
        print(model.state_dict().keys()) # 输出模型参数名称cnn
        torch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 2. 保存完整模型(w,b,模型cnn)
        torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)

epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):
    print(f"Epoch {t+1} \n-------------------------")
    train(train_dataloader,model,loss_fn,optimizer)
    test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2166690.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

今日不错的讲企业架构的好图

今日不错的讲企业架构的好图,来源B站不错的UP主:企业架构知识体系-业务技术管理的知识框架_哔哩哔哩_bilibili

grafana频繁DataSourceError问题

背景 随着 Grafana 数据量的不断增加,逐渐暴露出以下问题: Grafana 页面加载缓慢;Grafana 告警频繁出现 DatasourceError 错误。 对于第一个问题,大家可以参考这篇文章:Grafana 加载缓慢的解决方案。 不过&#xf…

【Unity踩坑】Textmesh Pro是否需要加入Version Control?

问题:如果Unity 项目中用到了Textmesh pro,相关的文件是否也需要签入呢? 回答: 在使用 Unity 的 Version Control(例如 Plastic SCM 或 Git)时,如果你的项目中使用了 TextMesh Pro&#xff0c…

条件字段有索引,为什么查询也这么慢?

如果我们想在某一本书中找到特定的主题,一般最快的方法是先看索引,找到对应的主题在哪个页码。 而对于 MySQL 而言,如果需要查找某一行的值,可以先通过索引找到对应的值,然后根据索引匹配的记录找到需要查询的数据行。…

家政服务预约系统小程序的设计

管理员账户功能包括:系统首页,个人中心,客户管理,员工管理,家政服务管理,服务预约管理,员工风采管理,客户需求管理,接单信息管理 微信端账号功能包括:系统首…

Java | Leetcode Java题解之第430题扁平化多级双向链表

题目: 题解: class Solution {public Node flatten(Node head) {dfs(head);return head;}public Node dfs(Node node) {Node cur node;// 记录链表的最后一个节点Node last null;while (cur ! null) {Node next cur.next;// 如果有子节点&#xff0…

后端(实例)08

设计一个前端在数据库调取数据的表格&#xff0c;并完成基础点击增删改查的功能&#xff1a; 1.首先写一个前端样式&#xff08;空壳&#xff09; <!DOCTYPE html> <html> <head> <meta charset"UTF-8"> <title>Insert title here&l…

VUE条件树查询

看如下图所示的功能&#xff0c;是不是可高级了&#xff1f;什么&#xff0c;你没看懂&#xff1f;拜托双击放大看&#xff01; 是的&#xff0c;我最近消失了一段时间就是在研究这个玩意的实现&#xff0c;通过不懈努力与钻研并参考其他人员实现并加以改造&#xff0c;很好&am…

人工智能时代的网络空间战略稳定及其挑战

文章目录 前言一、人工智能时代的网络空间战略稳定及其挑战(一)国内政治与官僚主义二、大国竞争与溯源政治三、国际法规与治理限制总结前言 人工智能的武器化应用在短期内将同时强化网络空间中进攻方和防御方的能力,但从长期看将有利于防御方。这种态势将令传统威慑逻辑重新…

[数据库实验三]安全性

目录 一、实验目的与要求&#xff1a; 二、实验内容&#xff1a; 三、实验小结 一、实验目的与要求&#xff1a; 1、设计用户子模式 2、根据实际需要创建用户角色及用户&#xff0c;并授权 3、针对不同级别的用户定义不同的视图&#xff0c;以保证系统的安全性 二、实验内…

Springboot jPA+thymeleaf实现增删改查

项目结构 pom文件 配置相关依赖&#xff1a; 2.thymeleaf有点类似于jstlel th:href"{url}表示这是一个链接 th:each"user : ${users}"相当于foreach&#xff0c;对user进行循环遍历 th:if进行if条件判断 {变量} 与 ${变量}的区别: 4.配置好application.ym…

【SemeDrive】【X9H】如何修改 SAFETY_FAULT 输出 PWM 频率

前言&#xff1a; SAFETY_FAULT 也是 SEM_FAULT&#xff0c;在原理图上会有不同的标注&#xff0c;但意义一样。 默认的 SAFETY_FAULT 正常时输出 PWM 频率为 100 MHz&#xff0c;过高的频率有时会导致无法通过 EMI 测试&#xff0c;需要降低频率。以下描述如何将正常时的 S…

ssh 命令详解

一、命令简介 ​ssh ​命令用于安全登录远程主机&#xff0c;以便在远程机上执行命令或传输数据。 ‍ 例如登录远程主机 169.10.222.23 ​上的 soulio ​用户&#xff1a; ssh soulio169.10.222.23更多示例参考第三章。 ‍ 了解背景知识&#xff1a;ssh 加密 1. 加密类型…

C++之Person类中调用Date类

main.cpp #include <iostream> #include "Person.h" using namespace std;int main() {Person myPerson;// Person myPerson("S.M.Wang", 070145, "莲花路200号");cout << "请输入姓名:" ;string name;cin >> name…

【文档智能 RAG】浅看开源的同质化的文档解析框架-Docling

前言 RAG的兴起&#xff0c;越来越多的人开始关注文档结构化解析的效果&#xff0c;这个赛道变得非常的同质化。 关于文档智能解析过程中的每个技术环节的技术点&#xff0c;前期文章详细介绍了很多内容&#xff1a; 下面我们简单的看看Docling这个PDF文档解析框架里面都有什…

尚品汇-自动化部署-Jenkins的安装与环境配置(五十六)

目录&#xff1a; 自动化持续集成 &#xff08;1&#xff09;环境准备 &#xff08;2&#xff09;初始化 Jenkins 插件和管理员用户 &#xff08;3&#xff09;工作流程 &#xff08;4&#xff09;配置 Jenkins 构建工具 自动化持续集成 互联网软件的开发和发布&#xf…

AI:颠覆式创新 vs. 持续性创新

随着有关生成式人工智能 (GenAI) 的新闻不断出现在社交媒体上&#xff0c;包括 ChatGPT 4o 如何帮助你与朋友玩石头、剪刀、布&#xff0c;关于 GenAI 的“颠覆性”影响的惊人声明并不难找到。 事实证明&#xff0c;将 GenAI 本身称为“颠覆性”并没有多大意义。 它能成为颠覆…

libvirt中的qemu与kvm

在 libvirt 虚拟机管理中&#xff0c;domain_type 的设置决定了虚拟机使用的虚拟化技术。在 domain_type 中&#xff0c;qemu 和 kvm 是两种不同的虚拟化模式&#xff0c;它们的区别主要在于是否使用硬件虚拟化加速。 qemu 模式 定义&#xff1a;qemu 是一种完全软件模拟的虚…

Recorder录音插件使用日记

目录 一、安装插件 二、导入文件 1.app-xxx-support.js支持文件 2.RecordApp 三 功能的使用 3.1 请求录音权限 3.2 开始录音 3.3 停止录音 3.4 其他接口 四 、使用 4.1 开始录音实例 4.2 请求录音权限 4.3 停止录音——文件的下载与上传 一、安装插件 npm install…

c++ day06

类的栈 实现 #include <iostream>using namespace std;class Stack { private:static const size_t MAX 100; // 定义固定容量int data[MAX]; // 存储栈元素的数组size_t len; // 当前栈的大小public:// 构造函数Stack() : len…