Pytorch学习:神经网络模块torch.nn.Module和torch.nn.Sequential

news2025/1/21 15:36:28

文章目录

    • 1. torch.nn.Module
      • 1.1 add_module(name,module)
      • 1.2 apply(fn)
      • 1.3 cpu()
      • 1.4 cuda(device=None)
      • 1.5 train()
      • 1.6 eval()
      • 1.7 state_dict()
    • 2. torch.nn.Sequential
      • 2.1 append
    • 3. torch.nn.functional.conv2d

1. torch.nn.Module

官方文档:torch.nn.Module
CLASS torch.nn.Module(*args, **kwargs)

  • 所有神经网络模块的基类。
  • 您的模型也应该对此类进行子类化。
  • 模块还可以包含其他模块,允许将它们嵌套在树结构中。您可以将子模块分配为常规属性:
  • training(bool)-布尔值表示此模块是处于训练模式还是评估模式。

定义一个模型

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
  • 以这种方式分配的子模块将被注册,并且当您调用 to() 等时也将转换其参数。
    • to(device=None,dtype=None,non_blocking=False)
      device ( torch.device) – 该模块中参数和缓冲区所需的设备
    • to(dtype ,non_blocking=False)
      dtype ( torch.dtype) – 该模块中参数和缓冲区所需的浮点或复杂数据类型
    • to(tensor,non_blocking=False)
      张量( torch.Tensor ) – 张量,其数据类型和设备是该模块中所有参数和缓冲区所需的数据类型和设备

引用上面定义的模型,将模型转移到GPU上

# 创建模型
model = Model()

# 定义设备 gpu1
gpu1 = torch.device("cuda:1")
model = model.to(gpu1)

1.1 add_module(name,module)

将子模块添加到当前模块。
可以使用给定的名称作为属性访问模块。

add_module(name,module)
主要参数:

  • name(str)-子模块的名称。可以使用给定的名称从此模块访问子模块。
  • module(Module)-要添加到模块的子模块。

在这里插入图片描述
添加一个卷积层

model.add_module("conv3", nn.Conv2d(20, 20, 5))

在这里插入图片描述

1.2 apply(fn)

将 fn 递归地应用于每个子模块(由 .children() 返回)以及self。
典型的用法包括初始化模型的参数(另请参见torch.nn.init)。

apply(fn)
主要参数:

  • fn( Module -> None)-应用于每个子模块的函数

将所有线性层的权重置为1

import torch
from torch import nn


@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)


net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2,2))
net.apply(init_weights)

在这里插入图片描述

1.3 cpu()

将所有模型参数和缓冲区移动到CPU。

device = torch.device("cpu")
model = model.to(device)

1.4 cuda(device=None)

将所有模型参数和缓冲区移动到GPU。

这也使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在GPU上,则应在构造优化器之前调用该函数。

cuda(device=None)
主要参数:

  • device(int,可选)-如果指定,所有参数将被复制到该设备

转移到GPU包括以下参数:

  1. 模型
  2. 损失函数
  3. 输入输出
# 创建模型
model = Model()

# 将模型转移到GPU上
model = model.cuda()

# 将损失函数转移到GPU上
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()

# 将输入输出转移到GPU上
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()

另一种表示形式(通过 to(device) 来表示)

# 创建模型
model = Model()

# 定义设备:如果有GPU,则在GPU上训练, 否则在CPU上训练
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

# 将模型转移到GPU上
model = model.to(device)

# 将损失函数转移到GPU上
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

# 将输入输出转移到GPU上
imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)

1.5 train()

将模块设置为训练模式。

这只对某些模块有任何影响。如受影响,请参阅特定模块在培训/评估模式下的行为详情,例如: Dropout 、 BatchNorm 等。

train(mode=True)
主要参数:

  • mode(bool)-是否设置训练模式( True )或评估模式( False )。默认值: True 。

1.6 eval()

将模块设置为评估模式。

这只对某些模块有任何影响。如受影响,请参阅特定模块在培训/评估模式下的行为详情,例如: Dropout 、 BatchNorm 等。

在进行模型测试的时候会用到。

1.7 state_dict()

返回一个字典,其中包含对模块整个状态的引用。

返回模型的关键字典。

model = Model()
print(model.state_dict().keys())

在这里插入图片描述
在保存模型的时候我们也可以直接保存模型的 state_dict()

model = Model()

# 保存模型
# 另一种方式:torch.save(model, "model.pth")
torch.save(model.state_dict(), "model.pth")

# 加载模型
model.load_state_dict(torch.load("model.pth"))

2. torch.nn.Sequential

顺序容器。模块将按照它们在构造函数中传递的顺序添加到它。

