第三章 模型篇:模型与模型的搭建

news2025/1/12 5:57:54

写在前面的话
这部分只解释代码,不对线性层(全连接层),卷积层等layer的原理进行解释。
尽量写的比较全了,但是自身水平有限,不太确定是否有遗漏重要的部分。

教程参考:
https://pytorch.org/tutorials/
https://github.com/TingsongYu/PyTorch_Tutorial
https://github.com/yunjey/pytorch-tutorial


文章目录

  • 模型的定义
    • nn.Module()
    • nn.Parameters()
    • module中的register是如何实现的
    • Module()的一些方法
      • add_module()
      • children() 和 named_children()
      • parameters() 和 named_parameters()
      • apply(fn)
      • modules() and get_submodule(target)
      • state_dict()
  • 模型的搭建
    • 设定训练设备
    • 定义自己的网络类
    • 一些相关的方法
      • nn.Sequential()
      • nn.ModuleList()
      • nn.ModuleDict()

模型的定义

模型,也就是我们常说的神经网络。它由大量相连连同的节点组成,形成类似于人体内神经的结构,所以被称为神经网络。在使用时,数据会通过网络中一层一层的节点,经过运算后得到一个结果。
神经网络,就由在数据上执行计算操作的layers(层)或者modules(模块)组成。torch.nn提供了我们组成一个神经网络需要的所有单位原件,我们可以使用torch.nn下的各个class,来组成我们的神经网络。

nn.Module()

在pytorch中所有的module都继承了nn.Module(),都是它的子类。一个神经网络也是一个module,只不过它本身包含了其它别的module。
源码链接:https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module

nn.Parameters()

源码链接:https://pytorch.org/docs/stable/_modules/torch/nn/parameter.html#Parameter
nn.Parameter()它并没有继承nn.Module(), 而是继承了torch.Tensor()。它同样被放到torch.nn这个模块下面,是因为它在用于nn.Module()时会呈现出和一般的tensor不一样的特征。
它会天然地被添加到Module.parameters中去,作为一个可训练的参数使用。

module中的register是如何实现的

在构建网络时,我们基于nn.Module()定义我们自己的模型的class,并在初始化的过程中使用多个不同的layer或者module来组成我们的模型。这些layer和module都会被register到网络中,方便我们使用参数名进行访问。
比如说我们现在定义一个非常简单的网络。这个网络在初始化时定义了三个变量,分别是self.t1:一个普通的tensor,self.p1:一个parameter和self.conv1:一个卷积层。这个卷积层,同样也继承了nn.Module(),它会被储存在Module._modules中。
在这里插入图片描述

self._modules会在你构建网络的过程中进行更新,更具体的讲,在你执行obj.name = value的命令时,一个名为 __setattr__的函数会起作用,判断你所构建的变量的类型。
比如说你的变量的类型是"Parameter",那么它就会被加到self._parameters中去;如果你的变量的类型是"Module",那么它就会被加到self._modules中去。下方举例了添加module的代码,具体可以参考源码链接。

[docs]    def add_module(self, name: str, module: Optional['Module']) -> None:
        r"""Adds a child module to the current module.

        The module can be accessed as an attribute using the given name.

        Args:
            name (str): name of the child module. The child module can be
                accessed from this module using the given name
            module (Module): child module to be added to the module.
        """
        if not isinstance(module, Module) and module is not None:
            raise TypeError("{} is not a Module subclass".format(
                torch.typename(module)))
        elif not isinstance(name, str):
            raise TypeError("module name should be a string. Got {}".format(
                torch.typename(name)))
        elif hasattr(self, name) and name not in self._modules:
            raise KeyError("attribute '{}' already exists".format(name))
        elif '.' in name:
            raise KeyError("module name can't contain \".\", got: {}".format(name))
        elif name == '':
            raise KeyError("module name can't be empty string \"\"")
        for hook in _global_module_registration_hooks.values():
            output = hook(self, name, module)
            if output is not None:
                module = output
        self._modules[name] = module

在上方,我们已经使用net = Model()实例化了我们的网络,现在来看一下网络里面的参数情况。
可以看到,我们的’p1’,因为类别是"Parameter",所以它被添加到了net._parameters中去;我们的’conv1’,类别是"Module",所以它被添加到了net._modules中去;而我们的t1,因为啥也不是,所以单独地被放到了net.t1。和普通的class中的属性没有什么区别。
在这里插入图片描述

Module()的一些方法

这里举例的方法并不是很全,主要是介绍了一些我认为可以去了解的函数。更多的细节还是要自己查询文档。
https://pytorch.org/docs/stable/generated/torch.nn.Module.html

add_module()

