深度学习模型笔记

news2024/11/15 7:58:27

加载和保存模型参数

保存模型参数

net = MLP()
# 此处省略训练过程,在训练之后,保存模型参数
# 保存字典格式的模型参数,模型参数名
torch.save(net.state_dict(), 'mlp.params') 

加载模型参数

clone = MLP()
# 加载模型参数
clone.load_state_dict(torch.load('mlp.params'))
# 进入评估模式
clone.eval()
# 开始推理
Y_clone = clone(X)

当训练时使用了nn.DataParallel时,正确的加载方法

nn.DataParallel是一种用于多GPU计算的封装模块。它可以将模型和数据分布到多个GPU上进行并行计算,以加速训练过程。在PyTorch中,nn.DataParallel可以接受一个已存在的模型作为输入,并将模型和数据分布到多个GPU上进行并行计算。这种并行计算可以显著减少训练时间,并提高模型的收敛速度。在使用nn.DataParallel时,需要将模型和数据传递给封装后的nn.DataParallel对象,并在训练过程中使用封装后的模型进行前向传播和反向传播。

直接使用上面的加载方式会报错

RuntimeError: Error(s) in loading state_dict for Sequential:

该错误通常与使用了nn.DataParallel进行训练有关

是指模型中的参数key中字符串与torch.load获取的key中字符串不匹配

因此,我们只需要修改torch.load获取的dict,令其匹配。

 

例如:

我torch.save时,参数key中字符串前自动添加了'module.'

因此,在torch.load后,需要去掉'module.'
model = Model()
model_para_dict_temp = torch.load('xxx.pth')
model_para_dict = {}
for key_i in model_para_dict_temp.keys():
    model_para_dict[key_i[7:]] = model_para_dict_temp[key_i]  # 删除掉前7个字符'module.'
del model_para_dict_temp
model.load_state_dict(model_para_dict)

pytorch加载时出现“RuntimeError: Error(s) in loading state_dict for Sequential:”时的解决方案
https://www.cnblogs.com/s-tyou/p/16514693.html

kaggle CIFAR-10图像分类笔记

对数据集的处理

在这里插入图片描述

train: 训练集
train_valid: 训练集中的验证集,每个epoch结束后使用它进行验证。
test: 模型所有轮训练结束后,使用它来进行测试。
valid: 验证集,评估在未见过的数据上的数据集。

损失函数

loss = nn.CrossEntropyLoss(reduction='none')

reduction的字面意思是“减少”或“缩减”,指的是对损失进行某种操作以减小或简化损失的计算。

当reduction='none'时,表示不进行任何归约操作,即每个样本的损失都单独计算并返回。

当reduction='sum'时,表示对所有样本的损失进行求和操作,得到一个标量值,表示所有样本损失的总和。

当reduction='elementwise_mean'时,表示对所有样本的损失进行平均操作,得到一个标量值,表示所有样本损失的平均值。

lr_period, lr_decay = 4, 0.9 的含义

lr_period = 4: lr_period表示学习率更新的周期。设置为4可能意味着在每4个epoch后更新学习率。

lr_decay = 0.9: lr_decay表示学习率的衰减率。这里设置为0.9,意味着每次更新学习率时,学习率会乘以0.9,从而逐渐降低学习率。

图像分类的通用训练函数(训练并保存最优训练结果参数)

该函数的输入为:
train(网络,训练集的dataloader,评估集的dataloader,训练轮数,学习率,权重衰减,设备,更新学习率的周期,学习率衰减)
输出为训练结果的图和损失和准确率的值。

def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,lr_decay):
    trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 设定一个参数 统计最优的训练准确率
    best_train_acc = 0
    for epoch in range(num_epochs):
        net.train()
        metric = d2l.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(net, features, labels,loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[2],None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch + 1, (None, None, valid_acc))
        scheduler.step()
    train_acc = metric[1] / metric[2]
    measures = (f'train loss {metric[0] / metric[2]:.3f}, 'f'train acc {train_acc:.3f}')
    if train_acc>best_train_acc:
        best_train_acc = train_acc
        # 更新模型最优准确率的参数
        torch.save(net.state_dict(),'weights/best_train_acc_params')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
                     f' examples/sec on {str(devices)}')
    print(f'best_train_acc={best_train_acc}')

图像分类模型训练代码

训练并统计训练的时间

