PyTorch之nn.Module与nn.functional用法区别

news2024/10/7 10:17:00

文章目录

  • 1. nn.Module
  • 2. nn.functional
    • 2.1 基本用法
    • 2.2 常用函数
  • 3. nn.Module 与 nn.functional
    • 3.1 主要区别
    • 3.2 具体样例:nn.ReLU() 与 F.relu()
  • 参考资料

1. nn.Module

在PyTorch中,nn.Module 类扮演着核心角色,它是构建任何自定义神经网络层、复杂模块或完整神经网络架构的基础构建块。通过继承 nn.Module 并在其子类中定义模型结构和前向传播逻辑(forward() 方法),开发者能够方便地搭建并训练深度学习模型。

关于 nn.Module 的更多介绍可以参考博客:PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用详解

这里,我们基于nn.Module创建一个简单的神经网络模型,实现代码如下:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = self.layer2(x)
        return x

2. nn.functional

nn.functional 是PyTorch中一个重要的模块,它包含了许多用于构建神经网络的函数。与 nn.Module 不同,nn.functional 中的函数不具有可学习的参数。这些函数通常用于执行各种非线性操作、损失函数、激活函数等。

2.1 基本用法

如何在神经网络中使用nn.functional?

在PyTorch中,你可以轻松地在神经网络中使用 nn.functional 函数。通常,你只需将输入数据传递给这些函数,并将它们作为网络的一部分。

以下是一个简单的示例,演示如何在一个全连接神经网络中使用ReLU激活函数:

import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

在上述示例中,我们首先导入nn.functional 模块,然后在网络的forward 方法中使用F.relu 函数作为激活函数。

nn.functional 的主要优势是它的计算效率和灵活性,因为它允许你以函数的方式直接调用这些操作,而不需要创建额外的层。

2.2 常用函数

(1)激活函数

激活函数是神经网络中的关键组件,它们引入非线性性,使网络能够拟合复杂的数据。以下是一些常见的激活函数:

  • ReLU(Rectified Linear Unit)
    ReLU是一种简单而有效的激活函数,它将输入值小于零的部分设为零,大于零的部分保持不变。它的数学表达式如下:
output = F.relu(input)
  • Sigmoid
    Sigmoid函数将输入值映射到0和1之间,常用于二分类问题的输出层。它的数学表达式如下:
output = F.sigmoid(input)
  • Tanh(双曲正切)
    Tanh函数将输入值映射到-1和1之间,它具有零中心化的特性,通常在循环神经网络中使用。它的数学表达式如下:
output = F.tanh(input)

(2)损失函数

损失函数用于度量模型的预测与真实标签之间的差距。PyTorch的nn.functional 模块包含了各种常用的损失函数,例如:

  • 交叉熵损失(Cross-Entropy Loss)
    交叉熵损失通常用于多分类问题,计算模型的预测分布与真实分布之间的差异。它的数学表达式如下:
loss = F.cross_entropy(input, target)
  • 均方误差损失(Mean Squared Error Loss)
    均方误差损失通常用于回归问题,度量模型的预测值与真实值之间的平方差。它的数学表达式如下:
loss = F.mse_loss(input, target)
  • L1 损失
    L1损失度量预测值与真实值之间的绝对差距,通常用于稀疏性正则化。它的数学表达式如下:
loss = F.l1_loss(input, target)

(3)非线性操作

nn.functional 模块还包含了许多非线性操作,如池化、归一化等。

  • 最大池化(Max Pooling)
    最大池化是一种用于减小特征图尺寸的操作,通常用于卷积神经网络中。它的数学表达式如下:
output = F.max_pool2d(input, kernel_size)
  • 批量归一化(Batch Normalization)
    批量归一化是一种用于提高训练稳定性和加速收敛的技术。它的数学表达式如下:
output = F.batch_norm(input, mean, std, weight, bias)

3. nn.Module 与 nn.functional

3.1 主要区别

nn.Module 与 nn.functional 的主要区别在于:

  • nn.Module实现的layers是一个特殊的类,都是由class Layer(nn.Module)定义,会自动提取可学习的参数;
  • nn.functional中的函数更像是纯函数,由def function(input)定义。

注意:

  1. 如果模型有可学习的参数时,最好使用nn.Module。
  2. 激活函数(ReLU、sigmoid、Tanh)、池化(MaxPool)等层没有可学习的参数,可以使用对应的functional函数。
  3. 卷积、全连接等有可学习参数的网络建议使用nn.Module。
  4. dropout没有可学习参数,但建议使用nn.Dropout而不是nn.functional.dropout。

