4-3 nn.functional和nn.Module

news2025/1/23 3:59:38

一,nn.functional 和 nn.Module

前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API。利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模型层,损失函数)。
其实:Pytorch和神经网络相关的功能组件大多都封装在** torch.nn **模块下。
这些功能组件的绝大部分既有函数形式实现,也有类形式实现。
其中nn.functional(一般引入后改名为F)有各种功能组件的函数实现。例如:
激活函数:
F.relu
F.sigmoid
F.tanh
F.softmax
模型层:
F.linear
F.conv2d
F.max_pool2d
F.dropout2d
F.embedding
损失函数:
F.binary_cross_entropy
F.mse_loss
F.cross_entropy
为了便于对参数进行管理,一般通过继承 nn.Module 转换成为类的实现形式,并直接封装在 nn 模块下。例如:
激活函数:
nn.ReLU
nn.Sigmoid
nn.Tanh
nn.Softmax
模型层:
nn.Linear
nn.Conv2d
nn.MaxPool2d
nn.Dropout2d
nn.Embedding
损失函数:
nn.BCELoss
nn.MSELoss
nn.CrossEntropyLoss
实际上nn.Module除了可以管理其引用的各种参数,还可以管理其引用的子模块,功能十分强大。
简单举例:
image.png

二,使用nn.Module来管理参数(配合nn.Parameter使用)

在Pytorch中,模型的参数是需要被优化器训练的,因此,通常要设置参数为 requires_grad = True 的张量。
同时,在一个模型中,往往有许多的参数,要手动管理这些参数并不是一件容易的事情。
Pytorch一般将参数用nn.Parameter来表示,并且用nn.Module来管理其结构下的所有参数。

requires_grad = True

手动设置:
image.png
nn.Parameter 具有 requires_grad = True 属性:
image.png

nn.ParameterList

列表形式
image.png

nn.ParameterDict

字典形式
image.png

Module管理

image.png
image.png

三、nn.Module构建模块类

实践当中,一般通过继承nn.Module来构建模块类,并将所有含有需要学习的参数的部分放在构造函数中。
以下范例为Pytorch中nn.Linear的源码的简化版本
可以看到它将需要学习的参数放在了__init__构造函数中,并在forward中调用F.linear函数来实现计算逻辑。