在上方我们提供了add_module()的源码,它起到的就是register_module()的作用。
除了使用obj.name = value类似的命令定义网络中的module以外,我们也可以使用self.add_module(name, value)的方法,两者是等价的。需要注意的是,这里的value必须是一个module。
如下图,可以看到,我们用两种方法,都成功将一个卷积层加入到了net._modules中去。
在这里插入图片描述
当我们想使用add_module()方法加入一个非module类型的变量时,则会出现报错。从方便的角度讲,一般我们也不会使用这样的方法来构建我们的模型,还是obj.name = value更为简单常用。
在这里插入图片描述

children() 和 named_children()

children()和named_children()都返回了一个迭代器,两者也是很好区分。children()只返回了定义的模型中的module,而named_children()在返回module的同时,还返回了module的名字。
在这里插入图片描述

parameters() 和 named_parameters()

类似于上面的children()和named_children(),parameters()和named_parameters()同样也返回了一个迭代器,只不过迭代器中的内容不再是module和它的名字,而是换成了module._parameters。
我们在这里使用Linear层做例子。
在这里插入图片描述

apply(fn)

apply()的作用是在你模型的所有module上执行同一个函数,因此输入参数是一个函数,在使用时,它会对你的self.children()的结果进行遍历,并在每个结果上递归地都执行传入的函数。

def apply(self, fn)
	for module in self.children():
	   module.apply(fn)
	 fn(self)
	 return self

比如说,在对模型进行权重初始化时,就可以使用这个函数。在tutorial文档中也给出了相应的例子。下方的代码给出了一个初始化权重的方法,假如module是一个线性层,就将它的权重的数值全部别为1。我们在上方定义的模型上使用这个函数。

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)

可以看到在我们的结果中,net中的l1和l2的weights都受到了影响,但是bias没有发生变化。你也可以使用类似的方法改变它的bias或者别的module中的权重值。
在这里插入图片描述

modules() and get_submodule(target)

modules()函数也可以返回我们的网络中的module,来看一下它和children()的区别。使用我们之前构建的有两个卷积层的网络,可以看到net.modules()除了返回它的submodule外,还返回了它本身。
在这里插入图片描述
而get_submodule(target)中的target,代表你想获得的module的name,使用name可以获得对应的module。
在这里插入图片描述

state_dict()

state_dict()是一个比较重要的方法,它可以orderdict的形式返回我们的模型中各个模块的权重和权重名。
以我们定义的包含两个线性层的模型为例子,state_dict()返回了l1的weight和bias以及l2的weight和bias。并且我们可以通过名称来检测对应的权重值。

模型的搭建

设定训练设备

假设我们现在有一个搭建好的模型net,我们可以将模型放到我们希望使用的设备上,从而利用设备的加速能力。
在pytorch tutorial同样给出了代码样例。

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
net.to(device)

在确定设备后,我们使用.to()函数,就可以把网络放到对应的设备上。

定义自己的网络类

定义模型的三要素:

  • 继承nn.Module()
  • 在__init__中定义组件
  • 在forward()中确定组件使用的顺序
    我们要基于nn.Module()类来构建我们自己的网络,并且在__init__中进行初始化,我们可以使用各种各样的组件来完成我们的网络,并在forward中决定我们的输入在各个组件中传递的顺序。
    需要注意的是,这个顺序不是随便决定的,我们要考虑我们使用的组件的输入维度和输出维度。
    比如说下面这个例子,我们定义了两个forward,其中第一个forward()会在__call__()中被调用,所以我们可以使用net(x)直接调用第一个forward(),第二个forward2()则需要用函数名调用。
    在这个例子中,我们定义了两个线性层,其中l1的输入大小为5,输出大小为2。l2的输入大小为2,输出大小为2。而我们创建的输入变量的大小是(3,5),相当于batchsize = 3,channel=5,因此在先使用l1后使用l2时,代码可以成功执行。但是反过来后代码就会报错。
    在这里插入图片描述

一些相关的方法

nn.Sequential()

nn.Sequential()方法也继承了nn.Module(),它的作用是作为一个container,把组件按照入参时的顺序添加进来,并且在forward()时,传入的数据也会按照顺序通过这些组件。
nn.Sequential()的传入参数有两种形式,第一种是OrderedDict[str, Module],其中有序字典的key代表的是你给要传入的module起的名字。如果使用的不是有序字典作为输入,而是直接使用的Module,那么这个方法会按从0开始的index给组件命名。
具体可以直接看源码:
可以看到,在__init__()函数中,该方法对输入的组件进行了遍历,并且使用add_module()进行register。

def __init__(self, *args):
        super().__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