'''先训练20轮看看有没有问题'''
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 20, 2e-4, 5e-4
# lr_period表示学习率的更新周期,设置为4表示,每四个epoch后更新学习率, lr_decay表示权重衰减,每次更新学习率,学习率都会乘0.9
lr_period, lr_decay, net = 4, 0.9, get_net()

# 开始计时
start_time = time.time()
# 开始训练
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period,lr_decay)
# 结束计时
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
    print(f'{round(run_time,2)}s')
else:
    print(f'{round(run_time/60,2)}minutes')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1134659.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第11期 | GPTSecurity周报

GPTSecurity是一个涵盖了前沿学术研究和实践经验分享的社区&#xff0c;集成了生成预训练 Transformer&#xff08;GPT&#xff09;、人工智能生成内容&#xff08;AIGC&#xff09;以及大型语言模型&#xff08;LLM&#xff09;等安全领域应用的知识。在这里&#xff0c;您可以…

知识图谱+推荐系统 文献阅读

文献阅读及整理 知识图谱推荐系统 知识图谱 1 基于知识图谱的电商领域智能问答系统研究与实现 [1]蒲海坤. 基于知识图谱的电商领域智能问答系统研究与实现[D].西京学院,2022.DOI:10.27831/d.cnki.gxjxy.2021.000079. 知识点 BIO标记策略进行人工标记,构建了电商领域商品…

【LeetCode 算法专题突破】链表(⭐)

文章目录 前言1. 移除链表元素题目描述代码 2. 设计链表题目描述代码 3. 反转链表题目描述代码 4. 两两交换链表中的节点题目描述代码 5. 删除链表的倒数第 N 个结点题目描述代码 6. 链表相交题目描述代码 7. 环形链表 II题目描述代码 总结 前言 链表题目一向是面试算法考察的…

【2024秋招】万得后端开发java 2023-7-13 2.30pm 一二面面经(附答案)

一面&#xff1a;20min 1 自我介绍 2 微服务架构 1 nacos作为配置中心&#xff0c;如果nacos服务失效了&#xff0c;各个服务之间的调用如何保持高可用呢&#xff1f; 答&#xff1a;nacos注册中心本地有缓存&#xff0c;所以请求来了还是能够正常提供一段时间的服务&#xff…

c语言进制的转换2进制转换8进制

c语言进制的转换之2进制转换8进制 c语言的进制的转换 c语言进制的转换之2进制转换8进制一、八四二一法则二、二进制转换八进制方法 一、八四二一法则 二、二进制转换八进制方法 如&#xff1a;111000110101001转换成八进制 按照八四二一法则 将二进制3个一等分变成&#xff1a…

LVS集群-DR模式【部署高可用LVS-DR集群】

文章目录 2.2 实战&#xff1a;配置LVS-DR集群2.2.1 配置IP2.2.2 生成ens33:1配置文件2.2.3 配置LVS-DR规则2.2.4 两台RealServer的IP配置Alastor62&#xff08;配置IP&#xff1a;192.168.1.62&#xff09;Alastor64&#xff08;配置IP&#xff1a;192.168.1.64&#xff09;客…

Macos视频增强修复工具:Topaz Video AI for mac

Topaz Video AI是一款使用人工智能技术对视频进行增强和修复的软件。它可以自动降噪、去除锐化、减少压缩失真、提高清晰度等等。Topaz Video AI可以处理各种类型的视频&#xff0c;包括低分辨率视频、老旧影片、手机录制的视频等等。 使用Topaz Video AI非常简单&#xff0c;…

C++设计模式_13_Flyweight享元模式

Flyweight享元模式仍然属于“对象性能”模式。 文章目录 1. 动机(Motivation)2. 模式定义3. 结构( Structure)4. 代码演示5. 要点总结6. 其他参考 1. 动机(Motivation) 在软件系统采用纯粹对象方案的问题在于大量细粒度的对象会很快充斥在系统中&#xff0c;从而带来很高的运行…

Web攻防05_MySQL_二次注入堆叠注入带外注入

文章目录 MYSQL-二次注入-74CMS思路描述&#xff1a;注入条件&#xff1a;案例&#xff1a;74CMS个人中心简历功能 MYSQL-堆叠注入-CTF强网思路描述注入条件案例&#xff1a;2019强网杯-随便注&#xff08;CTF题型&#xff09; MYSQL-带外注入-DNSLOG注入条件使用平台带外应用场…

代碼隨想錄算法訓練營|第四十九天|139.单词拆分、关于多重背包、背包问题总结。刷题心得(c++)

