PyTorch深度学习实战(12)—— 神经网络工具箱nn.functional

news2024/9/20 16:50:27

1. nn.functional

torch.nn中还有一个很常用的模块:nn.functionaltorch.nn中的大多数layer,在functional中都有一个与之相对应的函数。nn.functional中的函数和nn.Module的主要区别在于:使用nn.Module实现的layer是一个特殊的类,其由class layer(nn.Module)定义,会自动提取可学习的参数;使用nn.functional实现的layer更像是纯函数,由def function(input)定义。

2. nn.functional与nn.Module的区别

下面举例说明functional的使用,并对比它与nn.Module的不同之处:

In: input = t.randn(2, 3)
    model = nn.Linear(3, 4)
    output1 = model(input)
    output2 = nn.functional.linear(input, model.weight, model.bias)
    output1.equal(output2)
 Out:True
 
 In: b1 = nn.functional.relu(input)
    b2 = nn.ReLU()(input)
    b1.equal(b2)
 Out:True

此时读者可能会问,应该什么时候使用nn.Module,什么时候使用nn.functional呢?答案很简单,如果模型具有可学习的参数,那么最好用nn.Module,否则既可以使用nn.functional,也可以使用nn.Module。二者在性能上没有太大差异,具体的选择取决于个人的喜好。由于激活函数(如ReLU、sigmoid、tanh)、池化(如MaxPool)等层没有可学习参数,可以使用对应的functional函数代替,对于卷积、全连接等具有可学习参数的层,建议使用nn.Module。另外,虽然dropout操作也没有可学习参数,但是建议使用nn.Dropout而不是nn.functional.dropout,因为dropout在训练和测试两个阶段的行为有所差异,使用nn.Module对象能够通过model.eval()操作加以区分。下面举例说明如何在模型中搭配使用nn.Modulenn.functional

In: from torch.nn import functional as F
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = F.pool(F.relu(self.conv1(x)), 2)
            x = F.pool(F.relu(self.conv2(x)), 2)
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x


对于不具备可学习参数的层(如激活层、池化层等),可以将它们用函数代替,这样可以不用放置在构造函数__init__()中。对于有可学习参数的模块,也可以用functional代替,只不过实现起来较为烦琐,需要手动定义参数Parameter。例如,前面实现的全连接层,就可以将weight和bias两个参数单独拿出来,在构造函数中初始化为Parameter

In: class MyLinear(nn.Module):
        def __init__(self):
            super().__init__()
            self.weight = nn.Parameter(t.randn(3, 4))
            self.bias = nn.Parameter(t.zeros(3))
        def forward(self):
            return F.linear(input, weight, bias)


关于nn.functional的设计初衷,以及它和nn.Module的比较说明,读者可参考PyTorch论坛的相关讨论和说明。

3. 采样函数

nn.functional中还有一个常用的函数:采样函数torch.nn.functional.grid_sample,它的主要作用是对输入的Tensor进行双线性采样,并将输出变换为用户想要的形状。下面以lena为例进行说明:

In: to_pil(lena.data.squeeze(0)) # 原始的lena数据

In: # lena的形状是1×1×200×200,(N,C,Hin,Win)
    # 进行仿射变换,对图像进行旋转
    angle = -90 * math.pi / 180
    theta = t.tensor([[math.cos(angle), math.sin(-angle), 0], \                     
                      [math.sin(angle), math.cos(angle), 0]], dtype=t.float)
    # grid形状为(N,Hout,Wout,2)
    # grid最后一个维度大小为2,表示输入中pixel的位置信息,取值范围在(-1,1)
    grid = F.affine_grid(theta.unsqueeze(0), lena.size())
 
 In: import torch
    from torch.nn import functional as F
    import warnings
    warnings.filterwarnings("ignore")
    
    out = F.grid_sample(lena, grid=grid, mode='bilinear')
    to_pil(out.data.squeeze(0))


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2039736.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【人工智能】Transformers之Pipeline(九):物体检测(object-detection)

目录​​​​​​​ 一、引言 二、物体检测(object-detection) 2.1 概述 2.2 技术原理 2.3 应用场景 2.4 pipeline参数 2.4.1 pipeline对象实例化参数 2.4.2 pipeline对象使用参数 2.4 pipeline实战 2.5 模型排名 三、总结 一、引言 pipel…

黑马头条vue2.0项目实战(八)——文章评论

目录 1. 展示文章评论列表 1.1 准备组件 1.2 获取文章评论数据并展示 1.3 展示文章评论总数量 1.4 文章评论项 2. 评论点赞 3. 发布文章评论 3.1 准备弹出层 3.2 封装发布文章评论组件 3.3 请求发布 4. 评论回复 4.1 准备回复弹层 4.2 封装内容组件 4.3 处理头部…

【深度学习】创建和训练Transformer神经网络模型,将葡萄牙语翻译成英语

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言1. 安装2. 数据处理2.1 下载数据集2.2 设置标记器2.3 使用tf.data设置数据管道 3. 测试数据集4. 定义组件4.1 嵌入和位置编码层4.2 添加并规范化4.3 基础注意力…

Android 12系统源码_多屏幕(二)模拟辅助设备功能开关实现原理

前言 上一篇我们通过为Android系统开启模拟辅助设备功能开关,最终实现了将一个Activity显示到多个屏幕的效果。 本篇文章我们具体来分析一下当我们开启模拟辅助设备功能开关的时候,Android系统做了什么哪些操作。 一、模拟辅助设备功能开关应用位置 …

Qt5编译qmqtt库使用MQTT协议连接华为云IOT完成数据上传与交互

