深入理解 `torch.nn.Linear`:维度变换的过程详解与实践(附图、公式、代码)

news2024/11/17 11:21:58

在深度学习中,线性变换是最基础的操作之一。PyTorch 提供了 torch.nn.Linear 模块,用来实现全连接层(Fully Connected Layer)。在使用时,理解维度如何从输入映射到输出,并掌握其具体的变换过程,是至关重要的。本文将从线性变换的原理出发,结合图示、公式和代码,详细解析 torch.nn.Linear 的维度变化过程,帮助你深入理解这个关键模块。


1. 什么是 torch.nn.Linear

torch.nn.Linear 是 PyTorch 提供的一个线性变换模块,通常用于神经网络中的全连接层。在一个全连接层中,输入向量通过权重矩阵和偏置项进行线性变换,从而得到输出向量。其数学公式为:

[ \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} ]

其中:

  • ( \mathbf{x} ) 是输入向量,
  • ( \mathbf{W} ) 是权重矩阵,
  • ( \mathbf{b} ) 是偏置向量,
  • ( \mathbf{y} ) 是输出向量。

torch.nn.Linear 将输入的特征维度映射到输出的特征维度,常用于神经网络的最后一层或者中间层的线性计算。


2. torch.nn.Linear 的维度定义

在创建 torch.nn.Linear 实例时,我们需要定义两个重要参数:

  • in_features: 输入的特征数量,即输入向量的维度。
  • out_features: 输出的特征数量,即输出向量的维度。
import torch
import torch.nn as nn

# 创建线性变换层:从 4 维输入映射到 2 维输出
linear_layer = nn.Linear(in_features=4, out_features=2)

在上述代码中,in_features=4out_features=2 表示输入是 4 维的,输出将被线性变换为 2 维。


3. 线性变换过程中的维度变化

为了更好地理解维度的变化,我们可以通过一个具体的例子来说明。假设我们有一个形状为 (batch_size, in_features) 的输入张量,其维度为 batch_size = 3in_features = 4

  1. 输入维度:假设输入张量 x 的维度是 (3, 4),即有 3 个样本,每个样本有 4 个特征。
  2. 权重矩阵的维度:权重矩阵 W 的维度是 (out_features, in_features),即 (2, 4),表示它将 4 维的输入映射到 2 维的输出。
  3. 偏置向量的维度:偏置向量 b 的维度是 (out_features),即 (2)

在执行 y = W * x + b 之后,输出张量 y 的维度将变为 (batch_size, out_features),即 (3, 2)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4. 公式推导

线性变换的数学公式为:

[ \mathbf{y}_i = \mathbf{W} \mathbf{x}_i + \mathbf{b} ]

其中:

  • ( \mathbf{x}_i ) 是输入的第 ( i ) 个样本,形状为 (in_features)
  • ( \mathbf{W} ) 是权重矩阵,形状为 (out_features, in_features)
  • ( \mathbf{b} ) 是偏置向量,形状为 (out_features)
  • ( \mathbf{y}_i ) 是输出,形状为 (out_features)

5. 实际代码示例

我们通过具体的代码来验证上述维度变换过程。

import torch
import torch.nn as nn

# 创建线性层,将输入 4 维映射为 2 维
linear_layer = nn.Linear(in_features=4, out_features=2)

# 打印权重和偏置的形状
print("权重矩阵的形状:", linear_layer.weight.shape)  # (2, 4)
print("偏置向量的形状:", linear_layer.bias.shape)    # (2)

# 构造输入张量,形状为 (3, 4)
x = torch.randn(3, 4)
print("输入张量的形状:", x.shape)  # (3, 4)

# 进行线性变换
output = linear_layer(x)
print("输出张量的形状:", output.shape)  # (3, 2)

# 打印输入和输出
print("输入张量:\n", x)
print("输出张量:\n", output)

执行上述代码,输出结果如下:

权重矩阵的形状: torch.Size([2, 4])
偏置向量的形状: torch.Size([2])
输入张量的形状: torch.Size([3, 4])
输出张量的形状: torch.Size([3, 2])
输入张量:
 tensor([[-0.3451,  1.2234, -0.4567,  0.9876],
         [ 0.1234, -0.5432,  1.4567, -1.1234],
         [ 0.8765,  0.4567, -0.8765,  1.2345]])
输出张量:
 tensor([[ 0.2334, -0.5432],
         [ 0.9876, -1.1234],
         [ 1.2234,  0.8765]])

