pytorch训练后pt模型中保存内容详解(yolov8n.pt为例)

news2024/9/28 7:28:02

在 PyTorch 中,.pt 模型文件通常包含以下几类数据:

        模型参数:

                存储模型的权重和偏置参数。

        优化器状态:

                包含优化器的状态信息,以便在恢复训练时能够从中断的地方继续。

        训练状态:

                一些训练过程中的信息,例如当前的 epoch 数和训练进度。

        其他元数据:

                包括模型的配置、训练时使用的超参数等。

        在讲解pytorch pt(pth)文件中保存了什么内容之前,需要先了解pt在保存时保存了那些参数。

以YOLO系列pt保存代码来介绍说明:

1. 模型保存代码:

 def save_model(self):
        ckpt = {
            'epoch': self.epoch, #
            'best_fitness': self.best_fitness,
            'model': deepcopy(de_parallel(self.model)).half(),
            'ema': deepcopy(self.ema.ema).half(),
            'updates': self.ema.updates,
            'optimizer': self.optimizer.state_dict(),
            'train_args': vars(self.args),  # save as dict
            'date': datetime.now().isoformat(),
            'version': __version__}
        # Use dill (if exists) to serialize the lambda functions where pickle does not do this
        try:
            import dill as pickle
        except ImportError:
            import pickle
        # Save last, best and delete
        torch.save(ckpt, self.last, pickle_module=pickle)
        if self.best_fitness == self.fitness:
            torch.save(ckpt, self.best, pickle_module=pickle)
        if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
            torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
        del ckpt

参数说明:

        'epoch': 当前的训练轮次数。

        'best_fitness': 最佳性能指标的数值。

        'model': 深拷贝(deepcopy)并将模型参数进行半精度(half)转换后的模型。

        'ema': 深拷贝并将指数移动平均模型参数进行半精度转换后的指数移动平均模型。

        'updates': 指数移动平均模型的更新次数。

        'optimizer': 优化器的状态字典(state_dict)。

        'train_args': 训练参数的字典表示,使用vars(self.args)将self.args对象转换为字典。

        'date': 当前的日期和时间,使用datetime.now().isoformat()获取。

        'version': 代码的版本号,通过__version__获取。

        其中:model中保存的模型的结构,train_args中保存训练时的一些参数(超参数)。

通过上述功能函数可以看到pytorch保存的pt文件中的内容。

补充说明:

        torch.save()函数用于将PyTorch模型保存到磁盘上的文件中,以便以后可以重新加载和使用。它的基本语法如下:

        torch.save(obj, f, pickle_module=<module 'pickle' from '...'>, pickle_protocol=2)

                obj是要保存的对象,通常是一个模型的状态字典(state_dict())。

                f是文件的路径或文件对象,用于存储模型。

                pickle_module是用于序列化的Python模块,默认为pickle。

                pickle_protocol是序列化时使用的协议版本,默认为2。

2. 模型加载介绍

下面通过Debug来详解pt中的具体内容:

首先加载模型,代码如下:

import sys
import argparse
import os
import struct
import torch
pt_file = "./yolov8n.pt"
wts_file = "./yolov8n.wts"
# Initialize
device = 'cpu'
# Load model
modelAll = torch.load(pt_file, map_location=device)
model = modelAll['model'].float()  # load to FP32
#model = torch.load(pt_file, map_location=device)['model'].float()  # load to FP32

anchor_grid = model.model[-1].anchors * model.model[-1].stride[..., None, None]
delattr(model.model[-1], 'anchors')
model.to(device).eval()
with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())
        f.write('\n')

 Debug结果如下所示,分别对应save_model()中保存的内容

其中model(model = modelAll['model'].float())中内容如下:

       model的类型为DetectionModel,里面包含了模型结构(model.model)以及参数信息(model.args)及构造网络时的配置参数信息(model.yaml)以及目标类别及个数、stride等信息。 

3. 模型权重解析保存

        model.state_dict()是一个字典,键是参数的名称,值是对应的 tensor。

        其中保存着模型的权重(Weights)和偏置值(Biases)以及运行均值和方差(例如,Batch Normalization 层的 running_mean 和 running_var,用于推理时)等信息。

        权重解析保存代码如下:

with open(wts_file, 'w') as f:
    f.write('{}\n'.format(len(model.state_dict().keys())))
    for k, v in model.state_dict().items():
        print("key={0}, v={1}".format(k,v))
        vr = v.reshape(-1).cpu().numpy()
        f.write('{} {} '.format(k, len(vr)))
        for vv in vr:
            f.write(' ')
            f.write(struct.pack('>f', float(vv)).hex())
        f.write('\n')