目录 讀題 139.单词拆分 自己看到题目的第一想法 看完代码随想录之后的想法 139.单词拆分 - 實作 思路 Code 關於多重背包 與01背包與完全背包的差別 轉化成01背包問題 背包问题总结 背包問題分類 背包問題 - 遞推公式 最多裝多少/能否裝滿 最大價值 裝滿背包有…

OpenFeign实现分析、源码解析

什么是openfeign? 是springcloud全家桶的组件之一&#xff0c;其核心作用是为Rest API提供高效简洁的rpc调用方式。 为什么只定义接口而没有实现类&#xff1f; 源码解读&#xff08;省略&#xff09; 总结&#xff1a; 源码分析&#xff1a;如何发送http请求&#xff1f; …

基于单片机设计的智能窗帘控制系统

一、前言 智能家居技术在近年来取得了巨大的发展&#xff0c;并逐渐成为人们日常生活中的一部分。智能家居系统带来了便利、舒适和高效的生活体验&#xff0c;拥有广泛的应用领域&#xff0c;其中之一就是智能窗帘控制系统。 传统窗帘需要手动操作&#xff0c;打开或关闭窗帘…

微服务-Ribbon负载均衡

文章目录 负载均衡原理流程原理源码分析负载均衡流程 负载均衡策略饥饿加载总结 负载均衡原理 流程 原理 LoadBalanced 标记RestTemplate发起的http请求要被Ribbon进行拦截和处理 源码分析 ctrlshiftN搜索LoadBalancerInterceptor&#xff0c;进入。发现实现了ClientHttpRequ…

Snipaste--强大的截图贴图软件--非常实用

一.软件介绍&#xff1a; Snipaste 是一个简单但强大的截图工具&#xff0c;也可以让你将截图贴回到屏幕上&#xff01;下载并打开Snipaste&#xff0c;按下 F1 来开始截图&#xff0c;再按 F3&#xff0c;截图就在桌面置顶显示了。就这么简单&#xff01;你还可以将剪贴板里的…

SK-Net eca注意力机制应用于ResNet (附代码)

resnet发展历程 论文地址&#xff1a;https://arxiv.org/pdf/1903.06586.pdf 代码地址&#xff1a;https://github.com/pppLang/SKNet 1.是什么&#xff1f; SK-net网络是一种增加模块嵌入到一些网络中的注意力机制&#xff0c;它可以嵌入和Resnet中进行补强&#xff0c;嵌入…

composer安装thinkphp6报错

composer安装thinkphp6报错&#xff0c; 查看是否安装了对应的PHP扩展&#xff0c;我这边使用的是宝塔的环境&#xff0c;全程可以可视化操作 这样就可以安装完成了

linux工具篇

文章目录 linux工具篇1. linux 软件包管理器-yum1.1 什么是软件包1.2 yum的使用1.3 yum源 2. linux编辑器-vim2.1 vim概念2.2 vim各个模式切换2.3 vim正常模式命令汇总2.4 vim底行模式各命令汇总2.5 vim的简单配置 3. Linux编译器-gcc/g使用3.1 复习程序编译过程(1) 预处理(2) …

【Oracle】Navicat Premium 连接 Oracle的两种方式

Navicat Premium 使用版本说明 Navicat Premium 版本 11.2.16 (64-bit) 一、配置OCI 1.1 配置OCI环境变量 1.1.2 设置\高级系统设置 1.1.2 系统属性\高级\环境变量(N) 1.1.3 修改/添加系统变量 ORACLE_HOME ORACLE_HOME D:\app\root\product\12.1.0\dbhome_11.1.4 添加系…

Redis快速上手篇(一)(安装与配置)

NoSQL NoSQL 是 Not Only SQL 的缩写&#xff0c;意即"不仅仅是 SQL"的意思&#xff0c;泛指非关系型的数据库。强调 Key-Value Stores 和文档数据库的优点。 NoSQL 产品是传统关系型数据库的功能阉割版本&#xff0c;通过减少用不到或很少用的功能&#xff0c;来大…

vue3中使用svg并封装成组件

打包svg地图 安装插件 yarn add vite-plugin-svg-icons -D # or npm i vite-plugin-svg-icons -D # or pnpm install vite-plugin-svg-icons -D使用插件 vite.config.ts import { VantResolver } from unplugin-vue-components/resolvers import { createSvgIconsPlugin } from…