从上面的输出结果可以看出,输入 (3, 4) 被映射为输出 (3, 2),符合预期的维度变换。


6. torch.nn.Linear 的进阶使用

除了基本的线性变换,torch.nn.Linear 还可以结合其他 PyTorch 模块进行更加复杂的应用。以下是一个结合 ReLU 激活函数的例子:

import torch
import torch.nn as nn

# 创建线性层和 ReLU 激活函数
linear_layer = nn.Linear(4, 2)
activation = nn.ReLU()

# 输入张量
x = torch.randn(3, 4)

# 线性变换 + ReLU 激活
output = activation(linear_layer(x))

print("经过 ReLU 激活后的输出张量:\n", output)

此代码实现了线性变换后的激活操作,ReLU 函数将所有负值截断为零,保留正值。


7. 常见问题与调试技巧

在使用 torch.nn.Linear 时,有一些常见问题和调试技巧可以帮助开发者避免陷入错误:

  1. 输入与权重的维度不匹配:确保输入张量的特征维度与 in_features 匹配,否则会导致维度不一致的错误。
  2. 学习率调节:线性层的权重和偏置是需要通过反向传播来更新的,在训练过程中可以调节学习率,以提高模型的收敛速度。
  3. 多层线性层的堆叠:在神经网络中,通常会堆叠多个线性层,通过激活函数和非线性操作来提高模型的表达能力。

8. 总结

本文详细解析了 PyTorch 中 torch.nn.Linear 模块的维度变换过程,通过公式、代码和图示帮助读者理解其内部机制。在实际的深度学习应用中,线性层是最基本也是最重要的组成部分之一。希望通过本文的讲解,你能够更深入地掌握 torch.nn.Linear 的使用方法,并能在项目中灵活运用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2171466.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

更改远程访问端口

1、背景 在客户现场,由于安全限制,在内网的交换机中配置的某些限制,不允许使用22端口作为远程访问服务器的端口,此时就需要更改远程访问的端口。 2、前提 在修改默认的远程访问端口22时,可以需要在Linux服务器中支持…

三.python入门语法1

目录 1. 算数运算和关系运算 1.1. 算术运算符 1.2. 关系运算符 习题 2.赋值运算和逻辑运算 2.1. 赋值运算符 2.2. 逻辑运算符 3.位运算符 1)位与运算(A&B) 2)位或运算(A|B) 3)异或位…

uni-app运行到 Android 真机和Android studio模拟器

文章目录 1、运行到Android 真机2、运行到Android studio模拟器2.1、运行到Android studio模拟器Android studio的安装步骤2.2、安装android SDK2.3、新增虚拟设备2.4、项目运行 3、安装报错3.1、安卓真机调试提示检测不到手机【解决办法】3.2、Android Studio中缺少System Ima…

OpenCV与AI深度学习 | 实战 | 使用OpenCV和Streamlit搭建虚拟化妆应用程序(附源码)

本文来源公众号“OpenCV与AI深度学习”,仅用于学术分享,侵权删,干货满满。 原文链接:实战 | 使用OpenCV和Streamlit搭建虚拟化妆应用程序(附源码) 现看看demo演示。 本文将介绍如何使用Streamlit和OpenCV…

Excel锁定单元格,使其不可再编辑

