[PyTorch][chapter 39][nn.Module]

news2025/1/22 1:39:15

前言:

 

        pytorch.nn是专门为神经网络设计的模块化接口. nn构建于autograd之上,可以用来定义和运行神经网络.是所有类的父类.

       

 目录:

  1.     基本结构
  2.     常用模块
  3.    container(容器)
  4.     CPU,GPU 部署
  5.     train-test 环境切换
  6.     flatten
  7.     MyLinear

    


一 基本结构

   

 

   

     1   继承 nn.Module

     2   super

          super类的作用是继承的时候,调用含super的各个的基类__init__函数,

          如果不使用super,就不会调用这些类的__init__函数,除非显式声明。

         而且使用super可以避免基类被重复调用。

    3  forward 

          前向传播

   nn.Module nested in Module

  可以通过嵌套方式,构建更加复杂的模型

  


二  常用模块

   nn.Module 包含了深度学习里面常用的一些函数

  2.1  torch.nn.Linear(in_features, # 输入的神经元个数
                                 out_features, # 输出神经元个数
                                 bias=True      # 是否包含偏置)

        功能:

         y=XW^T+b

   2.2 torch.nn.BatchNorm2d(num_features,

                                              eps=1e-05, momentum=0.1,

                                              affine=True, 

                                              track_running_stats=True,

                                              device=None, dtype=None)

       功能: 

        对于所有的batch中样本的同一个channel的数据元素进行标准化处理,即如果有C个通道,无论batch中有多少个样本,都会在通道维度上进行标准化处理,一共进行C次。

   

 2.3 nn.Conv2d

           torch.nn.Conv2d(
                   in_channels, 
                  out_channels, 
                  kernel_size, 
                  stride=1, 
                  padding=0, 
                   dilation=1, 
                   groups=1, 
                  bias=True, 
                  padding_mode='zeros', 
                  device=None, 
                  dtype=None)

  功能:

                图像卷积操作


三  container(容器)

    通过容器功能,PyTorch 可以像搭积木一样的方式,组合各种模型.

  • 模型容器

    作用

    nn.Sequential

    顺序性,各网络层之间严格按顺序执行,常用于block构建

     nn.ModuleList

    迭代性,常用于大量重复网构建,通过for循环实现重复构建

    nn.ModuleDict

    索引性,常用于可选择的网络层

       3.1 nn.Sequential

              一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到nn.Sequential()容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。

       例子

# -*- coding: utf-8 -*-
"""
Created on Fri Jun  9 13:46:07 2023

@author: chengxf2
"""
import torch
from torch import nn
from collections import OrderedDict


# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

常用技巧

list(net.named_parameters())[0]
dict(net.named_parameters()).items()
optimezer = optim.SGD(net.parameters(),lr=1e-3)

  3.2  nn.ModuleList

            nn.ModuleList,它是一个存储不同module,并自动将每个module的parameters添加到网络之中的容器。但nn.ModuleList并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言

# -*- coding: utf-8 -*-
"""
Created on Fri Jun  9 14:02:48 2023

@author: chengxf2
"""

"""
Created on Fri Jun  9 13:46:07 2023

@author: chengxf2
"""
import torch
from torch import nn

class MyNet(nn.Module):
    
    def __init__(self):
        
        super(MyNet, self).__init__()
        
        self.linears = nn.ModuleList([nn.Linear(3,4)  for i in range(3)]  )
        
    
    def forward(self, x):
        
        for m in self.linears:
            
            x = m(x)
        
        return x
    
net = MyNet()
print("\n net ")
print(net)
print("\n parameters ")
print(list(net.parameters()))

输出:通过ModuleList 构建了一个小模型,该模型由三个 线性层 组成

 如下 相对于Sequential, ModuleList  模块之间并没有什么先后顺序可言

# -*- coding: utf-8 -*-
"""
Created on Fri Jun  9 14:02:48 2023

@author: chengxf2
"""

"""
Created on Fri Jun  9 13:46:07 2023

@author: chengxf2
"""
import torch
from torch import nn

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
    def forward(self, x):
        x = self.linears[2](x)
        x = self.linears[0](x)
        x = self.linears[1](x) 
        return x
 
net = MyNet()
print(MyNet)
# net3(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=20, bias=True)
#     (1): Linear(in_features=20, out_features=30, bias=True)
#     (2): Linear(in_features=5, out_features=10, bias=True)
#   )
# )
input = torch.randn(32, 5)
print(net(input).shape)
# torch.Size([32, 30])
    


3.3 nn.ModuleDict

        将所有的子模块放到一个字典中。

   ModuleDict 可以像常规 Python 字典一样进行索引,但它包含的模块已正确注册,所有 Module 方法都可以看到。ModuleDict 是一个有序字典。

参数:

modules (iterable, optional) – 一个(string: module)映射(字典)或者可迭代的键值对。

方法:

 clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除

# -*- coding: utf-8 -*-
"""
Created on Fri Jun  9 14:02:48 2023

@author: chengxf2
"""

"""
Created on Fri Jun  9 13:46:07 2023

@author: chengxf2
"""
import torch
from torch import nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(in_channels=3,out_channels= 10, kernel_size=3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['relu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

net = MyModule()
img = torch.randn((1, 3, 8, 8))

output = net(img, 'conv', 'relu')

print(output.shape)

一个完整的例子

# -*- coding: utf-8 -*-
"""
Created on Mon Jun 12 14:17:19 2023

@author: chengxf2
"""

import  torch
import  torch.nn as nn

class BasicNet(nn.Module):
    
    def __init__(self):
        super(BasicNet,self).__init__()
        
        self.net = nn.Linear(in_features =4, out_features=3)
        
    
    def forward(self,x):
        
        out = self.net(x)
        return out
    

class Net(nn.Module):
    
    def __init__(self):
        
        super(Net, self).__init__()
        
        self.net = nn.Sequential(BasicNet(),
                                nn.ReLU(),
                                nn.Linear(in_features=3, out_features=2)
                                 )
        
    def forward(self,x):
        
        out = self.net(x)
        
        return out
    

if __name__ == "__main__":
    
    data = torch.rand((2,4))
    
    model = Net()
    
    out = model(data)
    print(out)
    
    parame = list(model.parameters())
    
    print("\n parameters \n",parame)
    
    name_parame = list(model.named_parameters())
    print("\n name_parame \n",name_parame)

四  CPU,GPU 部署


五 save and load

  为了防止意外情况,我们每训训练一些次数后,需要把当前的参数

保存到本地磁盘中,后面再次训练后,可以通过磁盘文件直接加载

#保存已经训练好的参数到 net.mdl
torch.save(net.state_dict(), 'net.mdl')

#模型开始的时候加载,通过net.mdl 里面的值初始化网络参数
net.load_state_dict(torch.load('net.mdl'))

  5.1 通过torch.save保存模型,

     torch.save函数将序列化的对象保存到磁盘。此函数使用Python的pickle进行序列化。通过pickle可以保存各种对象的模型、张量和字典。

 5.2  torch.load加载模型

   torch.save和torch.load函数的实现在torch/serialization.py文件中。

    torch.load函数使用pickle的unpickling将pickle对象文件反序列化到内存中

torch.nn.Module的state_dict函数:

 在PyTorch中,torch.nn.Module模型的可学习参数(即weights和biases)包含在模型的参数中(通过model.parameters函数访问)。

  state_dict只是一个Python字典对象,它将每一层映射到其参数张量(tensor)。

  注意:只有具有可学习参数的层(卷积层,线性层等)和注册缓冲区(batchnorm’s running_mean)在模型的state_dict中有条目( Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict)

    优化器对象(torch.optim)也有一个state_dict,其中包含有关优化器状态的信息,以及使用的超参数。因为state_dict对象是Python字典,所以它们可以很容易地保存、更新、更改和恢复。


六  train-test 环境切换

       train 结束后,test 之前 必须加上 torch.eval

       model的eval方法主要是针对某些在train和predict两个阶段会有不同参数的层。

      比如Dropout层和BN层

Dropout:

  train阶段: 随机选择神经元, 

  test 阶段 : 使用全部神经元并且要乘一个补偿系数

BN层

     输出Y与输入X之间的关系是:Y = (X - running_mean) / sqrt(running_var + eps) * gamma + beta,其中gamma、beta为可学习参数(在pytorch中分别改叫weight和bias),

     train 阶段 :训练时通过反向传播更新;而running_mean、running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。所以在训练阶段,running_mean和running_var在每次前向时更新一次;

    test 阶段,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。


六  flatten

 

     我们在通过某些网络后得到的张量需要打平,得到一个[1,n] 维的张量, 输入到全连接网络里面去训练.

可以通过flatten 处理

实现原理如下:

     假设类型为 torch.tensor 的张量 a 的形状如下所示:(2,4,3,5,6),则 torch.flatten(a, 1, 3).shape 的结果为 (2, 60, 6)。 将索引为 start_dim 和 end_dim 之间(包括该位置)的数量相乘,其余位置不变。

也可以通过如下方式,进行Flatten.

# -*- coding: utf-8 -*-
"""
Created on Fri Jun  9 14:59:06 2023

@author: chengxf2
"""
import torch
from torch import nn

class Flatten(nn.Module):
    
    def __init__(self):
        
        super(Flatten, self).__init__()
        
    def forward(self, input):
        
        return input.view(input.size(0),-1) #[b,n]
    

class TestNet(nn.Module):
    
    def __init__(self):
        
        super(TestNet, self).__init__()
        
        self.net = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3),
        nn.MaxPool2d(2,2),
        Flatten(),
        nn.Linear(in_features=1*14*14, out_features=10))    
    
    def forward(self, x):
          
          return self.net(x)
    

 


七 MyLinear

 

      当自己定义某些张量时候,必须加载到nn.Parameter方法中管理,会自动的加上reuire_grad =True属性,可以被SGD 优化。

     如果不写nn.Parameter  ,必须加上require_grad=True,但是管理不方便.

 

 



原文链接:

参考:

pytorch小记:nn.ModuleList和nn.Sequential的用法以及区别_慕思侣的博客-CSDN博客

https://blog.csdn.net/fengbingchun/article/details/125706670

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/638593.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【kali】设置系统方式为中文

目录 1、右击终端 2、输入命令回车 3、需要输入当前账户密码 4、选择语言 5、再次确定 6、输入命令重启 1、右击终端 2、输入命令回车 sudo dpkg-reconfigure locales 3、需要输入当前账户密码 4、选择语言 依次选中en_US.UTF-8 / zh_CN.GBK(没找到&#xf…

野火启明RenesasRA4M2 UDS诊断bootloader 升级MCU

基于can总线的UDS软件升级 最近学习UDS诊断协议(ISO14229),是一项国际标准,为汽车电子系统中的诊断通信定义了统一的协议和服务。它规定了与诊断相关的服务需求,并没有设计通信机制。ISO14229仅对应用层和会话层做出了…

微信一天可以加多少个好友?

微信作为最大的私域流量池,几乎所有的人都会往微信引流,而微信每天加好友数量是有严格限制的。微信每天加多少人不会封号?微信每天加多少好友才不会被限制?微信频繁加好友被限制怎么办?请跟随小编的脚步一起往下看吧。…

IP地址定位原理

IP地址定位是一种通过IP地址来确定位置的技术,在互联网和移动网络的应用十分广泛。本文将介绍IP地址定位的原理和实现方式。 IP地址定位原理 IP地址是Internet Protocol(简称IP)的缩写,是互联网上的一个地址标识符用于识别连接到…

合成化学物:169219-08-3,Fmoc-Thr(Ac4Manα)-OH,一种甘露糖苏氨酸

Fmoc-Thr(Ac4Manα)-OH,甘露糖苏氨酸,供应商:陕西新研博美生物科技有限公司产品结构式: 产品规格: 1.CAS号:169219-08-3 2.分子式:C33H37NO14 3.分子量:671.65 4.包装规格&#xff1…

K8s in Action 阅读笔记——【14】Securing cluster nodes and the network

K8s in Action 阅读笔记——【14】Securing cluster nodes and the network 迄今为止,创建了 Pod 而不考虑它们允许消耗多少 CPU 和内存。但是,正如将在本章中看到的那样,设置 Pod 预期消耗和允许消耗的最大数量是任何 Pod 定义的重要部分。…

如何进行JMeter分布式压测?一个案例教你详细解读!

目录 引言 一、什么是压力测试? 二、什么是分布式测试? 三、为什么要使用分布式压力测试? 四、主流压力测试工具对比 五、Jmeter分布式压测原理 六、Jmeter分布式压测前的准备工作 七、阿里云服务器上进行分布式压测 八、系统架构学…

ATTCK v13版本战术介绍——凭证访问(二)

一、引言 在前几期文章中我们介绍了ATT&CK中侦察、资源开发、初始访问、执行、持久化、提权、防御规避战术,本期我们为大家介绍ATT&CK 14项战术中凭证访问战术第7-12种子技术,后续会介绍凭证访问其他子技术,敬请关注。 二、ATT&…

这两个小众的资源搜索工具其实很好用

01 小不点搜索是一个中国网络技术公司开发的网盘搜索引擎,该网站通过与多个主流网盘进行整合,为用户提供一种快速查找和下载文件的方式。小不点搜索因其高效性、便利性和实用性受到了广大用户的喜爱。 在技术实现上,小不点搜索拥有先进的搜…

C++项目打包成可调用dll文件python调用

目录 1.原项目如图 2.直接在项目对应地方新增dll.h,dll.cpp 3.改变工程的配置类型---动态库(.dll) 4.生成解决方案----可调用dll文件 5.查找dll依赖的其他dll 6.python调用dll 7.python调用dll打包成exe 相关dll要放一个文件夹 1.原项目如图 包括头文件uiaccess.h&#xff0…

Linux中Crontab(定时任务)命令详解及使用教程

Crontab介绍:Linux crontab是用来crontab命令常见于Unix和类Unix的操作系统之中,用于设置周期性被执行的指令。该命令从标准输入设备读取指令,并将其存放于“crontab”文件中,以供之后读取和执行。该词来源于希腊语 chronos(χρ?…

【云计算】Ubuntu多种安装docker方式

文章目录 前言一、docker官网二、安装docker1、第一种方式(官方)2、使用脚本安装(阿里云):3、使用官方脚本安装:拉取镜像(solo博客部署) 前言 Docker是一款开源的容器化平台&#x…

Misc(4)

RAR 打开是个加密压缩文件,给了提示说是4位纯数字加密 暴力破解为4位数字,获得flag QR 下载下来是一个二维码 很简单,利用工具一下就解出来了 镜子里的世界 打开后是一张图片,根据题目的提示,猜想可能是镜像翻转后得到…

java springboot整合MyBatis做数据库查询操作

首先 我们还是要搞清楚 MyBatis 工作中都需要那些东西 首先是基础配置 你要连哪里的数据 连什么类型的数据库 以什么权限去连 然后 以 注解还是xml方式去映射sql 好 我们直接上代码 我们先创建一个文件夹 然后打开idea 新建一个项目 然后 按我下图的操作配置一下 然后点下一…

小米秋招笔试题(强化基础)

1、已知const arr [A, B, C, D, E, F, G],下面可以获取数组最后一项的表达式有 A arr[6] B arr.pop() C arr.shift() D arr.unshift() 答案: AB 解析: shift() 方法用于把数组的第一个元素从其中删除,并返回第一个元素的值。…

Python自动化测试框架:Pytest和Unittest的区别

pytest和unittest是Python中常用的两种测试框架,它们都可以用来编写和执行测试用例,但两者在很多方面都有所不同。本文将从不同的角度来论述这些区别,以帮助大家更好地理解pytest和unittest。 1. 原理 pytest是基于Python的assert语句和Pyth…

leetcode 124.二叉树中的最大路径和

1.题目 二叉树中的 路径 被定义为一条节点序列,序列中每对相邻节点之间都存在一条边。同一个节点在一条路径序列中 至多出现一次 。该路径 至少包含一个 节点,且不一定经过根节点。 路径和 是路径中各节点值的总和。 给你一个二叉树的根节点 root &…

【Spring Cloud】Spring Cloud Alibaba-- 分布式事务Seata原理

文章目录 前言一、Seata 介绍1.1、Seata 简介1.2、Seata 的核心组件1.3、Seata 的整体执行流程 二、Seata 的 AT 模式原理2.1、AT 模式的整体执行流程2.2、AT 模式两阶段详细流程2.2.1、第一阶段的详细执行流程2.2.2、第二阶段提交的详细执行流程2.2.3、第二阶段回滚的详细执行…

独家揭秘:Kotlin K2编译器的前世今生

独家揭秘:Kotlin K2编译器的前世今生 也许您已经观看了最近的 KotlinConf 2023 主题演讲,关于 K2 编译器的更新。什么是 K2 编译器? 在搞清楚这个问题之前,我们需要了解Kotlin 使用的不同种类的编译器及其差异,以及编…

Python--数据类型

Python--数据类型 <font colorblue >一、数据的分类<font colorblue >二、数值类型<font colorblue >1、整型&#xff1a;int<font colorblue >2、浮点型&#xff1a;float<font colorblue >3、复数类型&#xff1a;complex <font colorblue …