代码功能介绍:

  1. 使用写模式打开一个文件 wts_file,以便保存模型的参数。
  2. 将模型参数的数量写入文件。
  3. 循环遍历每个参数的键名 k 和对应的值 v。
  4. 将参数 v 重塑为一维数组,并将其从 GPU 移动到 CPU(如果适用),然后转换为 NumPy 数组。
  5. 写入参数的名称和长度。
    for vv in vr:
        f.write(' ')
        f.write(struct.pack('>f', float(vv)).hex())

        遍历每个参数值,使用大端格式(‘>’)将其转换为浮点数并写入文件.

pt解包后保存后的文件内容如下:

上述代码可以将pt格式模型,转化为Nvidia TensorRT部署需要的文件。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2066978.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot的自动配置原理探究

目录 什么是SpringBoot的自动配置&#xff08;Auto-Configuration&#xff09; 举例&#xff1a;SpringBoot自动配置&#xff08;Redis的自动配置&#xff09;的实例&#xff1a; 步骤1.&#xff1a;引入Redis启动器pom依赖 步骤2.在application.yml或者&#xff08;proper…

火狐浏览器应用商店不支持下载

前言 之前手机一直用的火狐浏览器&#xff0c;现在换了新的手机&#xff0c;又想下载使用&#xff0c;从官网直接下载现在直接跳载到Google Play才能下载&#xff0c;但是国内又用不了的&#xff0c;这里就记录一下怎么在手机应用商店不支持情况下载。 从FTP服务器下载Beta版…

C++学习笔记----4、用C++进行程序设计(四)---- 复合关系与继承关系之间的细线

在现实世界只是很容易区分对象之间是复合关系还是继承关系。没有人会说桔子有一个水果--而只能是桔子是一种水果。但是&#xff0c;在代码中&#xff0c;有时候就不是那么清晰了。 设想有一个代表关联数组的假想类&#xff0c;将一个键影射到一个值的数据结构。例如&#xff0c…

python画图高斯平滑均值曲线

注:细线是具体值,粗线是高斯平滑处理后的均值曲线 #codinggbk import matplotlib.pyplot as plt import numpy as np from scipy.ndimage import gaussian_filter1d# 生成一些示例数据 np.random.seed(0) timesteps np.linspace(1000, 0, 1000) data 0.4 0.2 * np.random.r…

并查集(路径压缩、按秩合并、按大小合并)

文章目录 并查集简单介绍&#xff1a;初始化&#xff1a;如何查找&#xff1f;如何合并&#xff1f;优化如下&#xff1a;路径压缩&#xff1a;代码&#xff1a; 按秩合并&#xff1a;**代码&#xff1a;** 启发式合并&#xff08;按大小合并&#xff09;&#xff1a;代码: 例题…

E5063A-011 时域分析/测试向导程序

矢量网络分析 E5063A 选件 011 E5063A-011 时域分析/测试向导程序 不容错过&#xff01; 概述 Keysight E5063A ENA 系列 PCB 分析仪是较佳的 PCB 生产测试解决方案&#xff0c;可提供阻抗&#xff08;TDR&#xff09;和回波损耗&#xff08;S 参数&#xff09;测量能力。…

11091 最优自然数分解问题(优先做)

### 简短思路 #### 问题&#xff08;1&#xff09;&#xff1a;将n分解为若干个互不相同的自然数之和&#xff0c;且使这些自然数的乘积最大 1. 对于n < 4的情况&#xff0c;直接返回特定值。 2. 对于n > 4的情况&#xff0c;使用贪心策略&#xff0c;将n分解为从2开始的…

证书学习(一)keytool 工具使用介绍

目录 一、keytool 简介1.1 什么是 keytool&#xff1f;1.2 主要功能&#xff1a;1.3 使用场景1.4 常用命令1.5 默认参数 二、keytool 用法说明2.1 基本使用2.2 创建密钥库和密钥条目2.3 查看密钥库信息2.4 导出密钥库条目证书2.5 导入信任证书到密钥库2.6 打印证书内容2.7 删除…

零工市场小程序应该有什么功能?

数字经济现如今正飞速发展&#xff0c;零工市场小程序在连接雇主与自由职业者方面发挥着越来越重要的作用。一个高效的零工市场小程序不仅需要具备基础的信息发布与匹配功能&#xff0c;还应该涵盖交易管理、安全保障以及个性化服务等多个方面。 那么&#xff0c;零工市场小程…

