解决方案:炼丹师养成计划 Pytorch如何进行断点续训——DFGAN断点续训实操

news2024/12/24 9:20:12

我们在训练模型的时候经常会出现各种问题导致训练中断,比方说断电、系统中断、内存溢出、断连、硬件故障、地震火灾等之类的导致电脑系统关闭,从而将模型训练中断。

所以在实际运行当中,我们经常需要每100轮epoch或者每50轮epoch要保存训练好的参数,以防不测,这样下次可以直接加载该轮epoch的参数接着训练,就不用重头开始。下面我们来介绍Pytorch断点续训原理及DFGAN20版本和22版本断点续训实操

文末评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。

一、Pytorch断点续训

1.1、保存模型

pytorch保存模型等相关参数,需要利用torch.save(),torch.save()是PyTorch框架中用于保存Python对象到磁盘上的函数,一般为

torch.save(checkpoint, checkpoint_path)

其中checkpoint为保存模型的所有参数和缓存的键值对,checkpoint_path表示最终保存的模型,通常以.pth格式保存。

torch.save()函数会将obj序列化为字节流,并将字节流写入f指定的文件中。在读取数据时,可以使用torch.load()函数来将文件中的字节流反序列化成Python对象。使用这两个函数可以轻松地将PyTorch模型保存到磁盘上,并在需要的时候重新加载使用。

一般在实际操作中,我们写为:

torch.save(netG.state_dict(),'%s/netG_epoch_%d.pth' % (self.model_dir, epoch))

它接受两个参数:要保存的对象(即状态字典)和文件路径。在这里,状态字典是通过调用netG.state_dict()方法获得的,而文件路径是使用字符串格式化操作构建的。字符串'%s/netG_epoch_%d.pth' % (self.model_dir, epoch) 中,%s表示第一个字符串占位符将被替换为self.model_dir(即保存.pth文件的目录路径),%d表示第二个字符串占位符将被替换为epoch(即当前训练的轮数)。这样就可以在每一轮训练结束后将当前的网络模型参数保存到一个新的.pth文件中,文件名中包含轮数以便于后续的查看和比较。

1.2、读取模型

对应的,torch.load()函数是PyTorch框架中用于从磁盘上加载Python对象的函数。一般为:

 checkpoint = torch.load(log_dir)
 model.load_state_dict(checkpoint['model'])

torch.load()函数会从文件中读取字节流,并将其反序列化成Python对象。对于PyTorch模型,可以直接将其反序列化成模型对象。

一般实际操作中,我们常常写为:

model.load_state_dict(torch.load(path))

首先使用torch.load()函数从指定的路径中加载模型参数,得到一个字典对象,即state_dict。其中,字典的键是各个层次结构的名称,而键所对应的值则是该层次结构中各个参数的值。

然后,使用model.load_state_dict()函数将state_dict中的参数加载到已经定义好的模型中。这个函数的作用是将state_dict中每个键所对应的参数加载到模型中对应的键所指定的层次结构上。

需要注意的是,由于模型的结构和保存的参数的结构必须匹配,因此在加载参数之前,需要先定义好模型的结构,使其与保存的参数的结构相同。如果结构不匹配,会导致加载参数失败,甚至会引发错误。

二、DFGAN20版本

在DFGAN20版本当中,模型保存在DFGAN/code/models当中,其中netG_300.pth就是代表生成器第300轮的模型netD_300.pth也就是代表鉴别器第300轮的模型。
在这里插入图片描述

我们可以将需要的模型的路径记下来,然后打开main.py文件,其中在270行左右的# # validation data #下面
在这里插入图片描述
可以在下面这段代码的后面

netG = NetG(cfg.TRAIN.NF, 100, sentencelstm, wordlstm).to(device)
netD = NetD(cfg.TRAIN.NF).to(device)

增加两句:

