YOLOv8由pt文件中读取模型信息

news2024/9/25 17:18:36

Pytorch的pt模型文件中保存了许多模型信息,如模型结构、模型参数、任务类型、批次、数据集等
在先前的YOLOv8实验中,博主发现YOLOv8在预测时并不需要指定任务类型,因为这些信息便保存在pt模型中,那么,今天我们便来看看,其到底是如何加载这些参数的。

我们首先对pt文件进行一个简单介绍:

pt文格式

pt格式文件是PyTorch中用于保存张量数据的文件格式。与pth文件类似,pt文件也常用于模型的保存和加载,但更侧重于保存单个张量或一组张量数据。通过pt文件,我们可以方便地将张量数据持久化,并在需要时重新加载使用。

张量(Tensor)是PyTorch中的核心数据结构,用于表示多维数组。在深度学习中,张量常用于存储模型的参数、输入数据、中间结果等。因此,掌握pt文件的保存和加载方法对于PyTorch的使用者来说至关重要。

pt文件与pth的区别

pt.pth都是PyTorch模型文件的扩展名,但是它们的区别在于.pt文件是保存整个PyTorch模型的,而.pth文件只保存模型的参数。(其实现在似乎并没有区别了)
因此,如果要加载一个,pth文件,需要先定义模型的结构,然后再加载参数;而如果要加载一个,pt文件,则可以直接加载整个模型。

如何保存pt格式文件

PyTorch中,我们可以使用torch.save()函数将张量数据保存到pt文件中。

下面是一个简单的示例:

import torch
# 创建一个张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 将张量保存到pt文件中
torch.save(tensor, 'tensor.pt')

在上面的代码中,我们首先创建了一个二维张量tensor,然后使用torch.save()函数将其保存到名为tensor.pt的文件中。保存的文件将包含张量的数据和元数据,以便在加载时能够准确地恢复张量的结构和内容。

除了保存单个张量外,我们还可以保存多个张量到一个pt文件中。这可以通过将多个张量放入一个字典或列表中,然后将整个字典或列表保存到文件中实现。

例如:

# 创建多个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([[4, 5], [6, 7]])

# 将张量放入字典中
tensors_dict = {'tensor1': tensor1, 'tensor2': tensor2}

# 将字典保存到pt文件中
torch.save(tensors_dict, 'tensors_dict.pt')

如何加载pt格式文件

加载pt文件同样使用torch.load()函数。

下面是一个加载pt文件的示例:

# 加载单个张量的pt文件
loaded_tensor = torch.load('tensor.pt')
print(loaded_tensor)

# 加载包含多个张量的字典的pt文件
loaded_dict = torch.load('tensors_dict.pt')
print(loaded_dict['tensor1'])
print(loaded_dict['tensor2'])

在加载单个张量的pt文件时,我们直接调用torch.load()函数并传入文件名即可。加载得到的loaded_tensor将是一个与原始张量结构和内容相同的张量对象。
当加载包含多个张量的字典的pt文件时,我们同样使用torch.load()函数。加载得到的loaded_dict将是一个字典对象,其中包含了我们在保存时放入的所有张量。我们可以通过字典的键来访问这些张量。

强烈建议只保存模型参数,而非保存整个网络。PyTorch 官方也是这么建议的。

torch.save(net.state_dict(),path2)#只保留模型参数

(只保存模型参数)是官方推荐的方法,运行速度快,且占空间较小。需要注意的是 net.state_dict() 是将网络参数保存为字典形式(OrderedDict)load_state_dict() 加载的并不是网络参数的pth文件,而是字典。

pt文件保存神经网络

在评估时,记住一定要使用model.eval()来固定dropout和归一化层,否则每次推理会生成不同的结果。