我们给出一个非常简单的样例,我们使用nn.Sequential()构建一个简单的网络,模型里只有两个线形层。这个定义好的网络是直接可以使用的。

需要注意的是,nn.Sequential()在forward()中进行数据的传递时是按照组件传入的顺序进行的,因此你的组件顺序不对,仍然会出现报错。

nn.ModuleList()

nn.ModuleList()方法也继承了nn.Module(),它和nn.Sequential()一样,也是一个container,但是两者也存在一些区别。
nn.ModuleList()中没有实现forward()的方法,它只是把传入的组件放到了一个类似于python中list的序列中。
nn.ModuleList()中也不可以使用OrderedDict作为输入。
nn.ModuleList()中传入组件时不需要考虑组件的顺序。
以下给入了一个使用的例子。

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

要注意的是我们不能使用python的list来代替nn.Module(),因为之前提到过在进行register时会判断你所创建的obj.name = value的value的类别,假如这个类不是Module,则不会被add_module()添加到self._modules中去。

nn.ModuleDict()

nn.ModuleDict()与nn.ModuleList()类似,不同的是它传入的是一个dict。这也弥补了nn.ModuleList()中不能给组件起名字的缺点,传入的dict中的key就代表了对应组建的名字。
这个方法同样也没有实现forward()函数。
下方给出一个使用的例子。在forward()中调用组件时,用的也不再是nn.ModuleList()中的index,而是dict中的key。

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/643790.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RK3588平台开发系列讲解(以太网篇)SGMII和RGMII接口特性

文章目录 一、MAC 与 PHY的连接二、MAC 与 PHY 在OSI 中位置2.1、网络层2.2、数据链路层2.3、物理层三、RGMII四、SGMII沉淀、分享、成长,让自己和他人都能有所收获!😄 一、MAC 与 PHY的连接 从硬件的角度看,以太网接口电路主要由MAC控制器和物理层PHY芯片两部分组成。 以…

Redis 五大数据类型/结构

Redis 五大数据类型/结构 操作文档 官方文档: https://redis.io/commands 中文文档: http://redisdoc.com/ Redis 数据存储格式 一句话: redis 自身是一个Map,其中所有的数据都是采用key : value 的形式存储 key 是字符串,value 是数据,数…

流媒体接入服务的一般模型

0x00 背景说明 媒体接入服务用来实现媒体资源(resource)的接收和发送,在有限范围内实现不同接入协议的转换。 0x01 一般模型 媒体传输通道的建立步骤通常分为两个阶段: 握手/协商媒体传输 其中,握手/协商操作通常包含: 媒体…

【GD32F303CCT6BlueBill开箱点灯教程】

【GD32F303CCT6BlueBill开箱点灯教程】 1. 搭建环境1.1 官方资料1.2 安装Keil 51.3 安装芯片选型插件pack包 2. 编译2.1 Keil4转换为Keil5工程2.2 选择芯片型号2.3 存储器类型2.4 选择下载器2.5 内存下载设置 3. 烧录3.1 Keil内烧录3.1.1 J-Link烧录3.1.2 ST-Link烧录3.1.3 CMS…

读书笔记:《远见:如何规划职业生涯3大阶段》

《远见:如何规划职业生涯3大阶段》,作者布赖恩. 费瑟斯通豪,豆瓣链接:https://book.douban.com/subject/27609489/ 主旨:描述职业生涯中3个截然不同但相互关联的阶段,教会我们如何不断储备职场燃…

【linux指南--命令大全】

系统的学习linux常用的命令,命令很全所以篇幅很长,可以作为你查阅命令的手册。也欢迎大佬们评论区补充。 文章目录 常见目录介绍配置文件系统操作帮助命令man 帮助help 帮助info 帮助 显示当前的目录名称文件查看建立目录删除空目录复制文件移动文件删除…

Qt下面窗口嵌套,嵌套窗口中包含:QGraphicsView、QGraphicsScene、QGraphicsIte

Qt系列文章目录 文章目录 Qt系列文章目录前言一、嵌套窗口二、注意事项 前言 我们有一个主窗口mainwindow,需要向其中放入新的界面,你可以自己定义里面内容。 Qt的嵌套布局由QDockWidget完成,用Qt Creator拖界面得到的dock布置形式比较固定,…

vmware设置centos客户机和windows宿主机共享文件夹

一、安装内核 kernel-devel 包 yum install gcc yum install kernel-devel-$(uname -r) 注意,如果自己修改过内核版本,需要确保 uname -r 显示的版本和实际使用的内核版本一致。 二、安装 vmware-tools 在vmware上点击菜单:虚拟机->安…

