PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

news2025/1/18 12:29:49

torch.nn torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别:

1. 继承方式与结构

torch.nn

  • torch.nn 中的模块大多数是通过继承 torch.nn.Module 类来实现的。这些模块都是 Python 类,包含了神经网络的各种层(如卷积层、全连接层等)和其他组件(如损失函数、优化器等)。
  • torch.nn 中的模块可以包含可训练参数,如权重和偏置,这些参数在训练过程中会被优化。

torch.nn.functional

  • torch.nn.functional 中的函数是直接调用的,无需实例化。这些函数通常用于执行各种非线性操作、损失函数计算、激活函数应用等。
  • torch.nn.functional 中的函数没有可训练参数,它们只是执行操作并返回结果。

2. 实现方式与调用方式

torch.nn

  • torch.nn 中的模块是基于面向对象的方法实现的。开发者需要创建类的实例,并在类的 forward 方法中定义数据的前向传播路径。
  • torch.nn 中的模块通常需要先创建模型实例,再将输入数据传入模型中进行前向计算。

torch.nn.functional

  • torch.nn.functional 中的函数是基于函数式编程实现的。它们提供了灵活的接口,允许开发者以函数调用的方式轻松定制和扩展神经网络架构。
  • torch.nn.functional 中的函数可以直接调用,只需要将输入数据传入函数中即可进行前向计算。

3. 使用场景与优势

torch.nn

  • torch.nn 更适合用于定义有状态的模块,如包含可训练参数的层。
  • 当定义具有变量参数的层时(如卷积层、全连接层等),torch.nn 会帮助初始化好变量,并且模型类本身就是 nn.Module 的实例,看起来会更加协调统一。
  • torch.nn 可以结合 nn.Sequential 来简化模型的构建过程。

torch.nn.functional

  • torch.nn.functional 中的函数相比 torch.nn 更偏底层,封装性不高但透明度很高。开发者可以在其基础上定义出自己想要的功能。
  • 使用 torch.nn.functional 可以更方便地进行函数组合、复用等操作,适合那些喜欢使用函数式编程风格的开发者。当激活函数只需要在前向传播中使用时,使用 torch.nn.functional 中的激活函数会更加简洁。

4. 权重与参数管理

torch.nn

  • torch.nn 中的模块会自动管理权重和偏置等参数,这些参数可以通过 model.parameters() 方法获取,并用于优化算法的训练。

torch.nn.functional

  • torch.nn.functional 中的函数不直接管理权重和偏置等参数。如果需要使用这些参数,开发者需要在函数外部定义并初始化它们,然后将它们作为参数传入函数中。

5.举例说明

例子1:定义卷积层

使用 torch.nn

import torch.nn as nn

class MyConvNet(nn.Module):
    def __init__(self):
        super(MyConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        return x

# 实例化模型
model = MyConvNet()

# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as F

def my_conv_net(input_tensor, weight, bias=None):
    output_tensor = F.conv2d(input_tensor, weight, bias=bias, stride=1, padding=1)
    return output_tensor

# 定义卷积核的权重和偏置
weight = nn.Parameter(torch.randn(16, 1, 3, 3))
bias = nn.Parameter(torch.randn(16))

# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = my_conv_net(input_tensor, weight, bias)

在这个例子中,使用 torch.nn 定义了一个包含卷积层的模型类,而使用 torch.nn.functional 则是通过函数直接进行卷积操作。注意在使用 torch.nn.functional 时,需要手动定义和传递卷积核的权重和偏置。

例子2:应用激活函数

使用 torch.nn

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(x)
        return x

# 实例化模型
model = MyModel()

# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as F

def my_model(input_tensor):
    output_tensor = F.relu(input_tensor)
    return output_tensor

# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = my_model(input_tensor)

在这个例子中,使用 torch.nn 定义了一个包含 ReLU 激活函数的模型类,而使用 torch.nn.functional 则是通过函数直接应用 ReLU 激活函数。

例子3:定义和计算损失

使用 torch.nn

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        x = self.linear(x)
        return x

# 实例化模型
model = MyModel()

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()

# 前向传播和计算损失
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target)

使用 torch.nn.functional

import torch.nn.functional as F