‌在Excel中,锁定单元格后仍然可以编辑‌,这主要涉及到对特定单元格或区域的锁定与保护工作表的设置。以下是实现这一功能的具体步骤: ‌解除工作表的锁定状态‌:首先,需要全选表格(使用CtrlA快捷键&#x…

C语言进程

什么是进程 什么是程序 一组可以被计算机直接识别的 有序 指令 的集合。 通俗讲:C语言编译后生成的可执行文件就是一个程序。 那么程序是静态还是动态的? 程序是可以被存储在磁盘上的,所以程序是静态的。 那什么是进程 进程是程序的执行过…

VS code 使用 Jupyter Notebook 时显示 line number

VS code 使用 Jupyter Notebook 时显示 line number 引言正文引言 有些时候,我们在 VS code 中必须要使用 Jupyter Notebook,但是默认情况下,Jupyter Notebook 是不显示 Line number 的,这对于调试工作的定位是不友好的,这里我们将介绍如何让 Jupyter Notebook 显示 Line…

认识联合体和枚举

目录 一.联合体 1.联合体的声明 2.联合体的特点 (一)内存共享 (二)大小等于最大成员的大小 另一特殊情况: (三)一次只能使用一个成员 3.联合体相比较于结构体 (一)内存分配 …

c++反汇编逆向还原指令add sub imul idiv cdq

add 加法指令 比如add a,b 逆向还原为aab; sub 减法 比如sub a,b 逆向还原为aa-b; imul 乘法 比如sub a,b 逆向还原为aa*b; idiv 除法 比如sub a,b 逆向还原为aa/b; cdq 在x86 汇编中,用于扩展 eax 寄存器的符号位…

基于python深度学习遥感影像地物分类与目标识别、分割实践技术

我国高分辨率对地观测系统重大专项已全面启动,高空间、高光谱、高时间分辨率和宽地面覆盖于一体的全球天空地一体化立体对地观测网逐步形成,将成为保障国家安全的基础性和战略性资源。未来10年全球每天获取的观测数据将超过10PB,遥感大数据时…

优思学院:如何借助“六西格玛设计”流程确保产品创新成功?

六西格玛设计(DFSS, Design for Six Sigma)是一种专注于产品设计初期减少变异、确保高质量的方法。虽然六西格玛的核心目标是通过减少流程和产品变异来提升质量,但它对创新过程有着重要的支持作用。创新过程中,六西格玛设计能确保…

开源b2b2c商城系统流程 多用户商城系统流程图

在选择多用户商城系统时,服务质量至关重要。商淘云多用户商城系统凭借其卓越的功能和强大的客户支持,成为了许多企业的首选。下面我们一起分析多用户商城的特性及b2b2c商城系统思维导图,文中的图大家需要的可评论“666”领取。 首先&#xff…

【含文档】基于Springboot+Vue的学生宿舍管理系统(含源码+数据库+lw)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统定…

tomcat 文件上传 (CVE-2017-12615)

漏洞描述: 当 Tomcat 运行在 Windows 主机上,且启用了 HTTP PUT 请求方法 影响范围: Apache Tomcat 7.0.0 - 7.0.79 漏洞复现: 创建vulfocus靶场容器 poc #CVE-2017-12615 POC import requests import optparse import ospar…

mysql索引 -- 全文索引介绍(如何创建,使用),explain关键字

目录 全文索引 引入 介绍 创建 使用 表数据 简单搜索 explain关键字 使用全文索引 mysql索引结构详细介绍 -- mysql索引 -- 索引的硬件理解(磁盘,磁盘与系统),软件理解(mysql,与系统io,buffer pool),索引结构介绍和理解(page内部,page之间,为什么是b树)-CSDN博客 全文…

UE5: Content browser工具编写02

DebugHeader.h 中的全局变量,已经在一个cpp file中被include了,如果在另一个cpp file中再include它,就会有一些conflicts。先全部给加一个static Add static keyword to debug functionsWrap all the functions inside of a namespaceprint …

Linux入门攻坚——34、nsswitch、pam、rsyslog和loganalyzer前端展示工具

nsswitch&#xff1a;network service switch 名称解析&#xff1a;name <---> id 认证服务&#xff1a;用户名、密码验证或token验证等 名称解析和认证服务都涉及查找位置&#xff0c;即保存在哪里。如linux认证&#xff0c;passwd、shadow&#xff0c;是在文件中&…

Linux标准IO(五)-I/O缓冲详解

1.简介 出于速度和效率的考虑&#xff0c;系统 I/O 调用&#xff08;即文件 I/O&#xff0c;open、read、write 等&#xff09;和标准 C 语言库 I/O 函数&#xff08;即标准 I/O 函数&#xff09;在操作磁盘文件时会对数据进行缓冲&#xff0c;本小节将讨论文件 I/O 和标准 I/…

20 vue3之自定义hooks

Vue3 自定义Hook的作用 主要用来处理复用代码逻辑的一些封装 Vue3 的 hook函数 相当于 vue2 的 mixin, 不同在与 hooks 是函数Vue3 的 hook函数 可以帮助我们提高代码的复用性, 让我们能在不同的组件中都利用 hooks 函数 这个在vue2 就已经有一个东西是Mixins mixins就是将…

8,STM32CubeMX配置SPI工程(读取norflash的ID)

1&#xff0c;前言 单片机型号&#xff1a;STM32F407 编程环境 &#xff1a;STM32CubeMX Keil v5 硬件连接 &#xff1a;SPI1&#xff0c;CS/SS--->PB14 注&#xff1a;本工程在1&#xff0c;STM32CubeMX工程基础&#xff08;配置Debug、时钟树&#xff09;基础上完…