PyTorch模型创建与nn.Module

news2024/11/30 6:40:51

文章和代码已经归档至【Github仓库:https://github.com/timerring/dive-into-AI 】或者公众号【AIShareLab】回复 pytorch教程 也可获取。

文章目录

  • 模型创建与nn.Module
      • nn.Module
    • 总结

模型创建与nn.Module

创建网络模型通常有2个要素:

  • 构建子模块
  • 拼接子模块

class LeNet(nn.Module):
	# 子模块创建
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)
	# 子模块拼接
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

调用net = LeNet(classes=2)创建模型时,会调用__init__()方法创建模型的子模块。

训练调用outputs = net(inputs)时,会进入module.pycall()函数中:

    def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        ...
        ...
        ...

最终会调用result = self.forward(*input, **kwargs)函数,该函数会进入模型的forward()函数中,进行前向传播。

torch.nn中包含 4 个模块,如下图所示。

本次重点就在于nn.Model的解析:

nn.Module

nn.Module 有 8 个属性,都是OrderDict(有序字典)的结构。在 LeNet 的__init__()方法中会调用父类nn.Module__init__()方法,创建这 8 个属性。

    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
  • _parameters 属性:存储管理 nn.Parameter 类型的参数
  • _modules 属性:存储管理 nn.Module 类型的参数
  • _buffers 属性:存储管理缓冲属性,如 BN 层中的 running_mean
  • 5 个 ***_hooks 属性:存储管理钩子函数

LeNet 的__init__()中创建了 5 个子模块,nn.Conv2d()nn.Linear()都继承于nn.module,即一个 module 都是包含多个子 module 的。

class LeNet(nn.Module):
	# 子模块创建
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)
        ...
        ...
        ...

当调用net = LeNet(classes=2)创建模型后,net对象的 modules 属性就包含了这 5 个子网络模块。

下面看下每个子模块是如何添加到 LeNet 的_modules 属性中的。以self.conv1 = nn.Conv2d(3, 6, 5)为例,当我们运行到这一行时,首先 Step Into 进入 Conv2d的构造,然后 Step Out。右键Evaluate Expression查看nn.Conv2d(3, 6, 5)的属性。

上面说了Conv2d也是一个 module,里面的_modules属性为空,_parameters属性里包含了该卷积层的可学习参数,这些参数的类型是 Parameter,继承自 Tensor。

此时只是完成了nn.Conv2d(3, 6, 5) module 的创建。还没有赋值给self.conv1 nn.Module里有一个机制,会拦截所有的类属性赋值操作(self.conv1是类属性),进入到__setattr__()函数中。我们再次 Step Into 就可以进入__setattr__()

   def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            ...
            ...
            ...

在这里判断 value 的类型是Parameter还是Module,存储到对应的有序字典中。

这里nn.Conv2d(3, 6, 5)的类型是Module,因此会执行modules[name] = value,key 是类属性的名字conv1,value 就是nn.Conv2d(3, 6, 5)

总结

  • 一个 module 里可包含多个子 module。比如 LeNet 是一个 Module,里面包括多个卷积层、池化层、全连接层等子 module
  • 一个 module 相当于一个运算,必须实现 forward() 函数
  • 每个 module 都有 8 个字典管理自己的属性

参考文章:

https://www.cnblogs.com/zhangxiann/p/13579624.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/741186.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis 优惠卷秒杀(二) 异步秒杀、基于Stream的消息队列处理

目录 基于Stream的消息队列 Redis优化秒杀 登录头 改进秒杀业务,调高并发性能 Redis消息队列实现异步秒杀 ​编辑基于List结构模拟消息队列 基于PuSub的消息队列 ​编辑 基于Stream的消息队列 Redis消息队列 基于Stream的消息队列 Redis优化秒杀 登录头 改…

skywalking安装

目录 skywalking部署示意图 server安装 裸机安装 docker单节点安装 docker集群安装 k8s安装 helm安装(官方) k8s yaml安装 动态配置安装 client agent安装 skywalking部署示意图 skywalking ui - web界面管理程序oap server - skywalking服务程序nacos - skywalking集…

数字孪生水务系统可视化管理平台有效缓解城市供水压力

针对传统自来水厂供水水质安全隐患大,运行管理落后等问题,基于数字孪生技术构建全厂三维立体模型,在电脑前就可以掌握全厂管线、设备运行情况,遇到预案中的突发事件还可以给出辅助决策方案。从根本上有效提高水厂运行管理效率,增强对水质变化的应对能力,…

分析shein独立站成功的原因

近年来,Shein独立站在快时尚领域声名鹊起,成为许多时尚消费者的首选网站。面对激烈的竞争,它依然能够站稳脚跟并不断壮大。那么,Shein独立站成功的原因是什么呢? Shein独立站——以消费者为中心的运营模式 Shein独立站…

【Python】Locust持续优化:InfluxDB与Grafana实现数据持久化与可视化分析

在进行性能测试时,我们需要对测试结果进行监控和分析,以便于及时发现问题并进行优化。 Locust在内存中维护了一个时间序列数据结构,用于存储每个事件的统计信息。 这个数据结构允许我们在Charts标签页中查看不同时间点的性能指标&#xff0c…

java中使用HttpRequest发送请求调用自己的接口

(539条消息) java中使用HttpRequest发送请求_java httprequest_thankful_chn的博客-CSDN博客 <dependency><groupId>com.github.kevinsawicki</groupId><artifactId>http-request</artifactId><version>5.6</version></dependenc…

华为云-hcip笔记-网络服务规划