netG.load_state_dict(torch.load('models/%s/netG_300.pth' % (cfg.CONFIG_NAME)))
netD.load_state_dict(torch.load('models/%s/netD_300.pth' % (cfg.CONFIG_NAME)))

这样,就成功读取了所选文件夹目录下的netG_300.pthnetD_300.pth,如果要在这个epoch下进行采样,只需要把code/cfg/bird.ymlB_VALIDATION改为 True,如果需要在这个epoch下进行断点续训则B_VALIDATION改为False就可以了。

三、DFGAN22版本

DFGAN22版本与DFGAN20版本代码结构有所不同,但是在断点续训的原理上是一样的。

DFGAN22版本在保存模型时并没有单独保存netG, netD, netC, optG, optD等模型,而且将他们的模型都保存为一个.pth文件,如名为state_epoch_940.pth代表的就是第940轮的所有断点文件。这些断点文件保存在code/saved_models/bird或cooc下,如:
在这里插入图片描述
如果要进行断点续训,我们可以把这个文件路径记下来或者将文件挪到需要的位置,我一般将需要断点续训或者采样的模型放在pretrained文件夹下。

然后下一步,打开code/cfg/bird.yml文件,如果是coco数据集则打开coco.yml
在这里插入图片描述
修改state_epoch为自己选定的第几轮模型(想读取state_epoch_940.pth,则state_epoch改为940,这样后面打印结果、保存模型就是从941开始了),然后修改checkpoint为相应模型的路径如:./saved_models/bird/pretrained/state_epoch_940.pth,最终如下所示:

state_epoch: 940
checkpoint: ./saved_models/bird/pretrained/state_epoch_940.pth

如果你想更深层次了解其原理,即DFGAN22 版是如何保存模型和读取模型的,可以打开code/lib/utils.py文件,在第140行附近写了保存模型的函数,与我们之前讲的原理是一样的,只不过他将netG, netD, netC, optG, optD等又做了一层,然后将其统一保存到state_epoch_中:

def save_models(netG, netD, netC, optG, optD, epoch, multi_gpus, save_path):
    if (multi_gpus==True) and (get_rank() != 0):
        None
    else:
        state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}, \
                'optimizers': {'optimizer_G': optG.state_dict(), 'optimizer_D': optD.state_dict()},\
                'epoch': epoch}
        torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch))

在第90行到140行附近,也写了读取模型的方法,也就是读相应checkpoint的checkpoint['model']['netG'],看完你会发觉,原理很简单,代码也不算很难,遇到问题建议大家多多阅读源码。

def load_opt_weights(optimizer, weights):
    optimizer.load_state_dict(weights)
    return optimizer


def load_model_opt(netG, netD, netC, optim_G, optim_D, path, multi_gpus=False):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus)
    netD = load_model_weights(netD, checkpoint['model']['netD'], multi_gpus)
    netC = load_model_weights(netC, checkpoint['model']['netC'], multi_gpus)
    optim_G = load_opt_weights(optim_G, checkpoint['optimizers']['optimizer_G'])
    optim_D = load_opt_weights(optim_D, checkpoint['optimizers']['optimizer_D'])
    return netG, netD, netC, optim_G, optim_D


def load_models(netG, netD, netC, path):
    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    netG = load_model_weights(netG, checkpoint['model']['netG'])
    netD = load_model_weights(netD, checkpoint['model']['netD'])
    netC = load_model_weights(netC, checkpoint['model']['netC'])
    return netG, netD, netC


def load_netG(netG, path, multi_gpus, train):
    checkpoint = torch.load(path, map_location="cpu")
    netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus, train)
    return netG


def load_model_weights(model, weights, multi_gpus=False, train=True):
    if list(weights.keys())[0].find('module')==-1:
        pretrained_with_multi_gpu = False
    else:
        pretrained_with_multi_gpu = True
    if (multi_gpus==False) or (train==False):
        if pretrained_with_multi_gpu:
            state_dict = {
                key[7:]: value
                for key, value in weights.items()
            }
        else:
            state_dict = weights
    else:
        state_dict = weights
    model.load_state_dict(state_dict)
    return model