Sequential 的 forward() 方法接受任何输入并将其转发到它包含的第一个模块。然后,它将输出“链接”到每个后续模块的输入,最后返回最后一个模块的输出。

官方文档:torch.nn.Sequential
CLASS torch.nn.Sequential(*args: Module)

import torch
from torch import nn

# 使用 Sequential 创建一个小型模型。运行 `model` 时、
# 输入将首先传递给 `Conv2d(1,20,5)`。输出
# `Conv2d(1,20,5)`的输出将作为第一个
# 第一个 `ReLU` 的输出将成为 `Conv2d(1,20,5)` 的输入。
# `Conv2d(20,64,5)` 的输入。最后
# `Conv2d(20,64,5)` 的输出将作为第二个 `ReLU` 的输入
model = nn.Sequential(
            nn.Conv2d(1, 20, 5),
            nn.ReLU(),
            nn.Conv2d(20, 64, 5),
            nn.ReLU()
        )

在这里插入图片描述

2.1 append

append 在末尾追加给定块。

  • append(module)
    在末尾追加给定模块。
    在这里插入图片描述
def append(self, module):
    self.add_module(str(len(self)), module)
    return self


append(model, nn.Conv2d(64, 64, 5))
append(model, nn.ReLU())
print(model)

在这里插入图片描述

3. torch.nn.functional.conv2d

对由多个输入平面组成的输入图像应用2D卷积。
卷积神经网络详解:csdn链接

官方文档:torch.nn.functional.conv2d
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
主要参数:

  • input:形状的输入张量,(minibatch, inchannels, iH, iW)。
  • weigh:卷积核权重,形状为 (out_channels, inchannels / groups, kH, kW)

默认参数:

  • bias:偏置,默认值: None。
  • stride:步幅,默认值:1。
  • padding:填充,默认值:0。
  • dilation :内核元素之间的间距。默认值:1。
  • groups:将输入拆分为组,in_channels 应被组数整除。默认值:1。

在这里插入图片描述
对上图卷积操作进行代码实现

import torch.nn.functional as F

input = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]], dtype=float32)
kernel = torch.tensor([[0, 1],
                       [2, 3]], dtype=float32)


# F.conv2d 输入维数为4维
# torch.reshape(input, shape)
# reshape(样本数,通道数,高度,宽度)
input = torch.reshape(input, (1, 1, 3, 3))
kernel = torch.reshape(kernel, (1, 1, 2, 2))

output = F.conv2d(input, kernel, stride=1)
print(input.shape)
print(kernel.shape)
print(input)
print(kernel)
print(output)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/950826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python的可哈希对象