华为云-hcip笔记-网络服务规划 网络服务规划 安全组和网络ACL 网络ACL对子网进行防护&#xff0c;安全组是对ECS进行防护。 对等连接VPC peering 两个vpc之间的网络连接&#xff0c;用户可以使用私有ip地址在两个vpc之间进行通信。 同账号中对等连接自动接受&#xff0c;跨…

【JavaEE】JVM的组成及类加载过程

博主简介&#xff1a;想进大厂的打工人博主主页&#xff1a;xyk:所属专栏: JavaEE初阶 本文我们主要讲解一下面试中常见的问题&#xff0c;如果想深入了解&#xff0c;请看一下《Java虚拟机规范》这本书 目录 文章目录 一、JVM简介 二、JVM整体组成 2.1 运行时数据区组成 2.2…

【LeetCode周赛】2022上半年题目精选集——数学

文章目录 2183. 统计可以被 K 整除的下标对数目⭐⭐⭐⭐⭐思路——数论&#xff08;一个数乘上另一个数x是k的倍数&#xff0c;x最小是多少&#xff1f;&#xff09;代码1——统计每个数的因子代码2——统计k的因子 2245. 转角路径的乘积中最多能有几个尾随零思路&#xff08;因…

探索全球市场:初创品牌海外营销策略解析

​随着全球化进程的不断推进&#xff0c;越来越多的初创品牌意识到海外市场的巨大潜力&#xff0c;并希望能够将自己的品牌推广到更广阔的国际舞台上。然而&#xff0c;对于初创品牌来说&#xff0c;进军海外市场并开展品牌营销是一项具有挑战性的任务。本文Nox聚星将介绍一些初…

百变探影器 - 是一款很多人都在用的剪辑软件

有没有一款剪辑软件&#xff0c;它不仅颜值高&#xff0c;不用花时间学习就会剪&#xff0c;还自带丰富转场、片头片尾、字幕模板呢&#xff1f;那不得不说的就是一款超级能打的国产剪辑软件—百变探影器软件。 Pr这些比较专业的剪辑软件&#xff0c;基本都需要拥有一定的剪辑…

【实验八】多线程

1、完成书上268页习题第7题和实验题第1、2题 &#xff08;1&#xff09;第7题 import java.awt.*; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import javax.swing.*;public class RollWords extends JFrame{ static RollWords.MyThread thre…

西门子S7-200西门子200plc以太网配置向导

方案摘要&#xff1a; 西门子 S7200 系列 PLC通过MatrikonOPC实现以太网连接&#xff0c;捷米特ETH-S7200-JM01以太网模块为 PLC200转换出以太网通讯接口。 功能简介 MatrikonOPC是世界上最大的OPC开发商和供应商&#xff0c;我们的产品涵盖了OPC服务器、客户端、应用程序、O…

为D1定义一个f()函数,重做练习1-3,并解释其结果

运行代码&#xff1a; //为D1定义一个f()函数&#xff0c;重做练习1-3,并解释其结果 #include"std_lib_facilities.h" //---------------------------------------------------------------------- //定义B1类。 class B1 { public:virtual void vf() { cout<<…

第四章 数学知识(二)——欧拉函数,快速幂,扩展欧与中国剩余定理

文章目录 欧拉函数线性筛求欧拉函数欧拉定理 快速幂逆元 扩展欧几里得中国剩余定理扩展中国剩余定理 欧拉函数练习题873. 欧拉函数874. 筛法求欧拉函数 快速幂练习题875. 快速幂876. 快速幂求逆元 扩展欧练习题877. 扩展欧几里得算法878. 线性同余方程 中国剩余定理练习题204. …

Linux进程信号(一)

信号产生 1.信号基础知识2.初步认识信号3.signal函数4.技术应用角度的信号5.调用系统函数向进程发信号6.由软件条件产生的信号7.硬件异常产生信号8.core &#x1f31f;&#x1f31f;hello&#xff0c;各位读者大大们你们好呀&#x1f31f;&#x1f31f; &#x1f680;&#x1f…

从CoCo到喜茶,新茶饮品牌领悟出海的“九阴真经”了吗?

炎炎夏日里&#xff0c;一杯冰凉的奶茶和果茶受到了更多追捧。但是&#xff0c;中国新茶饮品牌却站在了一个十字路口。 随着新茶饮迈入“万店时代”&#xff0c;国内市场已经出现了明显的内卷现象&#xff0c;头部品牌之间的竞争日趋激烈&#xff0c;中小品牌的生存空间被挤压…

OpenGL的学习记录(一)(一些基本概念)

1.OpenGL是什么&#xff1f; OpenGL是一组各个GPU厂家一起遵循的约定。 2.GLFW&#xff0c;GLAD分别是什么&#xff1f; GLFW解决系统层面的不同&#xff0c;是我们与系统之间的隔离&#xff0c;如&#xff08;创建窗口&#xff0c;定义上下文&#xff0c;处理用户输入&#x…

数据结构--树和森林的遍历

数据结构–树和森林的遍历 树的先根遍历 void PreOrder(TreeNode* R) {if (R ! NULL){visit(R);while (R还有下一个子树T)PreOrder(T);} }树和二叉树的转化后》 树的先根遍历序列与这棵树相应二叉树的先序序列相同。 \color{red}树的先根遍历序列与这棵树相应二叉树的先序序列相…

txt文本筛选—python操作

需求&#xff1a;若文档中某行最后一列内容为0&#xff0c;则删除该行&#xff0c;否则保留该行内容&#xff0c;并将筛选后的内容保存到新的文本文档中。 # 读取原始txt文件 with open(depth_values.txt, r) as file:lines file.readlines()# 过滤掉第三列内容为0的行 filter…