三、可能遇见的问题

问题1:模型中断后继续训练出错

在有些时候我们需要保存训练好的参数为path文件,以防不测,下次可以直接加载该轮epoch的参数接着训练,但是在重新加载时发现类似报错:

size mismatch for block0.affine0.linear1.linear2.weight: copying a param with shape torch.Size([512, 256]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for block0.affine0.linear1.linear2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([256]).

问题原因:这是说明某个超参数出现了问题,可能你之前训练时候用的是64,现在准备在另外的机器上面续训的时候某个超参数设置的是32,导致了size mismatch,也有可能是你动过了模型的代码,导致现在代码和训练的模型匹配不上了。

解决方案:查看size mismatch的模型部分,将超参数改回来,并将代码和原本训练的代码保持一致。

问题2:模型中断后继续训练 效果直降

加载该轮epoch的参数接着训练,继续训练的过程是能够运行的,但是发现继续训练时效果大打折扣,完全没有中断前的最后几轮好。
问题原因:暂时未知,推测是续训时模型加载的问题,也有可能是保存和加载的方式问题
解决方案:统一保存和加载的方式,当我采用以下方式时,貌似避免了这个问题:
模型的保存:

torch.save(netG.state_dict(), 'models/%s/netG_%03d.pth' % (cfg.CONFIG_NAME, epoch))

模型的重新加载:

netD.load_state_dict(torch.load('models/%s/netD_300.pth' % (cfg.CONFIG_NAME), map_location='cuda:0'))

四、好书推荐(评论送书)

4.1、好书推荐

《PyTorch教程:21个项目玩转PyTorch实战》
在这里插入图片描述
阅读这本书,可以通过经典项目入门 PyTorch,通过前沿项目提升 PyTorch,基于PyTorch玩转深度学习,本书适合人工智能、机器学习、深度学习方面的人员阅读,也适合其他 IT 方面从业者,另外,还可以作为相关专业的教材。

京东自营购买链接:https://item.jd.com/13522327.html

评论区评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。

目前买还有优惠:北京大学出版社4月“423世界读书日”促销活动
当当活动日期:4.6-4.11,4.18-4.23
京东活动日期: 4.6 一天, 4.17-4.23
活动期间满100减50或者半价5折销售
希望大家关注参与423读书日北大社促销活动
京东自营购买链接:https://item.jd.com/13522327.html

4.2、内容简介

PyTorch 是基于 Torch 库的开源机器学习库,它主要由 Meta(原 Facebook)的人工智能研究实验室开发,在自然语言处理和计算机视觉领域都具有广泛的应用。本书介绍了简单且经典的入门项目,方便快速上手,如 MNIST数字识别,读者在完成项目的过程中可以了解数据集、模型和训练等基础概念。本书还介绍了一些实用且经典的模型,如 R-CNN 模型,通过这个模型的学习,读者可以对目标检测任务有一个基本的认识,对于基本的网络结构原理有一定的了解。另外,本书对于当前比较热门的生成对抗网络和强化学习也有一定的介绍,方便读者拓宽视野,掌握前沿方向。

4.3、作者简介

王飞,2019年翻译了PyTorch官方文档,读研期间研究方向为自然语言处理,主要是中文分词、文本分类和数据挖掘。目前在教育行业工作,探索人工智能技术在教育中的应用。
何健伟,曾任香港大学助理研究员,研究方向为自然语言处理,目前从事大规模推荐算法架构研究工作。
林宏彬,硕士期间研究方向为自然语言处理,现任阿里巴巴算法工程师,目前从事广告推荐领域的算法研究工作。
史周安,软件工程硕士,人工智能技术爱好者、实践者与探索者。目前从事弱监督学习、迁移学习与医学图像相关工作。

💡 最后

我们已经建立了🏤T2I研学社群,如果你还有其他疑问或者对🎓文本生成图像很感兴趣,可以私信我加入社群

📝 加入社群 抱团学习:中杯可乐多加冰-深度学习T2I研习群

🔥 限时免费订阅:文本生成图像T2I专栏

🎉 支持我:点赞👍+收藏⭐️+留言📝

评论区评论【人生苦短,我用Pytorch!】抽一位小伙伴送出《PyTorch教程:21个项目玩转PyTorch实战》书籍一本,包邮到家。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/430099.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python实现哈里斯鹰优化算法(HHO)优化卷积神经网络分类模型(CNN分类算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 2019年Heidari等人提出哈里斯鹰优化算法(Harris Hawk Optimization, HHO),该算法有较强的全…

DAY 43 Apache的配置与应用

虚拟Web主机 概述 虚拟web主机指的是在同一台服务器中运行多个web站点,其中每一个站点实际上并不独立占用整个服务器,因此被称为"虚拟"web主机。通过虚拟web主机服务可以充分利用服务器的硬件资源,从而大大降低网站构建及运行成本…

TensorFlow 决策森林详细介绍和使用说明

使用TensorFlow训练、调优、评估、解释和部署基于树的模型的完整教程 两年前TensorFlow (TF)团队开源了一个库来训练基于树的模型,称为TensorFlow决策森林(TFDF)。经过了2年的测试,他们在上个月终于宣布这个包已经准备好发布了,也就是说我们…

在android项目上集成libyuv库以及使用linyuv库完成camera的缩放,旋转,翻转,裁剪操作

目录 一、下拉google官方的libyuv库代码 二、在android项目中集成libyuv库 1.环境配置 2.拷贝libyuv源码文件 ​编辑3.配置cmake libyuv相关的链接编译等 三、使用libyuv库 1.libyuv库完成camera的旋转 2.libyuv库实现翻转 3.libyuv库实现缩放 4.libyuv库实现裁剪 一…

为什么重视安全的公司都在用SSL安全证书?

我们今天来讲一讲为什么重视安全的公司都在用SSL证书 SSL证书是什么? SSL安全证书是由权威认证机构颁发的,是CA机构将公钥和相关信息写入一个文件,CA机构用他们的私钥对我们的公钥和相关信息进行签名后,将签名信息也写入这个文件…

对于数据库而言,其锁范围可以分为全局锁 、表级锁、 行级锁

一、全局锁 全局锁就是对整个数据库实例加锁。 MySQL 提供了一个加全局读锁的方法,命令是 Flush tables with read lock (FTWRL)。当你需要让整个库处于只读状态的时候,可以使用这个命令,之后其他线程的以下语句会被阻塞:数据更新…

DOM(1)

DOM(文档对象模型):处理可扩展标记语言(HTML或XML)的标准编程接口,可以改变网页的内容、结构和样式。DOM树: …

ubuntu18 网络问题

在/etc/netplan/*.yaml配置文件中: renderer的值可以是networkd,或者是NetworkManager 它俩的其中一个区别为: networkd在图像界面,networking setting中不显示网卡配置。 版权简介: 从Ubuntu 18.04.2版本开始&…

腾讯云4核8G轻量服务器12M支持多少访客同时在线?并发数怎么算?

腾讯云轻量4核8G12M轻量应用服务器支持多少人同时在线?通用型-4核8G-180G-2000G,2000GB月流量,系统盘为180GB SSD盘,12M公网带宽,下载速度峰值为1536KB/s,即1.5M/秒,假设网站内页平均大小为60KB…

网络安全:网络攻击原理与方法.

网络安全:网络攻击原理与方法. 网络攻击:是损害网络系统安全属性的危害行为。危害行为导致网络系统的机密性、完整性、可控性、真实性、抗抵赖性等受到不同程度的破坏。 目录: 常见的危害行为有四个基本类型: 网络攻击模型&…

项目实践 | 行人跟踪与摔倒检测报警

项目实践 | 行人跟踪与摔倒检测报警 小白学视觉 7月7日 原文地址:项目实践 | 行人跟踪与摔倒检测报警 1.简介 本项目的目的是为了给大家提供跟多的实战思路,抛砖引玉为大家提供一个案例,也希望读者可以根据该方法实现更多的思想与想法&…

为什么Uber从PostgreSQL换成了MySQL

说明:本文翻译自Why Uber Engineering Switched from Postgres to MySQL 引言 Uber的早期架构包括一个用Python编写的单一后端应用程序,它使用Postgres进行数据持久化。从那时起,Uber的架构发生了重大变化,转向了微服务和新数据…

比例放大器设置接线US-DAS1/US-DAS2

US-DAS1、US-DAS2比例放大器接线定义 1 CMD 指令 2 CMD- 指令- 3/4/5 N.C. 不接 6 ENA 使能 7 VREF_5V 参考电压5V 8 VREF_0V 参考电压0V 9 SOL_A 电磁铁A 10 SOL_A- 电磁铁A- 11 PWR 电源 12 PWR- 电源- 13 SOL_B- 电磁铁B- 15 RS485_A - 16 RS485_B -

LeetCode-盛最多水的容器-11题

LeetCode-盛最多水的容器-11题 题目中要求计算最大面积,即需要选择对应的长和宽。 最终解决方法:使用对撞指针 对撞指针的概念:是指在数组的两个端引入两个指针,左指针不断向右移动,右指针不断向左移动。最终到达两个…

Spring AOP切入点表达式

先来认识两个概念吧(其实Spring AOP实现功能增强的方式就是代理模式) 目标对象(Target):原始功能去掉共性功能对应的类产生的对象,这种对象是无法直接完成最终工作的代理(Proxy):目标对象无法直接完成工作,…

【学术搬砖】第一期

“一期一会” —— 珍惜我们遇见的论文,把和每个论文的相遇,当做一种缘分。我们会定期推荐若干优质学术论文,并分享一段总结,非常欢迎提出任何建议和想法。 【NeurIPS2022】ShufflfleMixer: An Effificient ConvNet for Image Su…

R -- 时序分析

brief 横截面数据对应着某个时间点的数据。 纵向的数据对应着一系列时间点的数据&#xff0c;某个变量随着时间的变动被反复测量。 研究纵向数据&#xff0c;也许会得到“时间”的答案。 描述时间序列 生成时序对象 x <- runif(20)ts(x) ts(x,frequency 12) ts(x,frequen…

python整合Django框架初试

1.安装 以下是安装Django的步骤&#xff1a; 确认Python已经安装&#xff1a;在终端&#xff08;Mac/Linux&#xff09;或命令提示符&#xff08;Windows&#xff09;中输入python -V&#xff0c;如果出现Python版本号&#xff0c;则已经安装Python&#xff1b;如果未安装&…

Nginx配置与应用

Nginx 是开源、高性能、高可靠的 Web 和反向代理服务器&#xff0c;而且支持热部署&#xff0c;几乎可以做到 7 * 24 小时不间断运行&#xff0c;即使运行几个月也不需要重新启动&#xff0c;还能在不间断服务的情况下对软件版本进行热更新。性能是 Nginx 最重要的考量&#xf…

PCL源码剖析 -- 欧式聚类

PCL源码剖析 – 欧式聚类 参考&#xff1a; 1. pcl Euclidean Cluster Extraction教程 2. 欧式聚类分析 3. pcl-api源码 4. 点云欧式聚类 5. 本文完整工程地址 可视化结果 一. 理论 聚类方法需要将无组织的点云模型P划分为更小的部分&#xff0c;以便显著减少P的总体处理时间…