8、深入剖析PyTorch的state_dict、parameters、modules源码

news2024/11/25 20:09:51

文章目录

  • 1. 重要类
  • 2. 保存模型
  • 3. 代码测试

1. 重要类

  • container.py
  • nn.sequential
  • nn.modulelist
  • save_state_dict

2. 保存模型

pytorch官网教程

3. 代码测试

比较急,后续完善

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :ToTest01.py
# @Time      :2024/11/24 10:37
# @Author    :Jason Zhang
import torch
from torch import nn
from torch.nn import Module


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear1 = nn.Linear(2, 3)
        self.linear2 = nn.Linear(3, 4)
        self.batch_norm4 = nn.BatchNorm2d(4)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


if __name__ == "__main__":
    run_code = 0
    input_x = torch.randn((1, 2))
    test_model = MyModel()
    y = test_model(input_x)
    model_modules = test_model._modules
    print(f"*"*50)
    print(f"model_modules=\n{model_modules}")
    print(f"*"*50)
    linear1 = model_modules['linear1']
    print(f"*"*50)
    print(f"linear1={linear1}")
    print(f"*"*50)
    print(f"linear1.weight=\n{linear1.weight}")
    print(f"*"*50)
    print(f"linear1.weight.dtype={linear1.weight.dtype}")
    print(f"*"*50)
    test_model.to(torch.double)
    print(f"linear1.weight.dtype={linear1.weight.dtype}")
    print(f"*"*50)
    test_model.to(torch.float32)
    print(f"linear1.weight.dtype={linear1.weight.dtype}")
    print(f"*"*50)
    model_parameters = test_model._parameters
    print(f"model_parameters={model_parameters}")
    print(f"*"*50)
    model_buffers = test_model._buffers
    print(f"model_buffer={model_buffers}")
    print(f"*"*50)
    model_state_dict = test_model.state_dict()
    print(f"model_state_dict=\n{model_state_dict}")
    print(f"*"*50)
    model_state_dict_linear2 = test_model.state_dict()['linear2.weight']
    print(f"model_state_dict_linear2=\n{model_state_dict_linear2}")
    print(f"*"*50)
    model_named_para =list(test_model.named_parameters())
    print(f"model_named_para=\n{model_named_para}")
    print(f"*"*50)
    model_named_modules =list(test_model.named_modules())
    print(f"model_named_modules=\n{model_named_modules}")
    print(f"*"*50)
    model_named_buffers =list(test_model.named_buffers())
    print(f"model_named_buffers=\n{model_named_buffers}")
    print(f"*"*50)
    model_named_children =list(test_model.named_children())
    print(f"model_named_children=\n{model_named_children}")


  • 结果:
**************************************************
model_modules=
OrderedDict([('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))])
**************************************************
**************************************************
linear1=Linear(in_features=2, out_features=3, bias=True)
**************************************************
linear1.weight=
Parameter containing:
tensor([[-0.5518,  0.0687],
        [-0.7013,  0.4869],
        [-0.1157, -0.1287]], requires_grad=True)
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
linear1.weight.dtype=torch.float64
**************************************************
linear1.weight.dtype=torch.float32
**************************************************
model_parameters=OrderedDict()
**************************************************
model_buffer=OrderedDict()
**************************************************
model_state_dict=
OrderedDict([('linear1.weight', tensor([[-0.5518,  0.0687],
        [-0.7013,  0.4869],
        [-0.1157, -0.1287]])), ('linear1.bias', tensor([-0.2915, -0.4807,  0.0071])), ('linear2.weight', tensor([[ 0.4185,  0.1556,  0.1371],
        [ 0.4751,  0.2029, -0.0679],
        [ 0.1264, -0.0288, -0.3661],
        [ 0.4423, -0.5370,  0.3930]])), ('linear2.bias', tensor([ 0.2746, -0.1798,  0.0218,  0.5465])), ('batch_norm4.weight', tensor([1., 1., 1., 1.])), ('batch_norm4.bias', tensor([0., 0., 0., 0.])), ('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))])
**************************************************
model_state_dict_linear2=
tensor([[ 0.4185,  0.1556,  0.1371],
        [ 0.4751,  0.2029, -0.0679],
        [ 0.1264, -0.0288, -0.3661],
        [ 0.4423, -0.5370,  0.3930]])
**************************************************
model_named_para=
[('linear1.weight', Parameter containing:
tensor([[-0.5518,  0.0687],
        [-0.7013,  0.4869],
        [-0.1157, -0.1287]], requires_grad=True)), ('linear1.bias', Parameter containing:
tensor([-0.2915, -0.4807,  0.0071], requires_grad=True)), ('linear2.weight', Parameter containing:
tensor([[ 0.4185,  0.1556,  0.1371],
        [ 0.4751,  0.2029, -0.0679],
        [ 0.1264, -0.0288, -0.3661],
        [ 0.4423, -0.5370,  0.3930]], requires_grad=True)), ('linear2.bias', Parameter containing:
tensor([ 0.2746, -0.1798,  0.0218,  0.5465], requires_grad=True)), ('batch_norm4.weight', Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)), ('batch_norm4.bias', Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True))]
**************************************************
model_named_modules=
[('', MyModel(
  (linear1): Linear(in_features=2, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=4, bias=True)
  (batch_norm4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)), ('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]
**************************************************
model_named_buffers=
[('batch_norm4.running_mean', tensor([0., 0., 0., 0.])), ('batch_norm4.running_var', tensor([1., 1., 1., 1.])), ('batch_norm4.num_batches_tracked', tensor(0))]
**************************************************
model_named_children=
[('linear1', Linear(in_features=2, out_features=3, bias=True)), ('linear2', Linear(in_features=3, out_features=4, bias=True)), ('batch_norm4', BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2247475.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

URL在线编码解码- 加菲工具

URL在线编码解码 打开网站 加菲工具 选择“URL编码解码” 输入需要编码/解码的内容,点击“编码”/“解码”按钮 编码: 解码: 复制已经编码/解码后的内容。

魔众题库系统 v10.0.0 客服条、题目导入、考试导航、日志一大批更新

魔众题库系统基于PHP开发,可以用于题库管理和试卷生成软件,拥有极简界面和强大的功能,用户遍及全国各行各业。 魔众题库系统发布v10.0.0版本,新功能和Bug修复累计30项,客服条、题目导入、考试导航、日志一大批更新。 …

深入解析 EasyExcel 组件原理与应用

✨深入解析 EasyExcel 组件原理与应用✨ 官方:EasyExcel官方文档 - 基于Java的Excel处理工具 | Easy Excel 官网 在日常的 Java 开发工作中,处理 Excel 文件的导入导出是极为常见的需求。 今天,咱们就一起来深入了解一款非常实用的操作 Exce…

本地部署 MaskGCT

本地部署 MaskGCT 0. 更新系统和安装依赖项1. 克隆代码2. 创建虚拟环境3. 安装依赖模块4. 运行 MaskGCT5. 访问 MaskGCT 0. 更新系统和安装依赖项 sudo apt update sudo apt install espeak-ng1. 克隆代码 git clone https://github.com/engchina/learn-maskgct.git; cd lear…

线程控制方法之wait和sleep的区别

线程控制方法之wait和sleep的区别 wait()和sleep()都是Java线程控制方法,但存在明显区别: 所属与调用:wait()属Object类,需synchronized调用;sleep()属Thread类,可随意调用。锁处理:wait()释放…

Fakelocation Server服务器/专业版 Centos7

前言:需要Centos7系统 Fakelocation开源文件系统需求 Centos7 | Fakelocation | 任务一 更新Centos7 (安装下载不再赘述) sudo yum makecache fastsudo yum update -ysudo yum install -y kernelsudo reboot//如果遇到错误提示为 Another app is curre…

探索 RocketMQ:企业级消息中间件的选择与应用

一、关于RocketMQ RocketMQ 是一个高性能、高可靠、可扩展的分布式消息中间件,它是由阿里巴巴开发并贡献给 Apache 软件基金会的一个开源项目。RocketMQ 主要用于处理大规模、高吞吐量、低延迟的消息传递,它是一个轻量级的、功能强大的消息队列系统&…

基于信创环境的信息化系统运行监控及运维需求及策略

随着信息技术的快速发展和国家对信息安全的日益重视,信创环境(信息技术应用创新环境)的建设已成为行业发展的重要趋势。本指南旨在为运维团队在基于信创环境的系统建设及运维过程中提供参考,确保项目顺利实施并满足各项技术指标和…

初学 flutter 问题记录

windows搭建flutter运行环境 一、运行 flutter doctor遇到的问题 Xcmdline-tools component is missingRun path/to/sdkmanager --install "cmdline-tools;latest"See https://developer.android.com/studio/command-line for more details.1)cmdline-to…

【虚拟机】VMWare的CentOS虚拟机断电或强制关机出现问题

VMware 虚拟机因为笔记本突然断电故障了,开机提示“Entering emergency mode. Exit the shell to continue.”,如下图所示: 解决方法:输入命令: xfs_repair -v -L /dev/dm-0 注:报 no such file or direct…

设计模式:6、装饰模式(包装器)

目录 0、定义 1、装饰模式包含的四种角色 2、装饰模式的UML类图 3、示例代码 0、定义 动态地给对象添加一些额外的职责。就功能来说装饰模式相比生成子类更为灵活。 1、装饰模式包含的四种角色 抽象组件(Component):抽象组件是一个抽象…

Java开发经验——Spring Test 常见错误

摘要 本文详细介绍了Java开发中Spring Test的常见错误和解决方案。文章首先概述了Spring中进行单元测试的多种方法,包括使用JUnit和Spring Boot Test进行集成测试,以及Mockito进行单元测试。接着,文章分析了Spring资源文件扫描不到的问题&am…

Java基于Spring Boot框架的房屋租赁系统,附源码

博主介绍:✌Java老徐、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&…

单片机_简单AI模型训练与部署__从0到0.9

IDE: CLion MCU: STM32F407VET6 一、导向 以求知为导向,从问题到寻求问题解决的方法,以兴趣驱动学习。 虽从0,但不到1,剩下的那一小步将由你迈出。本篇主要目的是体验完整的一次简单AI模型部署流程&#x…

Java-08 深入浅出 MyBatis - 多对多模型 SqlMapConfig 与 Mapper 详细讲解测试

点一下关注吧!!!非常感谢!!持续更新!!! 大数据篇正在更新!https://blog.csdn.net/w776341482/category_12713819.html 目前已经更新到了: MyBatis&#xff…

IDEA使用tips(LTS✍)

一、查找项目中某个外部库依赖类的pom来源 1、显示图 2、导出Maven 项目依赖的可视化输出文件 3、点击要查找的目标类,项目中定位后复制依赖名称 4、在导出的依赖的可视化文件中搜索查找 5、综上得到,Around类来自于pom中的spring-boot-starter-aop:jar…

【LLM训练系列02】如何找到一个大模型Lora的target_modules

方法1:观察attention中的线性层 import numpy as np import pandas as pd from peft import PeftModel import torch import torch.nn.functional as F from torch import Tensor from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig from typ…

如何选择服务器

如何选择服务器 选择服务器时应考虑以下几个关键因素: 性能需求。根据网站的预期流量和负载情况,选择合适的处理器、内存和存储容量。考虑网站是否需要处理大量动态内容或高分辨率媒体文件。 可扩展性。选择一个可以轻松扩展的服务器架构,以便…

C++共享智能指针

C中没有垃圾回收机制,必须自己释放分配的内存,否则就会造成内存泄漏。解决这个问题最有效的方式是使用智能指针。 智能指针是存储指向动态分配(堆)对象指针的类,用于生存期的控制,能够确保在离开指针所在作用域时,自动…

python Flask指定IP和端口

from flask import Flask, request import uuidimport json import osapp Flask(__name__)app.route(/) def hello_world():return Hello, World!if __name__ __main__:app.run(host0.0.0.0, port5000)