import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):  # 神经网络部分用你自己的
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 2, 1)  # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)
        self.conv3 = nn.Conv2d(64, 128, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(6272, 128)  # 6272=128*7*7
        self.fc2 = nn.Linear(128, 8)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        self.output = F.log_softmax(x, dim=1)
        out1 = x
        return self.output,out1

def predict_mine():
    model=Net()
    model.load_state_dict(torch.load("model.pt"))
    print(model)
    images=torch.rand((1,1,64,64))
    x=model(images)
    print(x)
def torch_script_save():
    model=Net()
if __name__ == '__main__':
    save_model()
    predict_mine()

在这里插入图片描述

可以看到,我们可以通过pt文件读取出来下面的信息:

在这里插入图片描述

同时,我们也看到,我们虽然可以使用pt文件保存模型结构,但我们在推理时,依旧需要我们能够生成Net对象才能加载其数据,这其实很不方便,那么,有什么办法可以真正的将模型结构保存进去,让我们在推理过程中不需要再定义相关的类与对象呢,先前博主所使用的ONNX便是其中的一种,但它其实是另一种文件结构了,pt文件真的就不能摆脱环境吗,答案是否定的,TorchScript模型便解决了这个问题。

TorchScript模型

事实上,PyTorch提供了两种主要的模型保存和加载机制,一种是基于Python的序列化,另一种是TorchScript

普通的PyTorch模型(基于Python的序列化):

  • 保存: 使用torch.save(model.state_dict(), 'model_path.pth'),它保存了模型的权重和参数,但不保存模型的结构。(当然也是可以保存的,但我们需要处理一下才能用,比如定义好Net类)
  • 加载: 首先,您需要有模型的类定义。 创建该类的一个实例。 使用model.load_state_dict(torch.load('model_path.pth'))来加载权重。
    特点:
  • 需要Python环境和模型的原始代码来加载和运行模型。 保存的文件是Python特定的,并且依赖于特定的类结构。 主要用于继续训练或在Python环境中进行推断。

TorchScript模型:

TorchScript是PyTorch的一个子集,它创建了一个可以独立于Python运行的序列化模型。 生成方法:

  • Tracing: 使用torch.jit.trace方法。这涉及到通过模型运行一个输入示例,从而跟踪模型的执行路径。
  • Scripting: 使用torch.jit.script方法。这转化Python代码到TorchScript,允许更复杂的模型和控制流。
  • 保存: 使用torch.jit.save(traced_model, 'model_path.pt')
  • 加载: 使用torch.jit.load(‘model_path.pt’)。注意,加载不需要原始的模型类定义。

特点:

  • 可以在没有Python运行时的环境中运行,如C++
  • 提供了一种方法,将模型从Python转移到其他平台或部署环境。
  • 包含模型的完整定义,包括结构、权重和参数。

Tracing方法:

example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_model.pt')

Scripting方法:

scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_model.pt')

加载模型:

loaded_model = torch.jit.load('model_path.pt')

例程:

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# Tracing
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_simple_model.pt')

# Scripting
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_simple_model.pt')

# 加载模型
loaded_model = torch.jit.load('traced_simple_model.pt')

我们采用TorchScript结构去执行先前的Net

def torch_script_save():
    model=Net()
    example_input =torch.rand((1,1,64,64))
    traced_model = torch.jit.trace(model, example_input)
    torch.jit.save(traced_model, 'traced_simple_model.pt')

    # Scripting
    scripted_model = torch.jit.script(model)
    torch.jit.save(scripted_model, 'scripted_simple_model.pt')
def predict_script():
    model1=torch.jit.load("traced_simple_model.pt")
    image =torch.rand((1,1,64,64))
    print(model1)
    model1.eval()
    x=model1(image)
    print(x)
    model2=torch.jit.load("scripted_simple_model.pt")
    image =torch.rand((1,1,64,64))
    print(model2)
    model2.eval()
    x=model2(image)
    print(x)

在这里插入图片描述

yaml文件内容如下:

{'nc': 1000, 
'scales': {'n': [0.33, 0.25, 1024], 's': [0.33, 0.5, 1024], 'm': [0.67, 0.75, 1024], 'l': [1.0, 1.0, 1024], 'x': [1.0, 1.25, 1024]}, 
'backbone': [[-1, 1, 'Conv', [64, 3, 2]], [-1, 1, 'Conv', [128, 3, 2]], [-1, 3, 'C2f', [128, True]], [-1, 1, 'Conv', [256, 3, 2]], [-1, 6, 'C2f', [256, True]], [-1, 1, 'Conv', [512, 3, 2]], [-1, 6, 'C2f', [512, True]], [-1, 1, 'Conv', [1024, 3, 2]], [-1, 3, 'C2f', [1024, True]]],
 'head': [[-1, 1, 'Classify', ['nc']]], 
 'scale': 'n',
 'yaml_file': 'yolov8n-cls.yaml', 
 'ch': 3}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1989080.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot外部配置文件来修改jar包属性

在jar包所在的文件夹内创建application.yml配置文件: 在yml文件内部添加想要修改的属性值就可以了。 随后输入下面命令来运行jar包: java -jar Big-Deal-Boot-0.0.1-SNAPSHOT.jar 下图是优先级顺序,从上往下依次变高:

Linux Shell编程--变量

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 变量: bash作为程序设计语言和其它高级语言一样也提供使用和定义变量的功能 预定义变量、环境变量、自定义变量、位置变量 一、自定义变…

【Java 第十二篇章】SpringMVC 呜呜,为啥现在面试会问呢

一、简介 Spring MVC 是 Spring 框架的一个模块,用于构建 Web 应用程序,它遵循模型 - 视图 - 控制器(MVC)设计模式。 二、Spring MVC 的核心组件 1、DispatcherServlet 这是 Spring MVC 的前端控制器,它是整个框架…

Spring Boot获取Bean的三种方式

​ 博客主页: 南来_北往 系列专栏:Spring Boot实战 引言 在Spring Boot中,Bean是一个由Spring IoC容器管理的对象。 Spring Bean是在Spring IoC容器中被实例化、组装和管理的对象,可以视为Spring应用的构建块。它通过提供一套丰富的注…

Centos7安装Zabbix5.0的yum安装失败的解决方案

目前由于Centos7停服以及Zabbix官方限制了其5.0版本在Centos7上安装服务版本,因此可能会导致安装Zabbix5.0的一些组件无法正常安装。 zabbix5.0安装参考:一、zabbix 5.0 部署_zabbix5.0部署-CSDN博客 问题现象 当安装到zabbix的GUI包时报如下错误&…

护眼灯真的可以护眼吗?五款专业护眼灯品牌在线分析

很多新手小白在选购护眼台灯前,都会思考哪个护眼台灯的效果比较好这个问题,因为有的无良商家因为想要降低成本,使用一些廉价低劣的处理器,台灯的电压和功率都难以保证,有的甚至会产生有害的辐射,对人体的健…

Unity入门3——脚本入门

本文使用的代码编辑器为VSCode 安装接口有: 通过将变量设置为public,可以直接在unity的Inspector面板中看到相关变量。此时可直接将需要的素材拖拽到变量处。 Awake()方法 只要物体被加到场景就会执行一次

【vue3】【elementPlus】【国际化】

1.如需从0-1开始,请参考 https://blog.csdn.net/Timeguys/article/details/140995569 2.使用 vue-i18n 模块: npm i vue-i18n3.在 src 目录下创建 locales 目录,里面创建文件:en.js、zh-cn.js、index.js 语言js文件:…

ICC2:检查漏tree的脚本

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 前面写了innovus检查clock 漏tree的脚本,ICC2的脚本也相差不多,只需要替换少部分命令就行。原理就是检查clock pin有没有clock 定义。 foreach pin [ge…

JavaSE之常用API大全

API大全 一、Object toString 返回这个对象的字符串表示形式 当输入一个引用类型的时候,会自动调用该对象的toString方法 默认的toString方法是: 包名.类名十六进制值 Equals 用于比较两个对象是否相同,默认比较内存地址 “”:比较基本类型的时候,比较的是值的大小,而比较引用…

光伏气象站会对环境产生影响吗?

在探讨光伏气象站对环境的影响时,我们首先要明确其核心功能和运作原理。光伏气象站,作为集光伏发电与气象监测于一体的设备,其主要作用在于为光伏电站提供精准的气象数据支持,并辅助电站优化运行,提高发电效率。 从环境…

互联网之光与人工智能之光交相辉映,如何抓住5G人工智能红利

一、互联网之光闭幕 第六届世界互联网大会“互联网之光”虽然已经闭幕!“科学与技术”“产业与经济”“人文与社会”“合作与治理”等4大板块20个分论坛,为5G人工智能时代提出了一个新的问题:5GAI 交相辉映,抓住5G人工智能红利&am…

George Danezis谈Mysticeti的吞吐量和低延迟

Sui的新共识引擎Mysticeti已经在主网上开始分阶段推出。Mysten Labs联合创始人兼首席科学家George Danezis在采访中,解释了吞吐量和延迟的区别,以及Sui上的Mysticeti如何结合这两者。 采访视频:https://youtu.be/A4vtyE8obXQ 中文译文&…

【前缀异或和】力扣2588. 统计美丽子数组数目

给你一个下标从 0 开始的整数数组nums 。每次操作中&#xff0c;你可以&#xff1a; 选择两个满足 0 < i, j < nums.length 的不同下标 i 和 j 。 选择一个非负整数 k &#xff0c;满足 nums[i] 和 nums[j] 在二进制下的第 k 位&#xff08;下标编号从 0 开始&#xff0…

1.数据加载时 暂无数据会晃一下再显示数据 2.判断图片加载失败后渲染占位图

工作中问题小记 这种问题正常来说都没有记录的意义 但是我是强迫症 hhh 1.在正常数据渲染时 如果为空我们会渲染(暂无数据占位图)来提示用户 通常是用数据长度来判断 但是他在刷新的时候会先弹出 <暂无数据> 的提示再显示那个数据 解决方法: 搞个标识符 必须等他请求完接…

【电控笔记z14z16】增加霍尔元件分辨率

霍尔传感器用的不多?实际增量编码器更好 z14 假设60度内速度不变 z16(更简单的方法)BLDC

vue前端项目--路由vue-router

1. 路由介绍 我们可以总结一下从早期网站开发到现代单页应用(SPA)的发展过程及其关键概念&#xff1a; 早期的服务器端渲染 (SSR): 早期的网站开发中&#xff0c;服务器负责生成完整的 HTML 页面&#xff0c;并将其发送给客户端展示。 每个 URL 对应一个特定的控制器(Control…

学习笔记 韩顺平 零基础30天学会Java(2024.8.7)

P481 Math方法 利用random返回一个[2,7]之间的随机数&#xff1a; 因为random只能返回[0,1)之间的随机数&#xff0c;因此做一下处理&#xff1a;[(int)(a), (int) (aMath.random()*(b-a1))]&#xff0c;对于Math.random()*(b-a1)&#xff0c;其中b-a1&#xff0c;它乘上[0,1)相…

PFC+DAB原理介绍

三、PFCDAB原理介绍 1、PFC工作原理 三相交流电网的一个公认拓扑是三相全桥 PFC。此拓扑也称为 B6 或“三段桥”。如图显示此拓扑仅使用三相交流输入运行。如果需要单相工作模式。可以轻松地通过增加中性线实现 三相电源包含三个交流相位&#xff0c;通常用 L1、L2 和 L3 表…

linux使用ssh连接一直弹出密码框问题

1.查看ssh服务的状态 输入以下命令&#xff1a; sudo service sshd status 小编已经安装了。 如果出现 Loaded: error (Reason: No such file or directory) 提示的话&#xff0c;说名没有安装ssh服务&#xff0c;按照第二步&#xff1a;安装ssh服务。 如果出现 Active: in…