7、深入剖析PyTorch nn.Module源码

news2024/11/25 19:08:52

文章目录

  • 1. 重要类
  • 2. add_modules
  • 3. Apply(fn)
  • 4. register_buffer
  • 5. nn.Parameters®ister_parameters
  • 6. 后续测试

1. 重要类

  • nn.module --> 所有神经网络的父类,自定义神经网络需要继承此类,并且自定义__init__,forward函数即可:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :MyModelNet.py
# @Time      :2024/11/20 13:38
# @Author    :Jason Zhang
import torch
from torch import nn


class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork,self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


if __name__ == "__main__":
    run_code = 0
    x_row = 28
    x_column = 28
    x_total = x_row * x_column
    x = torch.arange(x_total, dtype=torch.float).reshape((1, x_row, x_column))
    my_net = NeuralNetwork()
    y = my_net(x)
    print(f"y.shape={y.shape}")
    print(my_net)
  • 结果:
y.shape=torch.Size([1, 10])
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

2. add_modules

通过add_modules在旧的网络里面添加新的网络

  • 重点: 用nn.ModuleList自带的insert,新的网络继承自老网络中,直接用按位置插入
  • python
import torch
from torch import nn
from pytorch_model_summary import summary

torch.manual_seed(2323)


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = nn.Flatten()
        self.block = nn.ModuleList([
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        ])

    def forward(self, x):
        x = self.flatten(x)
        for layer in self.block:
            x = layer(x)
        return x


class MyNewNet(MyModel):
    def __init__(self):
        super(MyNewNet, self).__init__()
        self.block.insert(2, nn.Linear(512, 256))  # 插入新层
        self.block.insert(3, nn.ReLU())  # 插入新的激活函数
        self.block.insert(4, nn.Linear(256, 512))  # 插入另一层
        self.block.insert(5, nn.ReLU())  # 插入激活函数


if __name__ == "__main__":
    # 测试原始模型
    my_model = MyModel()
    print("Original Model:")
    print(summary(my_model, torch.ones((1, 28, 28))))

    # 测试新模型
    my_new_model = MyNewNet()
    print("\nNew Model:")
    print(summary(my_new_model, torch.ones((1, 28, 28))))
  • 结果:
Original Model:
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================
         Flatten-1            [1, 784]               0               0
          Linear-2            [1, 512]         401,920         401,920
            ReLU-3            [1, 512]               0               0
          Linear-4             [1, 10]           5,130           5,130
=======================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
-----------------------------------------------------------------------

New Model:
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================
         Flatten-1            [1, 784]               0               0
          Linear-2            [1, 512]         401,920         401,920
            ReLU-3            [1, 512]               0               0
          Linear-4            [1, 256]         131,328         131,328
            ReLU-5            [1, 256]               0               0
          Linear-6            [1, 512]         131,584         131,584
            ReLU-7            [1, 512]               0               0
          Linear-8             [1, 10]           5,130           5,130
=======================================================================
Total params: 669,962
Trainable params: 669,962
Non-trainable params: 0
-----------------------------------------------------------------------

3. Apply(fn)

模型权重weight,bias 的初始化

  • python
import torch.nn as nn
import torch


class MyAwesomeModel(nn.Module):
    def __init__(self):
        super(MyAwesomeModel, self).__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(4, 5)
        self.fc3 = nn.Linear(5, 6)


# 定义初始化函数
@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)


# 创建神经网络实例
model = MyAwesomeModel()

# 应用初始化权值函数到神经网络上
model.apply(init_weights)
  • 结果:
Linear(in_features=3, out_features=4, bias=True)
Parameter containing:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], requires_grad=True)
Linear(in_features=4, out_features=5, bias=True)
Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)
Linear(in_features=5, out_features=6, bias=True)
Parameter containing:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], requires_grad=True)
MyAwesomeModel(
  (fc1): Linear(in_features=3, out_features=4, bias=True)
  (fc2): Linear(in_features=4, out_features=5, bias=True)
  (fc3): Linear(in_features=5, out_features=6, bias=True)
)

Process finished with exit code 0

4. register_buffer

将模型中添加常数项。比如加1

  • python:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :RegisterBuffer.py