一、介绍 在Python中,可哈希(hashable)是指一种对象类型,该类型的对象可以用作字典的键(keys)或集合(sets)的元素。可哈希的对象具有以下特点: 不可变性(Imm…

【Interaction交互模块】LinearTransformDrive线性变换驱动

文章目录 一、预设位置二、案例:建一个按下后可自动抬起的按钮三、留有疑问 一、预设位置 交互——可控制物体——变换——线性变换驱动 二、案例:建一个按下后可自动抬起的按钮 按钮的结构和设置如下图 为了让它碰触时,往下走——预设体…

【Go 基础篇】Go语言结构体基本使用

在Go语言中,结构体是一种重要的数据类型,用于定义和组织一组不同类型的数据字段。结构体允许开发者创建自定义的复合数据类型,类似于其他编程语言中的类。本文将深入探讨Go语言中结构体的定义、初始化、嵌套、方法以及与其他语言的对比&#…

Java八股文学习笔记day01

01.和equals区别 对于字符串变量来说,使用""和"equals"比较字符串时,其比较方法不同。""比较两个变量本身的值,即两个对象在内存中的首地址,"equals"比较字符串包含内容是否相同。 对于非…

VR司法法治教育平台,沉浸式课堂教学培养刑侦思维和能力

VR司法法治教育平台提供了多种沉浸式体验,通过虚拟现实(Virtual Reality,简称VR)技术让用户深度参与和体验法治知识。以下是一些常见的沉浸式体验: 1.罪案重现 VR司法法治教育平台可以通过重现真实案例的方式,让用户亲眼目睹罪案发…

基于天鹰算法优化的BP神经网络(预测应用) - 附代码

基于天鹰算法优化的BP神经网络(预测应用) - 附代码 文章目录 基于天鹰算法优化的BP神经网络(预测应用) - 附代码1.数据介绍2.天鹰优化BP神经网络2.1 BP神经网络参数设置2.2 天鹰算法应用 4.测试结果:5.Matlab代码 摘要…

Python 的画图函数 seaborn 简介

seaborn 简介 seanborn 是 Python 的另外一个常用工具包,它基于 matplotlib,但画出的图形更加美观些,并且与 Pandas 的数据类型结合地较好。 # Import seaborn import seaborn as sns import matplotlib.pyplot as plt# Apply the default …

选择一款可靠的仓库管理软件

在选择仓库管理软件时,数据的准确性、数据的安全性和软件的稳定性这三点是至关重要的。下面我们将详细讨论这些关键点。 第一,数据的准确性对于一个仓库管理软件来说是非常重要的。如果数据的准确性无法得到保证,那么这样的仓库入库出库管理…

华为OD七日集训第1期复盘 - 按算法分类,由易到难,循序渐进,玩转OD(文末送书)

目录 一、活动内容如下第1天、逻辑分析第2天、字符串处理第3天、数据结构第4天、双指针第5天、递归回溯第6天、二分查找第7天、贪心算法 && 二叉树 二、可观测性工程1、简介2、主要内容 大家好,我是哪吒。 最近一直在刷华为OD机试的算法题,坚持…

RT-Thread内核机制 线程栈

int flag;void cmp_val(int a,int b) {volatile int tmp[10];tmp[0] a;if(tmp[0] > b){flag 1;}else{flag 0;} }int main() {int a 1;int b ;cmp_val(a,b);return 0; }我们写好的程序会保存在Flash上。 其它类似汇编指令 SUB R0,R0,#4 R0 R0-4 B LR 放入LR寄存器 局…

Swift使用PythonKit调用Python

打开Xcode项目。然后选择“File→Add Packages”,然后输入软件包依赖链接: ​https://github.com/pvieito/PythonKit.git https://github.com/kewlbear/Python-iOS.git Python-iOS包允许在iOS应用程序中使用python模块。 用法: import Pyth…

超大屏画质奇迹 !TCL 115吋 QD-Mini LED 电视成豪宅顶配

有目共睹的是,现如今的电视屏幕越做越大。 尤其近年来,大屏电视的趋势愈发明显,原因无他,大屏电视所带来的震撼视觉体验,是其他电视所无法比拟的,这也正是电视品牌们不断突破新局限,往“大”了…

Python 画多个子图函数 subplot

子图函数 subplot 若要 pyplot 一次生成多个图形,一般要用到subplot函数,另外还有一个subplots函数。两个函数比较接近但略有区别,限于篇幅,我们只介绍 subplot函数,它的基本语法如下: ax plt.subplot(n…

【C语言基础】变量类型,Static关键字的使用

📢:如果你也对机器人、人工智能感兴趣,看来我们志同道合✨ 📢:不妨浏览一下我的博客主页【https://blog.csdn.net/weixin_51244852】 📢:文章若有幸对你有帮助,可点赞 👍…

【C/C++】#define宏替换高级用法

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…

载舟前行——2023跳槽涨薪,Android的1000道面试题

转眼没有口罩的一年&#xff0c;就来到下半年。比起之前几年今天愈发的艰难&#xff1b;今年的金九银十的来到&#xff0c;许多跳槽找工作的也来到了旺季。岗位的减少无疑造成的后果就是竞争大&#xff0c;所以面试优胜劣汰你需要在千百人中脱颖而出。 面试不容小觑&#xff0…

英文晨读记录(broken heart)

2023/8/28 mate 配偶;伙伴;朋友;(男人之间常 用)哥儿们&#xff0c;伙计&#xff0c;老兄;同伴;subject n. 主题;问题;学科;课程;科目;话题;题目;题材;表现对象; adj. 服从于;取决于;视…而定;易遭受…的;受…支配;可能受…影响的;受异族统治的 vt. 使臣服;使顺从;(尤指)压服 …

Go 面向对象(匿名字段)

概述 严格意义上说&#xff0c;GO语言中没有类(class)的概念,但是我们可以将结构体比作为类&#xff0c;因为在结构体中可以添加属性&#xff08;成员&#xff09;&#xff0c;方法&#xff08;函数&#xff09;。 面向对象编程的好处比较多&#xff0c;我们先来说一下“继承…

软件系统测试报告包括哪些内容?对软件产品起到什么作用?

软件系统测试报告是软件开发过程中非常重要的一环。它是一个详细记录了对系统进行测试的结果和总结的文档。通过系统测试报告&#xff0c;开发人员可以了解系统在测试过程中的表现&#xff0c;发现系统的问题和不足之处&#xff0c;从而采取相应的措施进行改进。 一、软件系统…

睿趣科技:抖音开小店大概多久可以做起来

随着移动互联网的快速发展&#xff0c;社交媒体平台成为了人们分享生活、交流信息的主要渠道之一。在众多社交平台中&#xff0c;抖音以其独特的短视频形式和强大的用户粘性受到了广泛关注。近年来&#xff0c;越来越多的人通过在抖音上开设小店来实现创业梦想&#xff0c;这种…