一、前言 随着物联网技术的发展,越来越多的设备通过网络互相连接,形成了庞大的智能系统。这些系统能够收集、分析并响应各种数据,从而实现自动化控制和智能化管理。在这个背景下,MQTT 成为了一个广泛使用的轻量级消息传输协议,特别适用于资源受限的环境,如移动应用或远程…

WebSocket 实现:注解与原生方式对比

WebSocket 作为一种在单个长连接上进行全双工、双向通信的协议,已经成为现代Web应用中实现实时通信的重要技术。本文将探讨如何使用注解和原生方式来实现 WebSocket,并对这两种方法进行比较。 一、注解方式实现 WebSocket 在许多现代Java框架中&#x…

GBJ406-ASEMI无人机专用GBJ406

编辑:ll GBJ406-ASEMI无人机专用GBJ406 型号:GBJ406 品牌:ASEMI 封装:GBJ-4 批号:2024 现货:50000 最大重复峰值反向电压:600V 最大正向平均整流电流(Vdss):4A 功率(Pd)&am…

43.【C语言】指针(重难点)(F)

目录 15.二级指针 *定义 *演示 16.三级以及多级指针 *三级指针的定义 *多级指针的定义 17.指针数组 *定义 *代码 18.指针数组模拟二维数组 往期推荐 15.二级指针 *定义 之前讲的指针全是一级指针 int a 1; int *pa &a;//一级指针 如果写成 int a 1; int *pa &a…

MES生产执行系统源码,支持 SaaS 多租户,技术架构:springboot + vue-element-plus-admin

MES的定义与功能 MES是制造业中一种重要的管理信息系统,用于协调和监控整个生产过程。它通过收集、分析和处理各种生产数据,实现对生产流程的实时跟踪和监控,并为决策者提供准确的数据支持。MES涵盖了工厂运营、计划排程、质量管理、设备维护…

AI时代下的智慧体育, 用科技赋能体育创新

在科技飞速发展的今天,人工智能(AI)已成为推动各行各业创新的重要力量。体育,作为人类文明的重要组成部分,同样在AI的浪潮中迎来了新的变革机遇。AI时代下的智慧体育,不再局限于传统的运动模式,…

Spring Boot集成Devtools实现热更新?

1.什么Devtools? DevTools是开发者工具集,主要用于简化开发过程中的热部署问题。 热部署是指在开发过程中,当代码发生变化时,无需手动重启应用,系统能够自动检测并重新加载修改后的代码,大大提高了开发效率…

量化投资策略与技术学习PART2:量化选股之风格轮动

市场上的投资者是有偏好的,有时候偏好于价值股,有时候偏好于成长股,有时偏于大盘,有时又偏于小盘,由于投资者的这种不同的交易行为,形成了市场风格,本节主要研究如何判断市场风格,以…

MyBatis介绍(1)

前言 MyBatis 是一个半 ORM(对象关系映射)框架,它内部封装了 JDBC,开发时只需要关注 SQL 语句本身,不需要花费精力去处理加载驱动、创建连接、创建 statement 等繁杂的过程。程序员直接编写原生态 sql,可以…

【java报错已解决】error: metadata-generation-failed

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 文章目录 一、问题描述1.1 报错示例1.2 报错分析1.3 解决思路二、解决方法2.1 方法一:检查环境变量2.2 步骤二&…

嵌入式学习Day30---Linux软件编程---进程间的通信

目录 一、Linux操作ipc对象(内存文件)的命令 1.1.查看命令 1.ipcs 2.ipcs -q(查看信息队列) 3.ipcs -m(查看共享内存) 4.ipcs -s(查看信号灯) 1.2.删除命令 1.ipcrm -q id 2.ipc…

conda虚拟环境中pip的混淆

在conda的虚拟环境中&#xff0c;会在<PATH>\Anaconda\envs\<ENV_NAME>\Scripts目录下存在 pip.exe 和pip3.exe. 如果存在多个虚拟环境是&#xff0c;加上conda自带的python版本&#xff0c;系统中存在多个pip和pip3指令&#xff0c;在执行安装的时候&#xff0c;…

【AI 绘画】 文生图图生图(基于diffusers)

AI 绘画- 文生图&图生图&#xff08;基于diffusers&#xff09; 1. 效果展示 本次测试主要结果展示如下&#xff1a; SDXL文生图 可爱Lora 2. 基本原理 模型基本原理介绍如下 stable diffusion首先训练一个自编码器&#xff0c;学习将图像数据压缩为低维表示。通过使…

VINS-Fusion的点云转换成ego-planner能用的点云

背景 2013年智在飞翔比赛&#xff1a; RoboMaster | 无人飞行器智能感知技术竞赛https://www.robomaster.com/zh-CN/robo/drone?djifromnav_drone 用vins-fusion来定位&#xff0c;他自己会生成点云数据。 进一步用ego-planner来路径规划和避障&#xff0c;需要用到vins-f…

mpls静态lsp实验

实验需求 R1、R2和R3之间已经部署了IGP协议&#xff0c;故192.168.10.0/24与192.168.20.0/24网络之间已经能够互访。现要求通过配置 静态LSP&#xff0c;使得这两个网络之间能基于MPLS进行互访&#xff0c;标签分配如图 组网图 实验思路 1、R1、R2和R3之间已经部署了IGP协议…

非科班出身的你,如何转行AI算法工程师?

想从其他行业转行到算法工程师的人&#xff0c;无外乎以下几个原因&#xff1a; 现在工资太低工作没有前景对现在的工作没有热情对算法工程师很感兴趣 那么&#xff0c;如何成功转行&#xff1f;给大家整理一些学习方式。 1&#xff09;数据结构和算法&#xff1a;推荐大家使…