3.2 具体样例:nn.ReLU() 与 F.relu()

nn.ReLU() :

import torch.nn as nn
'''
nn.ReLU()

F.relu():

import torch.nn.functional as F
'''
out = F.relu(input)

其实这两种方法都是使用relu激活,只是使用的场景不一样,F.relu()是函数调用,一般使用在foreward函数里。而nn.ReLU()是模块调用,一般在定义网络层的时候使用。

当用print(net)输出时,nn.ReLU()会有对应的层,而F.ReLU()是没有输出的。

import torch.nn as nn
import torch.nn.functional as F

class NET1(nn.Module):
    def __init__(self):
        super(NET1, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()  # 模块的激活函数

    def forward(self, x):
        out = self.conv(x)
        x = self.bn(x)
        out = self.relu()
        return out


class NET2(nn.Module):
    def __init__(self):
        super(NET2, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn = nn.BatchNorm2d(16)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        out = F.relu(x)  # 函数的激活函数
        return out


net1 = NET1()
net2 = NET2()
print(net1)
print(net2)

在这里插入图片描述

参考资料

  • PyTorch的nn.Module类的详细介绍
  • PyTorch nn.functional 模块详解:探索神经网络的魔法工具箱
  • pytorch:F.relu() 与 nn.ReLU() 的区别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1889517.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

这次发现的开源版本我愿意称之为最具学习价值的商城系统|商城源码点击进入

这是一款我发现的强大、灵活、易用的商城系统,成为我的的首选商城框架,让我的商城开发事半功倍!这款开源商城项目具有多元的商业模式满足了任何使用场景的需求。 有S2B2C供应链商城、B2B2C多商户商城、O2O外卖商城、B2C单商户商城、社区团购、…

全网最详细,零基础学会AI绘画Stable Diffusion,学不会来打我!

前言 什么是Stable Diffusion 自从去年10月份Stable Diffusion开源以来,仅过了半年的时间,如今它已经能够创作出精美细致的二次元插画,媲美真人的赛博Coser,以及具有独特风格的AI动画。 无论你只是感兴趣,还是想了解…

旅游管理系统-计算机毕业设计源码16021

摘 要 本文旨在设计和实现一个基于Spring Boot框架的旅游管理系统。该系统通过利用Spring Boot的快速开发特性和丰富的生态系统,提供了一个高效、可靠和灵活的解决方案。系统将实现旅游景点信息的管理、线路规划、跟团游玩、旅游攻略、酒店信息管理、订单管理和用户…

有哪些手持小风扇品牌推荐?五大手持小风扇诚意推荐!

在炎炎夏日,一款便携且高效的手持小风扇无疑是消暑的必备神器。为了帮助大家轻松应对酷暑,我们精心挑选了五大手持小风扇品牌进行诚意推荐。这些品牌不仅拥有出色的降温效果,更在外观设计、便携性、续航能力及操作便捷性上表现卓越。接下来&a…

第三方软件测试公司分享:软件渗透测试的测试内容和注意事项

软件渗透测试是一种通过模拟攻击的方式来评估软件系统的安全性和漏洞,以发现并修复系统中的安全弱点。保护用户的数据和信息不被恶意攻击者利用,也是软件产品开发流程中重要的环节,可以帮助开发团队完善产品质量,提高用户满意度。…

代码随想录-二叉搜索树①

目录 二叉搜索树的定义 700. 二叉搜索树中的搜索 题目描述: 输入输出示例: 思路和想法: 98. 验证二叉搜索树 题目描述: 输入输出示例: 思路和想法: 530. 二叉搜索树的最小绝对差 题目描述&#x…

03:Spring MVC

文章目录 一:Spring MVC简介1:说说自己对于Spring MVC的了解?1.1:流程说明: 一:Spring MVC简介 Spring MVC就是一个MVC框架,Spring MVC annotation式的开发比Struts2方便,可以直接代…

c/c++语言的一种日志的编写办法

今日分享一下,从某源码中看到这种日志编写方式,很强。可以借鉴。 这个函数调用的日志函数是不一样的,仔细观看: 这几种日志输出函数,背后都调用了相同的调用。 与之对应的区别就是,函数名称的差异取决于…

【云原生监控】Prometheus 普罗米修斯从搭建到使用详解

目录 一、前言 二、服务监控概述 2.1 什么是微服务监控 2.2 微服务监控指标 2.3 微服务监控工具 三、Prometheus概述 3.1 Prometheus是什么 3.2 Prometheus 特点 3.3 Prometheus 架构图 3.3.1 Prometheus核心组件 3.3.2 Prometheus 工作流程 3.4 Prometheus 应用场景…

【Python基础篇】一篇文章入门Python,进入Python的世界

文章目录 0.前言1.打印(Hello,World)2.创建变量3.打印升级3.1 打印一句话中间加变量3.2 sep设置分隔符3.3 end和换行 4. 注释 0.前言 大家好,我是小辰,前几天做了个重大的决定,学习python。 首先&#xff0…

wine烧录stm32教程

前言 使用环境 ubuntu22.04 因为stlnk的线太短了并且容易扯断开,想使用串口进行烧录,但是又不想每次烧录代码都拔下短接帽(暂时不知道stm32flash怎么支持ISP一键下载),故写下此教程步骤一:安装wine 首选我们要下载wine,由于国内下…

跨国企业与IP地址定位的商业策略

随着经济全球化的发展,许多企业都选择拓宽国际市场,而跨国企业需要在全球范围内进行高效的市场运营和管理,以应对不同市场的需求和竞争。IP地址定位技术能够通过识别用户的地理位置,为企业提供重要的数据支持,帮助他们…

记录搭建一台可域名访问的HTTPS服务器

一、背景 近期公司业务涉及到微信小程序,即将开发完成需要按照微信小程序平台的要求提供带证书的域名请求服务器。 资源背景介绍如下: 1、域名 公司已有一个二级域名,再次申请新的二级域名并且实现ICP备案不仅需要花重金重新购买,…

深入浅出:进程管理的艺术

目录 进程的定义 进程的特征 进程的状态 进程与程序的区别 进程的控制和管理 进程的特点 1. 虚拟内存空间的分配 2. 时间片轮转调度 图解: 进程段 数据段(Data Segment) 正文段(Text Segment) 堆栈段&…

十二、【源码】Spring整合AOP

源码地址:https://github.com/spring-projects/spring-framework 仓库地址:https://gitcode.net/qq_42665745/spring/-/tree/12-spring-aop Spring整合AOP 核心类: DefaultAdvisorAutoProxyCreator:用于在Spring框架中自动为符…

华为交换机基本命令配置(创建vlan、配置telnet登录)

<HUAWEI>system-view 进入系统视图 [HUAWEI]sysname SW1 交换机命名为SW1 [SW1]undo info-center enable 关闭消息中心 [SW1]quit 退出当前视图 <SW1>display vlan 查看vlan详情 <SW1>system-view 进入系统视图 [SW1]vlan 5 …

帮找Java Bug,面试,项目,解决Java问题

本人是个Java老程序员&#xff0c;如果你有解决不了的问题&#xff0c;或者面试的时候需要人帮助&#xff0c;或者求职就业上任何问题都可以联系我&#xff0c;下面是我微信&#xff0c;欢迎联系我&#xff01;

慧哥Saas充电桩开源平台 V2.5.5

文章目录 原地址&#xff1a;https://gitee.com/chouleng/cdzkjjh&#xff0c;更换新的地址如下 [点击此链接 https://gitee.com/chouleng/huili-cloud](https://gitee.com/chouleng/huili-cloud)一、产品功能部分截图1.手机端&#xff08;小程序、安卓、ios&#xff09;2.PC端…

豪车视频改字,节假日祝福视频改字小程序制作搭建开发

目录 前言&#xff1a; 一、视频改字小程序功能介绍 二、怎么对短视频模板进行改字&#xff1f; 三、这个短视频改字的项目怎么样&#xff1f; 总结&#xff1a; 前言&#xff1a; 现在很多豪车改字的短视频&#xff0c;节假日祝福的小视频&#xff0c;有不少直播在弄这个…

【Java学习笔记】java图形界面编程

在前面的章节中&#xff0c;我们开发运行的应用程序都没有图形界面&#xff0c;但是很多应用软件&#xff0c;如Windows下的Office办公软件、扑克牌接龙游戏软件、企业进销存ERP系统等&#xff0c;都有很漂亮的图形界面。素以需要我们开发具有图形界面的软件。 Java图形界面编程…