def my_model(input_tensor):
    output_tensor = torch.matmul(input_tensor, weight.t()) + bias
    return output_tensor

# 定义权重和偏置
weight = nn.Parameter(torch.randn(10, 2))
bias = nn.Parameter(torch.randn(2))

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()

# 前向传播和计算损失
output_tensor = my_model(input_tensor)
loss = criterion(output_tensor, target)

在这个例子中,使用 torch.nn 定义了一个包含全连接层的模型类,并使用了 torch.nn 中的损失函数来计算损失。而使用 torch.nn.functional 则是通过函数直接进行线性变换,并使用 torch.nn 中的损失函数来计算损失。注意在使用 torch.nn.functional 时,需要手动定义和传递权重和偏置。

6. 小结

torch.nn 和 torch.nn.functional 在定义神经网络组件、应用激活函数和计算损失等方面存在显著的区别。torch.nn 提供了一种面向对象的方式来构建模型,而 torch.nn.functional 则提供了一种更灵活、更函数式的方式来执行相同的操作。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2278462.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

传统以太网问题与VLAN技术详解

传统以太网的问题 广播域:在网络中能接收同一广播信息的所有设备(计算机、交换机)等的集合 说明:在一个广播域内,当一个设备发送广播帧时,该域内的所有设备都能接收到这个广播帧。工作原理:在以…

OpenAI Whisper:语音识别技术的革新者—深入架构与参数

当下语音识别技术正以前所未有的速度发展,极大地推动了人机交互的便利性和效率。OpenAI的Whisper系统无疑是这一领域的佼佼者,它凭借其卓越的性能、广泛的适用性和创新的技术架构,正在重新定义语音转文本技术的规则。今天我们一起了解一下Whi…

WPS计算机二级•高效操作技巧

听说这里是目录哦 斜线表头 展示项目名称🍋‍🟩横排转竖排🍐批量删除表格空白行🍈方法一方法二建辅助列找空值 能量站😚 斜线表头 展示项目名称🍋‍🟩 选中单元格,单击右键➡️“设…

RabbitMQ实现延迟消息发送——实战篇

在项目中,我们经常需要使用消息队列来实现延迟任务,本篇文章就向各位介绍使用RabbitMQ如何实现延迟消息发送,由于是实战篇,所以不会讲太多理论的知识,还不太理解的可以先看看MQ的延迟消息的一个实现原理再来看这篇文章…

《Keras 3 在 TPU 上的肺炎分类》

Keras 3 在 TPU 上的肺炎分类 作者:Amy MiHyun Jang创建日期:2020/07/28最后修改时间:2024/02/12描述:TPU 上的医学图像分类。 (i) 此示例使用 Keras 3 在 Colab 中查看 GitHub 源 简介 设置 本教程将介…

1.17组会汇报

STRUC-BENCH: Are Large Language Models Good at Generating Complex Structured Tabular Data? STRUC-BENCH:大型语言模型擅长生成复杂的结构化表格数据吗?23年arXiv.org 1概括 这篇论文旨在评估大型语言模型(LLMs)在生成结构…

PyTorch使用教程(2)-torch包

1、简介 torch包是PyTorch框架最外层的包,主要是包含了张量的创建和基本操作、随机数生成器、序列化、局部梯度操作的上下文管理器等等,内容很多。我们基础学习的时候,只有关注张量的创建、序列化,随机数、张量的数学数学计算等常…

idea gradle compiler error: package xxx does not exist