为什么企业跨国组网建议用SD-WAN?

SD-WAN成为企业跨国组网的首选方案&#xff0c;主要因为它在灵活性、智能化管理以及数据安全等方面具备显著优势。在企业进行跨国组网时&#xff0c;往往会面临网络连接复杂、流量管理难度大以及数据安全等诸多挑战&#xff0c;而SD-WAN能够有效应对这些难题。 首先&#xff0c…

Docker续1:

一、打包传输 1.打包 [rootlocalhost ~]# systemctl start docker [rootlocalhost ~]# docker save -o centos.tar centos:latest [rootlocalhost ~]# ls anaconda-ks.cfg centos.tar 2.传输 [rootlocalhost ~]# scp centos.tar root192.168.1.100:/root 3.删除镜像 [r…

场外个股期权杠杆率是多少如何计算倍数?

今天带你了解场外个股期权杠杆率是多少如何计算倍数&#xff1f;场外个股期权的杠杆大小不是固定的&#xff0c;而是取决于期权合约的价值和标的资产的价值之间的比例&#xff0c;一般来说场外个股期权的杠杆率大概在5-30倍甚至更高左右。 场外个股期权杠杆率是多少&#xff1…

罗德与施瓦茨RS SMW200A 最实用的一款矢量信号发生器

Rohde & Schwarz SMW200A 是一款适用于最苛刻应用的矢量信号发生器。由于其灵活性、性能和直观的操作&#xff0c;它是生成复杂、高质量数字调制信号的完美工具。 罗德与施瓦茨 SMW200A 是开发新型宽带通信系统、验证 3G 和 4G 基站或航空航天和国防领域所需的数字调制信号…

【软考】cpu的组成

目录 1. 说明2. cpu结构图3. 运算器3.1 说明3.2 主要功能3.3 算术逻辑单元3.4 累加寄存器3.5 数据缓冲寄存器DR3.6 状态条件寄存器PSW 4. 控制器4.1 说明4.2 指令寄存器(IR)4.3 程序计数器(PC)4.4 地址寄存器(AR)4.5 指令译码器(DD) 5. 寄存器组6. 例题6.1 例题1 1. 说明 1.cp…

Lighthouse ApexZ 尘埃粒子计数器审计追踪 数据完整性

在大型制药企业中&#xff0c;高效、准确且安全的样本处理与数据管理至关重要。这些企业不仅需要确保产品质量符合严格的监管要求&#xff0c;还需要优化流程以提高生产效率和降低成本。结合您提到的LIMS&#xff08;实验室信息管理系统&#xff09;和Lighthouse ApexZ便携式空…

行星搅拌炒锅的优点有哪些?

1、容积大&#xff0c;产量高。 2、火力大&#xff0c;独特的燃烧装置&#xff0c;升温快&#xff0c;温度高&#xff0c;炒出的物料色泽鲜艳&#xff0c;口味纯正。 3、不糊锅&#xff0c;独特的搅拌装置&#xff0c;可以覆盖锅体的每一个角落&#xff0c;使物料不糊锅&…

《黑神话 悟空》大火,通关后部分景区可免门票,72处《黑神话 悟空》取景地汇总!

重要提醒&#xff01;打通关的天命人们 免门票了&#xff01;72处《黑神话 悟空》取景地汇总。 8月20日&#xff0c;首个国产3A大作《黑神话:悟空》上线&#xff0c;这几天&#xff0c;大家基本很难不刷到这个热点。在这个游戏中&#xff0c;去了全国多个景区取景&#xff0c;…

城乡燃气安全监管平台 打造城市安全防护网

随着城市化进程的不断加快&#xff0c;燃气已成为现代生活中不可或缺的重要能源。然而&#xff0c;传统燃气管理方式的局限性逐渐显现&#xff0c;难以应对日益增长的安全监管需求。为此&#xff0c;旭华智能基于其在智慧城市领域的深厚积累&#xff0c;推出了燃气安全监管物联…

Spring Cloud + Easy Excel导出表格

在现代应用开发中&#xff0c;数据的导出和处理是一个非常常见的需求。Spring Cloud 和 Easy Excel 是两个强大的工具&#xff0c;可以帮助我们高效地完成这个任务。本文将介绍如何将这两个工具结合起来&#xff0c;实现表格数据的导出功能。 1.环境准备 在开始之前&#xff0…