【深度学习】(7)--保存最优模型

news2024/11/20 2:32:14

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset) # 总数据大小
    num_batches = len(dataloader) # 划分的小批次数量
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

    # 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)
        print(model.state_dict().keys()) # 输出模型参数名称cnn
        torch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 2. 保存完整模型(w,b,模型cnn)
        torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)

epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):
    print(f"Epoch {t+1} \n-------------------------")
    train(train_dataloader,model,loss_fn,optimizer)
    test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2166130.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据-148 Apache Kudu 从 Flink 下沉数据到 Kudu

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

Spring Boot房屋租赁平台:现代化解决方案

1 绪论 1.1 研究背景 中国的科技的不断进步,计算机发展也慢慢的越来越成熟,人们对计算机也是越来越更加的依赖,科研、教育慢慢用于计算机进行管理。从第一台计算机的产生,到现在计算机已经发展到我们无法想象。给我们的生活改变很…

Recaptcha2 图像识别 API 对接说明

Recaptcha2 图像识别 API 对接说明 本文将介绍一种 Recaptcha2 图像识别2 API 对接说明,它可以通过用户输入识别的内容和 Recaptcha2验证码图像,最后返回需要点击的小图像的坐标,完成验证。 接下来介绍下 Recaptcha2 图像识别 API 的对接说…

8.12DoG (Difference of Gaussians)

基本概念 不同尺度的高斯模糊图像之间的差异(DoG),用于边缘检测。函数: cv::GaussianBlur() 结合 cv::Laplacian() 或者自定义DoG实现。 在OpenCV中并没有直接提供一个名为“DoG”(Difference of Gaussians)的函数&a…

【学术会议征稿】第四届人工智能、机器人和通信国际会议(ICAIRC 2024)

第四届人工智能、机器人和通信国际会议(ICAIRC 2024) 2024 4th International Conference on Artificial Intelligence, Robotics, and Communication 第四届人工智能、机器人和通信国际会议(ICAIRC 2024)定于2024年12月27-29日…

css 自定义滚动条样式

* { scrollbar-color: auto !important; scrollbar-width: auto; } //滚动条宽高 ::-webkit-scrollbar { width: 4px; height: 4px; background: transparent; } ::-webkit-scrollbar-thumb { //滑块部分 border-radius: 5px; background-color: rgba(32, 224, 254, 1); } ::-…

【Python报错已解决】TypeError: can only concatenate str (not “float“) to str

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 专栏介绍 在软件开发和日常使用中,BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

docker compose的使用

docker compose 1.概述 是 Docker 官方提供的一款开源工具,主要用于简化在单个主机上定义和运行多容器 Docker 应用的过程。它的核心作用是容器编排,使得开发者能够在一个统一的环境中以声明式的方式管理多容器应用的服务及其依赖关系。 也就是说Docker…

用 Django 5 快速生成一个简单 进销存 系统 添加 个打印 按钮

一、前置条件: 1.安装好python 【关联网址】 2. 安装好vscode 【关联网址】 插件 3. 登陆海螺AI【关联网址】 4. 安装好 pip install django 【关联网址】 pip install django -i https://mirrors.aliyun.com/pypi/simple/ 二、开始生成 1. 打开vscode 打开…

[数据库实验五] 审计及触发器

一、实验目的与要求: 1.了解MySQL审计功能及实现方式 2.掌握触发器的工作原理、定义及操作方法 二、实验内容: 注: 在同一个触发器内编写多行代码,需要用结构begin ……end 函数current_user()获得当前登录用户名 1.自动保存…

Linux 应用层自定义协议与序列化

文章目录 一、应用层1、协议2、序列化 && 反序列化3、通过Json库进行数据的序列化 && 反序列化Json::Value类Json::Reader类Json::Writer类 二、为什么read、write、recv、send和Tcp支持全双工?发数据的本质:tcp支持全双工通信的原因&am…

gitlab-runner集成CI/CD完整项目部署

目录 1.环境安装 2.gitlab代码仓库搭建 3.gitlab-runner-安装以及注册 4..gitlab-ci.yml脚本 5.脚本说明 6.build.sh 7.test.sh 8. deploy.sh 9.运行流水线 10.选择流水线分支 11.查看运行阶段 12.查看运行日志 13.查看服务器真实日志 1.环境安装 确保服务器的Java环…

Python_异常机制

软件程序在运行过程中,非常可能遇到刚刚提到的这些问题,我们称之为异常,英文是:Exception,意思是例外。遇到这些例外情况,或者叫异常,我们怎么让写的程序做出合理的处理,安全的退出&…

Footprint Growthly Quest 工具:赋能 Telegram 社区实现 Web3 飞速增长

作者:Stella L (stellafootprint.network) 在 Web3 的快节奏世界里,社区互动是关键。而众多 Web3 社区之所以能够蓬勃发展,很大程度上得益于 Telegram 平台。正因如此,Footprint Analytics 精心打造了 Growthly —— 一款专为 Tel…

Tkinter制作登录界面以及登陆后页面切换

Tkinter制作登录界面以及登陆后页面切换 前言序言1. 由来2. 思路3. 项目结构描述4. 项目实战1. 登录界面实现(代码)2. 首页界面实现(代码)3. 打包build.py(与main.py同级目录)4. 打包安装包 前言 本帖子&a…

【nrm】npm 注册表管理器

nrm是什么 nrm(NPM Registry Manager)是一个用于管理 Node.js 包管理器(如 npm 和 Yarn)的注册表工具。它可以帮助用户快速切换不同的 npm 源,以便于提高包安装的速度和效率,特别是在中国大陆地区&#xf…

Ubuntu23.10下处理libncurses5-dev包的安装问题

Ubuntu23.10下处理libncurses5-dev包的安装问题 导语环境准备问题和解决方案总结参考文献 导语 使用Ubuntu23.10的时候,遇到需要termios的场景,结果发现无论是codeblocks还是系统本身的gcc都无法找到term.h和curse.h,网上找了很多解决方案都…

了解云计算工作负载保护的重要性,确保数据和应用程序安全

云计算de小白 云计算技术的快速发展使数据和应用程序安全成为一种关键需求,而不仅仅是一种偏好。随着越来越多的客户公司将业务迁移到云端,保护他们的云工作负载(指所有部署的应用程序和服务)变得越来越重要。云工作负载保护&…

【stm32】TIM定时器输出比较-PWM驱动LED呼吸灯/舵机/直流电机

TIM定时器输出比较 一、输出比较简介1、OC(Output Compare)输出比较2、PWM简介3、输出比较通道(高级)4、输出比较通道(通用)5、输出比较模式6、PWM基本结构配置步骤:程序代码:PWM驱动LED呼吸灯 7、参数计算8、舵机简介程序代码&am…

nginx 安装(Centos)

nginx 安装-适用于 Centos 7.x [rootiZhp35weqb4z7gvuh357fbZ ~]# lsb_release -a LSB Version: :core-4.1-amd64:core-4.1-noarch Distributor ID: CentOS Description: CentOS Linux release 7.9.2009 (Core) Release: 7.9.2009 Codename: Core# 创建文件…