class Linear(nn.Module):
    __constants__ = ['in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

四、使用nn.Module来管理子模块

一般情况下,我们都很少直接使用 nn.Parameter来定义参数构建模型,而是通过拼装一些常用的模型层来构造模型。
这些模型层也是继承自nn.Module的对象,本身也包括参数,属于我们要定义的模块的子模块。
nn.Module提供了一些方法可以管理这些子模块。
children() 方法: 返回生成器,包括模块下的所有子模块。
named_children()方法:返回一个生成器,包括模块下的所有子模块,以及它们的名字。
modules()方法:返回一个生成器,包括模块下的所有各个层级的模块,包括模块本身
named_modules()方法:返回一个生成器,包括模块下的所有各个层级的模块以及它们的名字,包括模块本身。
其中chidren()方法和named_children()方法较多使用。
modules()方法和named_modules()方法较少使用,其功能可以通过多个named_children()的嵌套使用实现。

class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        
        self.embedding = nn.Embedding(num_embeddings = 10000,embedding_dim = 3,padding_idx = 1)
        self.conv = nn.Sequential()
        self.conv.add_module("conv_1",nn.Conv1d(in_channels = 3,out_channels = 16,kernel_size = 5))
        self.conv.add_module("pool_1",nn.MaxPool1d(kernel_size = 2))
        self.conv.add_module("relu_1",nn.ReLU())
        self.conv.add_module("conv_2",nn.Conv1d(in_channels = 16,out_channels = 128,kernel_size = 2))
        self.conv.add_module("pool_2",nn.MaxPool1d(kernel_size = 2))
        self.conv.add_module("relu_2",nn.ReLU())
        
        self.dense = nn.Sequential()
        self.dense.add_module("flatten",nn.Flatten())
        self.dense.add_module("linear",nn.Linear(6144,1))
        
    def forward(self,x):
        x = self.embedding(x).transpose(1,2)
        x = self.conv(x)
        y = self.dense(x)
        return y
    
net = Net()

children

image.png

named_children

image.png

modules

image.png
image.png

冻结参数

下面我们通过named_children方法找到embedding层,并将其参数设置为不可训练(相当于冻结embedding层)。
image.png
image.png
image.png

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1010126.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

中小企业数字化转型难?为什么不试试“企业级”无代码平台

首先,让我们思考一下,中小企业为什么要进行数字化转型?随着全球经济的数字化趋势日益明显,中小企业作为经济的重要组成部分,其数字化转型已成为推动经济高质量发展的关键。数字技术可以帮助中小企业提高生产效率、降低…

ctfshow-web-红包题 辟邪剑谱

0x00 前言 CTF 加解密合集CTF Web合集网络安全知识库溯源相关 文中工具皆可关注 皓月当空w 公众号 发送关键字 工具 获取 0x01 题目 0x02 Write Up 这道题主要是考察mysql查询绕过的问题。 首先访问后看到是一个登录页面,测试注册等无果 扫描目录,发…

Packet Tracer的使用介绍

直接访问 Packet Tracer 的帮助页面、教程视频和在线资源对于了解该软件会更加方便。 单击菜单工具栏右上角的问号图标。单击“帮助”菜单,然后选择“内容”。 b. 通过单击“帮助”>“教程”来访问 Packet Tracer 的教程视频。 菜单栏:提供文件、编辑…

SpringBoot运行原理

目录 SpringBootApplication ComponentScan SpringBootConfiguration EnableAutoConfiguration 结论 SpringbootApplication(主入口) SpringBootApplication public class SpringbootConfigApplication {public static void main(String[] args) {…

Android动态片段

之前创建的片段都是静态的。一旦显示片段,片段的内容就不能改变了。尽管可以用一个新实例完全取代所显示的片段,但是并不能更新片段本身的内容。 之前已经创建过一个基础秒表应用,具体代码https://github.com/MADMAX110/Stopwatch。我们将这个…

发生以下的报错怎么办?

报错问题: 解决办法: 根据你提供的代码和错误信息,问题出在使用了nullptr。这个错误是因为你的编译器不支持C11标准。 nullptr是C11引入的空指针常量。为了解决这个问题,你可以尝试以下两种方法之一: 1. 将nullptr…

想要精通算法和SQL的成长之路 - 可以攻击国王的皇后

想要精通算法和SQL的成长之路 - 可以攻击国王的皇后 前言一. 可以攻击国王的皇后 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 可以攻击国王的皇后 原题链接 这个题目其实并没有涉及到什么很难的算法,其实就是一个简单的遍历题目。核心思想: 以…

CRM系统销售自动化功能如何提高销售效率

销售效率对企业的盈利能力有着至关重要的联系。提高销售效率,就是要提高销售人员的工作效率和销售转化率。那么,企业如何提高销售效率呢?CRM销售自动化功能可以帮助企业实现这一目标。 一、线索管理 线索是指有潜在购买意向的客户&#xff…

kali必杀器之三剑客

Kali常见攻击手段 注意:仅用于教程和科普,切勿做违法之事,否则后果自负 1 网络攻击手段 请正确使用DDos和CC攻击,不要用来做违反当地法律法规的事情,否则后果自负 使用之前kali需要能够上网 参考:kali安装 1.1 DDos攻击…

新加坡打车软件平台Ryde Group申请1700万美元纳斯达克IPO上市

来源:猛兽财经 作者:猛兽财经 猛兽财经获悉,新加坡打车软件平台Ryde Group近期已向美国证券交易委员会(SEC)提交招股书,申请在纳斯达克IPO上市,股票代码为(RYDE)&#x…

学习javaEE初阶的第一堂课

学习金字塔 java发展简史 Java最初诞生的时候是用来写前端的!! 199x年 199x年,互联网还处在比较早期的阶段,当时主流的编程语言是 C/C, 有个大佬要搞个"智能面包机",觉得用C来做太难了 于是就基于C搞了个简单点的语言,Java 就诞生了~~ 遗憾的是项目流产了,没做成…

【SpringMVC】自定义注解

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚 🌟在这里,我要推荐给大家我的专栏《Spring MVC》。&#x1f3…

MMrotate_dev 1.x训练自己的数据集

因为MMRotate dev 1.x 新增了PSC角度编码器以及RTMDet目标检测算法,而之前从官网下载的MMRotate是main分支,没有新增的东西,所以重新搞了一下,以此记录。 环境配置 1.创建虚拟环境 注意:如果之前安装了MMRotate的其…

基于小程序的理发店预约系统

一、项目背景及简介 现在很多的地方都在使用计算机开发的各种管理系统来提高工作的效率,给人们带来很多的方便。计算机技术从很大的程度上解放了人们的双手,并扩大了人们的活动范围,是人们足不出户就可以通过电脑进行各种事情的管理。信息系…

pycharm安装jupyter,用德古拉主题,但是输入行全白了,看不清,怎么办?

问题描述 今天换了以下pycharm主题,但是jupyter界面输入代码行太白了,白到看不清楚这行的字,更不知道写的是什么,写到哪了,这还是挺烦人的,其他都挺正常的。 问题分析 目前来看有两个原因: 1、…

深化产教融合,知了汇智助力高校数字化人才培养

随着数字经济的不断深入和发展,数字人才短缺的问题逐渐凸显,根据相关报告,目前我国数字人才缺口在2500万到3000万左右,且缺口仍在不断扩大。为了满足数字经济的发展需求,如何培养出具备创新型、复合型、应用型能力的数…

C++学习笔记一(重载、类)

C 1、函数重载2、类2.1、类的方法和属性2.2、类的方法的定义2.3、构造器和析构器2.4、类的实例化2.5、基类与子类2.6、类的public、protected、private继承2.7、类的方法的重载2.8、子类方法的覆盖2.9、继承中的构造函数和析构函数 1、函数重载 函数重载大概可以理解为&#x…

再次理解Android账号管理体系

目录 ✅ 0. 需求 📂 1. 前言 🔱 2. 使用 2.1 账户体系前提 2.2 创建账户服务 2.3 操作账户-增删改查 💠 3. 源码流程 ✅ 0. 需求 试想,自己去实现一个账号管理体系,该如何做呢? ——————————…

竞赛 基于大数据的时间序列股价预测分析与可视化 - lstm

文章目录 1 前言2 时间序列的由来2.1 四种模型的名称: 3 数据预览4 理论公式4.1 协方差4.2 相关系数4.3 scikit-learn计算相关性 5 金融数据的时序分析5.1 数据概况5.2 序列变化情况计算 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 &…

2004-2020年中小企业板上市公司财务报表股票交易董事高管1200+变量数据及说明

2004-2020年中小企业板上市公司财务报表股票交易董事高管1200变量数据及说明 1、时间:2004-2020年 2、范围:中小企业板上市公司,具体名单参看下文链接内数据预览 3、指标:1200变量 变量说明、证券代码、证券代码-字符串、年份…