idea 编译运行task时报项目内的包不存在,如果你试了网上的其它方法还不能解决,应该是你更新了新版idea,项目用的是旧版jdk,请在以下编译器设置中把项目JDK字节码版本设为8(jdk1.8,我这里是17请自行选择&…

Nmap之企业漏洞扫描(Enterprise Vulnerability Scanning for Nmap)

简介 Namp是一个开源的网络连接端扫描软件,主要用于网络发现和安全审核。‌它可以帮助用户识别网络上的设备、分析它们的服务、检测操作系统类型,甚至发现潜在的安全漏洞。Nmap由Fyodor开发,最初是为了满足网络管理员的需求,但随…

RabbitMQ前置概念

文章目录 1.AMQP协议是什么?2.rabbitmq端口介绍3.消息队列的作用和使用场景4.rabbitmq工作原理5.整体架构核心概念6.使用7.消费者消息推送限制(work模型)8.fanout交换机9.Direct交换机10.Topic交换机(推荐)11.声明队列…

RabbitMQ---TTL与死信

(一)TTL 1.TTL概念 TTL又叫过期时间 RabbitMQ可以对队列和消息设置TTL,当消息到达过期时间还没有被消费时就会自动删除 注:这里我们说的对队列设置TTL,是对队列上的消息设置TTL并不是对队列本身,不是说队列过期时间…

MySQL8数据库全攻略:版本特性、下载、安装、卸载与管理工具详解

大家好,我是袁庭新。 MySQL作为企业项目中的主流数据库,其5.x和8.x版本尤为常用。本文将详细介绍MySQL 8.x的特性、下载、安装、服务管理、卸载及管理工具,旨在帮助用户更好地掌握和使用MySQL数据库。 1.MySQL版本及下载 企业项目中使用的…

хорошо哈拉少wordpress俄语主题

хорошо哈拉少wordpress俄语主题 wordpress俄文网站模板,推荐做俄罗斯市场的外贸公司建俄语独立站使用。 演示 https://www.jianzhanpress.com/?p7360

【STM32-学习笔记-10-】BKP备份寄存器+时间戳

文章目录 BKP备份寄存器Ⅰ、BKP简介1. BKP的基本功能2. BKP的存储容量3. BKP的访问和操作4. BKP的应用场景5. BKP的控制寄存器 Ⅱ、BKP基本结构Ⅲ、BKP函数Ⅳ、BKP使用示例 时间戳一、Unix时间戳二、时间戳的转换(time.h函数介绍)Ⅰ、time()Ⅱ、mktime()…

Python毕业设计选题:基于python的酒店推荐系统_django+hadoop

开发语言:Python框架:djangoPython版本:python3.7.7数据库:mysql 5.7数据库工具:Navicat11开发软件:PyCharm 系统展示 管理员登录 管理员功能界面 用户管理 酒店客房管理 客房类型管理 客房预定管理 用户…

【c++继承篇】--继承之道:在C++的世界中编织血脉与传承

目录 引言 一、定义二、继承定义格式2.1定义格式2.2继承关系和访问限定符2.3继承后子类访问权限 三、基类和派生类赋值转换四、继承的作用域4.1同名变量4.2同名函数 五、派生类的默认成员构造函数5.1**构造函数调用顺序:**5.2**析构函数调用顺序:**5.3调…

Elasticsearch:Jira 连接器教程第二部分 - 6 个优化技巧

作者:来自 Elastic Gustavo Llermaly 将 Jira 连接到 Elasticsearch 后,我们现在将回顾最佳实践以升级此部署。 在本系列的第一部分中,我们配置了 Jira 连接器并将对象索引到 Elasticsearch 中。在第二部分中,我们将回顾一些最佳实…

【狂热算法篇】探秘图论之 Floyd 算法:解锁最短路径的神秘密码(通俗易懂版)

: 羑悻的小杀马特.-CSDN博客羑悻的小杀马特.擅长C/C题海汇总,AI学习,c的不归之路,等方面的知识,羑悻的小杀马特.关注算法,c,c语言,青少年编程领域.https://blog.csdn.net/2401_82648291?spm1010.2135.3001.5343 在本篇文章中,博主将带大家去学习所谓的…

npm的包管理

从哪里下载包 国外有一家 IT 公司,叫做 npm,Inc.这家公司旗下有一个非常著名的网站: https://www.npmjs.com/,它是全球最大的包共享平台,你可以从这个网站上搜索到任何你需要的包,只要你有足够的耐心!到目前位置,全球约…

GitLab:添加SSH密钥之前,您不能通过SSH来拉取或推送项目代码

1、查看服务器是否配置过 [rootkingbal-ecs-7612 ~]# cd .ssh/ [rootkingbal-ecs-7612 .ssh]# ls authorized_keys id_ed25519 id_ed25519.pub id_rsa id_rsa.pub2、创建密钥 $ ssh-keygen -t rsa -C kingbalkingbal.com # -C 后写你的邮箱 一路回车 3、复制密钥 [rootk…