Android kotlin 实现仿京东多个item向左自动排队(横向、动手滑动、没有首尾滑动)功能

文章目录 一、实现效果二、引入依赖三、源码实现1、适配器2、视图实现一、实现效果 二、引入依赖 在app的build.gradle在添加以下代码 1、implementation com.github.CymChad:BaseRecyclerViewAdapterHelper:3.0.6,这个里面带的适配器,直接调用就即可 BaseRecyclerViewAdapt…

【图神经网络】图神经网络(GNN)学习笔记:Graph Embedding

图神经网络(GNN)学习笔记:Graph Embedding 为什么要进行图嵌入Graph embedding?Graph Embedding使用图嵌入的优势有哪些?图嵌入的方法有哪些?节点嵌入方法(Node Embeddings)1. DeepWalk2. LINE…

CTFShow-WEB入门篇命令执行详细Wp(29-40)

WEB入门篇--命令执行详细Wp 命令执行:Web29:Web30:Web31:web32:web33:web34:web35:web36:web37:web38:web39:web40: CTFSh…

【哈希表part02】| 454.四数相加、383.赎金信、15.三数之和、18.四数之和

目录 ✿LeetCode454.四数相加❀ ✿LeetCode383.赎金信❀ ✿LeetCode15.三数之和❀ ✿LeetCode18.四数之和❀ ✿LeetCode454.四数相加❀ 链接:454.四数相加 给你四个整数数组 nums1、nums2、nums3 和 nums4 ,数组长度都是 n ,请你计算有多…

Hive3安装

Mysql安装 卸载Centos7自带的mariadb rpm -qa|grep mariadb rpm -e mariadb-libs-5.5.64-1.el7.x86_64 --nodeps rpm -qa|grep mariadb 安装mysql mkdir /export/software/mysql 上传mysql-5.7.29-1.el7.x86_64.rpm-bundle.tar 到上述文件夹下后解压 tar xvf mysql-5.7.29-1…

微服务技术简介

微服务技术简介 服务架构的演变微服务架构的常见概念微服务常见的解决方案Spring CloudSpring Cloud Alibaba微服务技术对比常用的微服务组件 微服务架构图 服务架构的演变 单体架构:当一个系统业务量很小的时候,将业务的所有功能集中在一个项目中开发&…

红帽认证常见答疑(一):有效期、考试题型、考试对年龄和身份要求、英语水平等

红帽认证有效期 红帽的每个证书都有有效期,期限3年。RHCE过期前可以考下午的RHCE(EX294)或者考一门RHCA来延期3年。证书过期后在红帽官网上无法下载证书,但仍然可以查询到考试记录,不会影响到就业求职,如果…

2.6 TCP与UDP的可靠性传输

目录 一、TCP可靠性传输1、重传机制1.1、超时重传1.2、快速重传1.3、SACK1.4、Duplicate SACK 2、滑动窗口3、流量控制3.1 滑动窗口与流量控制3.2窗口关闭 4、拥塞控制4.1拥塞窗口4.2 慢启动4.3 拥塞避免4.4 拥塞发生4.5 快速恢复 二、UDP可靠性传输1、主要策略2、重传机制2.1 …

基础知识学习---牛客网C++面试宝典(六)操作系统--第二节

1、本栏用来记录社招找工作过程中的内容,包括基础知识学习以及面试问题的记录等,以便于后续个人回顾学习; 暂时只有2023年3月份,第一次社招找工作的过程; 2、个人经历: 研究生期间课题是SLAM在无人机上的应…

湖南大学CS-2017期末考试解析

【特别注意】 答案来源于@wolf 是我在备考时自己做的,仅供参考,若有不同的地方欢迎讨论。 【试卷评析】 这张卷子有点老了,部分题目可能有用。如果仔细研究应该会有所收获。 【试卷与答案】 一.(6 分,每空 0.5 分) 下表中%r1,%r2 为两个四位的寄存器,请仿照第一行…

考虑3D海底环境的风电场集电系统

摘要 风能是目前国内外应用较为广泛的一种绿色可再生能源,近几年我国风电产业的发展十分迅速。然后,越来越多的风力发电系统建并网,风力发电产生的电能受外界因素影响较大,具有一定的随机性和波动性,给并网后的电力系统…

代码随想录算法训练营第三十五天| 860.柠檬水找零、406.根据身高重建队列、452. 用最少数量的箭引爆气球

柠檬水找零 题目链接:力扣 这道题 我一开始用纯模拟的方法也能写出来,后来发现和卡哥给的答案差不多,其贪心的点在: 当账单是20的情况,优先消耗一个10和一个5因为美元10只能给账单20找零,而美元5可以给账…