# @Time      :2024/11/23 19:21
# @Author    :Jason Zhang
import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.register_buffer("my_buffer_a", torch.ones(2, 3))

    def forward(self, x):
        x = x + self.my_buffer_a
        return x


if __name__ == "__main__":
    run_code = 0
    my_test = MyNet()
    in_x = torch.arange(6).reshape((2, 3))
    y = my_test(in_x)
    print(f"x=\n{in_x}")
    print(f"y=\n{y}")
  • 结果:
x=
tensor([[0, 1, 2],
        [3, 4, 5]])
y=
tensor([[1., 2., 3.],
        [4., 5., 6.]])

5. nn.Parameters&register_parameters

  • python
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :ParameterTest.py
# @Time      :2024/11/23 19:37
# @Author    :Jason Zhang
import torch
from torch import nn


class MyModule(nn.Module):
    def __init__(self, in_size, out_size):
        self.in_size = in_size
        self.out_size = out_size
        super(MyModule, self).__init__()
        self.test = torch.rand(self.in_size, self.out_size)
        self.linear = nn.Linear(self.in_size, self.out_size)

    def forward(self, x):
        x = self.linear(x)
        return x


class MyModuleRegister(nn.Module):
    def __init__(self, in_size, out_size):
        self.in_size = in_size
        self.out_size = out_size
        super(MyModuleRegister, self).__init__()
        self.test = torch.rand(self.in_size, self.out_size)
        self.linear = nn.Linear(self.in_size, self.out_size)

    def forward(self, x):
        x = self.linear(x)
        return x


class MyModulePara(nn.Module):
    def __init__(self, in_size, out_size):
        self.in_size = in_size
        self.out_size = out_size
        super(MyModulePara, self).__init__()
        self.test = nn.Parameter(torch.rand(self.in_size, self.out_size))
        self.linear = nn.Linear(self.in_size, self.out_size)

    def forward(self, x):
        x = self.linear(x)
        return x


if __name__ == "__main__":
    run_code = 0
    test_in = 4
    test_out = 6
    my_test = MyModule(test_in, test_out)
    my_test_para = MyModulePara(test_in, test_out)
    test_list = list(my_test.named_parameters())
    test_list_para = list(my_test_para.named_parameters())
    my_test_register = MyModuleRegister(test_in, test_out)
    para_register = nn.Parameter(torch.rand(test_in, test_out))
    my_test_register.register_parameter('para_add_register', para_register)
    test_list_para_register = list(my_test_register.named_parameters())

    print(f"*" * 50)
    print(f"test_list=\n{test_list}")
    print(f"*" * 50)
    print(f"*" * 50)
    print(f"test_list_para=\n{test_list_para}")
    print(f"*" * 50)
    print(f"*" * 50)
    print(f"test_list_para_register=\n{test_list_para_register}")
    print(f"*" * 50)
  • 结果:
**************************************************
test_list=
[('linear.weight', Parameter containing:
tensor([[ 0.3805, -0.3368,  0.2348,  0.4525],
        [-0.4557, -0.3344,  0.1368, -0.3471],
        [-0.3961,  0.3302,  0.1904, -0.0111],
        [ 0.4542, -0.3325, -0.3782,  0.0376],
        [ 0.2083, -0.3113, -0.3447, -0.1503],
        [ 0.0343,  0.0410, -0.4216, -0.4793]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([-0.3465, -0.4510,  0.4919,  0.1967, -0.1366, -0.2496],
       requires_grad=True))]
**************************************************
**************************************************
test_list_para=
[('test', Parameter containing:
tensor([[0.1353, 0.9934, 0.0462, 0.2103, 0.3410, 0.0814],
        [0.7509, 0.2573, 0.8030, 0.0952, 0.1381, 0.5360],
        [0.1972, 0.1241, 0.5597, 0.2691, 0.3226, 0.0660],
        [0.3333, 0.8031, 0.9226, 0.4290, 0.3660, 0.6159]], requires_grad=True)), ('linear.weight', Parameter containing:
tensor([[-0.0633, -0.4030, -0.4962,  0.1928],
        [-0.1707,  0.2259,  0.0373, -0.0317],
        [ 0.4523,  0.2439, -0.1376, -0.3323],
        [ 0.3215,  0.1283,  0.0729,  0.3912],
        [ 0.0262, -0.1087,  0.4721, -0.1661],
        [-0.1055, -0.2199, -0.4974, -0.3444]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([ 0.3702, -0.0142, -0.2098, -0.0910, -0.2323, -0.0546],
       requires_grad=True))]
**************************************************
**************************************************
test_list_para_register=
[('para_add_register', Parameter containing:
tensor([[0.2428, 0.1388, 0.6612, 0.4215, 0.0215, 0.2618],
        [0.4234, 0.0160, 0.8947, 0.4784, 0.4403, 0.4800],
        [0.8845, 0.1469, 0.6894, 0.7050, 0.5911, 0.7702],
        [0.7694, 0.0491, 0.3583, 0.4451, 0.2282, 0.4293]], requires_grad=True)), ('linear.weight', Parameter containing:
tensor([[ 0.1358, -0.4704, -0.4181, -0.4504],
        [ 0.0903,  0.3235, -0.3164, -0.4163],
        [ 0.1342,  0.3108,  0.0612, -0.2910],
        [ 0.3527,  0.3397, -0.0414, -0.0408],
        [-0.4877,  0.1925, -0.2912, -0.2239],
        [-0.0081, -0.1730,  0.0921, -0.4210]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([-0.2194,  0.2233, -0.4950, -0.3260, -0.0206, -0.0197],
       requires_grad=True))]
**************************************************

6. 后续测试

  • register_module
  • get_submodule
  • get_parameter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2247452.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何选择服务器

如何选择服务器 选择服务器时应考虑以下几个关键因素: 性能需求。根据网站的预期流量和负载情况,选择合适的处理器、内存和存储容量。考虑网站是否需要处理大量动态内容或高分辨率媒体文件。 可扩展性。选择一个可以轻松扩展的服务器架构,以便…

C++共享智能指针

C中没有垃圾回收机制,必须自己释放分配的内存,否则就会造成内存泄漏。解决这个问题最有效的方式是使用智能指针。 智能指针是存储指向动态分配(堆)对象指针的类,用于生存期的控制,能够确保在离开指针所在作用域时,自动…

python Flask指定IP和端口

from flask import Flask, request import uuidimport json import osapp Flask(__name__)app.route(/) def hello_world():return Hello, World!if __name__ __main__:app.run(host0.0.0.0, port5000)

虚幻引擎---初识篇

一、学习途径 虚幻引擎官方文档:https://dev.epicgames.com/documentation/zh-cn/unreal-engine/unreal-engine-5-5-documentation虚幻引擎在线学习平台:https://dev.epicgames.com/community/unreal-engine/learning哔哩哔哩:https://www.b…

Java开发经验——SpringRestTemplate常见错误

摘要 本文分析了在使用Spring框架的RestTemplate发送表单请求时遇到的常见错误。主要问题在于将表单参数错误地以JSON格式提交,导致服务器无法正确解析参数。文章提供了错误案例的分析,并提出了修正方法。 1. 表单参数类型是MultiValueMap RestControl…

oracle会话追踪

一 跟踪当前会话 1.1 查看当前会话的SID,SERIAL# #在当前会话里执行 示例: SQL> select distinct userenv(sid) from v$mystat; USERENV(SID) -------------- 1945 SQL> select distinct sid,serial# from v$session where sid1945; SID SERIAL# …

数据可视化复习2-绘制折线图+条形图(叠加条形图,并列条形图,水平条形图)+ 饼状图 + 直方图

目录 目录 一、绘制折线图 1.使用pyplot 2.使用numpy ​编辑 3.使用DataFrame ​编辑 二、绘制条形图(柱状图) 1.简单条形图 2.绘制叠加条形图 3.绘制并列条形图 4.水平条形图 ​编辑 三、绘制饼状图 四、绘制散点图和直方图 1.散点图 2…

postgresql按照年月日统计历史数据

1.按照日 SELECT a.time,COALESCE(b.counts,0) as counts from ( SELECT to_char ( b, YYYY-MM-DD ) AS time FROM generate_series ( to_timestamp ( 2024-06-01, YYYY-MM-DD hh24:mi:ss ), to_timestamp ( 2024-06-30, YYYY-MM-DD hh24:mi:ss ), 1 days ) AS b GROUP BY tim…

【JavaEE初阶 — 多线程】定时器的应用及模拟实现

目录 1. 标准库中的定时器 1.1 Timer 的定义 1.2 Timer 的原理 1.3 Timer 的使用 1.4 Timer 的弊端 1.5 ScheduledExecutorService 2. 模拟实现定时器 2.1 实现定时器的步骤 2.1.1 定义类描述任务 定义类描述任务 第一种定义方法 …

一文学会Golang里拼接字符串的6种方式(性能对比)

g o l a n g golang golang的 s t r i n g string string类型是不可修改的,对于拼接字符串来说,本质上还是创建一个新的对象将数据放进去。主要有以下几种拼接方式 拼接方式介绍 1.使用 s t r i n g string string自带的运算符 ans ans s2. 使用…

LeetCode 3244.新增道路查询后的最短距离 II:贪心(跃迁合并)-9行py(O(n))

【LetMeFly】3244.新增道路查询后的最短距离 II:贪心(跃迁合并)-9行py(O(n)) 力扣题目链接:https://leetcode.cn/problems/shortest-distance-after-road-addition-queries-ii/ 给你一个整数 n 和一个二维…

MyBatis中特殊SQL的执行

目录 1.模糊查询 2.批量删除 3.动态设置表名 4.添加功能获取自增的主键 1.模糊查询 List<User> getUserByLike(Param("username") String username); <select id"getUserByLike" resultType"com.atguigu.mybatis.pojo.User">&…

ES 基本使用与二次封装

概述 基本了解 Elasticsearch 是一个开源的分布式搜索和分析引擎&#xff0c;基于 Apache Lucene 构建。它提供了对海量数据的快速全文搜索、结构化搜索和分析功能&#xff0c;是目前流行的大数据处理工具之一。主要特点即高效搜索、分布式存储、拓展性强 核心功能 全文搜索:…

Azkaban部署

首先我们需要现在相关的组件&#xff0c;在这里已经给大家准备好了相关的安装包&#xff0c;有需要的可以自行下载。 只需要启动hadoop集群就可以&#xff0c;如果现在你的hive是打开的&#xff0c;那么请你关闭&#xff01;&#xff01;&#xff01; 如果不关会造成证书冲突…

Jmeter中的定时器

4&#xff09;定时器 1--固定定时器 功能特点 固定延迟&#xff1a;在每个请求之间添加固定的延迟时间。精确控制&#xff1a;可以精确控制请求的发送频率。简单易用&#xff1a;配置简单&#xff0c;易于理解和使用。 配置步骤 添加固定定时器 右键点击需要添加定时器的请求…

JavaEE初学07

JavaEE初学07 MybatisORMMybatis一对一结果映射一对多结果映射 Mybatis动态sqlif标签trim标签where标签set标签foreach标签补充 Mybatis Mybatis是一款优秀的持久层框架&#xff0c;他支持自定义SQL、存储过程以及高级映射。Mybatis几乎免除了所有的JDBC代码以及设置参数和获取…

【layui】table的switch、edit修改

<title>简单表格数据</title><div class"layui-card layadmin-header"><div class"layui-breadcrumb" lay-filter"breadcrumb"><a>系统设置</a><a>简单表格数据</a></div> </div>&…

工具使用_docker容器_crossbuild

1. 工具简介 2. 工具使用 拉取 multiarch/crossbuild 镜像&#xff1a; docker pull multiarch/crossbuild 创建工作目录和示例代码&#xff1a; mkdir -p ~/crossbuild-test cd ~/crossbuild-test 创建 helloworld.c &#xff1a; #include <stdio.h>int main() …

Android 天气APP(三十七)新版AS编译、更新镜像源、仓库源、修复部分BUG

上一篇&#xff1a;Android 天气APP&#xff08;三十六&#xff09;运行到本地AS、更新项目版本依赖、去掉ButterKnife 新版AS编译、更新镜像源、仓库源、修复部分BUG 前言正文一、更新镜像源① 腾讯源③ 阿里源 二、更新仓库源三、修复城市重名BUG四、地图加载问题五、源码 前…

基于Java Springboot海洋馆预约系统

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术&#xff1a;Html、Css、Js、Vue、Element-ui 数据库&#xff1a;MySQL 后端技术&#xff1a;Java、Spring Boot、MyBatis 三、运行环境 开发工具&#xff1a;IDEA/eclipse 数据…