深入理解全连接层:从线性代数到 PyTorch 中的 nn.Linear 和 nn.Parameter

news2025/1/12 23:26:28

文章目录

这篇文章会从基础的一个数学概念到对应的代码实现,你将了解到:

  • 为什么nn.Parameter()接受 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)作为参数?
  • 为什么不是torch.matmul(self.weight, x) + self.bias
  • 如何使用torch.matmul()@F.linear() 去等价地实现nn.Linear()的输出。

数学概念(全连接层,线性层)

线性变化是数学中一个基础的概念,它描述了如何通过线性变换将输入映射到输出。在线性代数中,线性变化通常表示为矩阵乘法。在神经网络中,线性层的核心就是实现这样的矩阵运算。

数学公式:

给定一个输入向量 x ∈ R n \mathbf{x} \in \mathbb{R}^n xRn 和一个输出向量 y ∈ R m \mathbf{y} \in \mathbb{R}^m yRm,线性变化通过矩阵 W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} WRm×n 和偏置项 b ∈ R m \mathbf{b} \in \mathbb{R}^m bRm 进行变换,其公式为:
y = W x + b \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} y=Wx+b

  • W \mathbf{W} W:是权重矩阵,维度为 m × n m \times n m×n,它决定了输入向量如何线性变换到输出空间;
  • x \mathbf{x} x:是输入向量,维度为 n n n,表示特征数据;
  • b \mathbf{b} b:是偏置向量,维度为 m m m,用来调整线性变换的输出;
  • y \mathbf{y} y:是输出向量,维度为 m m m,是变换后的结果。

例子:

如果输入向量 x \mathbf{x} x 有 3 个特征,输出向量 y \mathbf{y} y 有 2 个特征,则权重矩阵 W \mathbf{W} W 的形状为 2 × 3 2 \times 3 2×3。假设:
W = [ 1 2 3 4 5 6 ] , x = [ 1 2 3 ] , b = [ 0 1 ] \mathbf{W} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}, \quad \mathbf{x} = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, \quad \mathbf{b} = \begin{bmatrix} 0 \\ 1 \end{bmatrix} W=[142536],x= 123 ,b=[01]
线性变换计算为:
y = W x + b = [ 1 2 3 4 5 6 ] [ 1 2 3 ] + [ 0 1 ] = [ 14 32 ] + [ 0 1 ] = [ 14 33 ] \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 33 \end{bmatrix} y=Wx+b=[142536] 123 +[01]=[1432]+[01]=[1433]
矩阵运算过程:
[ 1 2 3 4 5 6 ] [ 1 2 3 ] = [ ( 1 × 1 ) + ( 2 × 2 ) + ( 3 × 3 ) ( 4 × 1 ) + ( 5 × 2 ) + ( 6 × 3 ) ] = [ 14 32 ] \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} = \begin{bmatrix} (1 \times 1) + (2 \times 2) + (3 \times 3) \\ (4 \times 1) + (5 \times 2) + (6 \times 3) \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} [142536] 123 =[(1×1)+(2×2)+(3×3)(4×1)+(5×2)+(6×3)]=[1432]

nn.Linear()

nn.Linear() 会自动创建一个权重矩阵(Weight)和偏置项(Bias),并将它们应用到输入上。

代码示例:

import torch
import torch.nn as nn

# 定义一个输入为3,输出为2的线性层
linear_layer = nn.Linear(3, 2)

# 打印权重矩阵和偏置项
print("权重矩阵 W:")
print(linear_layer.weight)

print("偏置项 b:")
print(linear_layer.bias)

# 模拟输入向量
input_vector = torch.tensor([1.0, 2.0, 3.0])
output_vector = linear_layer(input_vector)
print("输出向量 y:")
print(output_vector)

image-20240912221728559

在这里,nn.Linear(3, 2) 创建了一个 2×3 的权重矩阵和一个 2 维的偏置向量。通过 linear_layer(input_vector),可以直接获得输入向量经过线性变换后的输出。

nn.Parameter()

在 PyTorch 中,nn.Linear() 自动处理了权重和偏置项的初始化和更新,但有时你可能希望对这些参数自定义一些操作,比如 LoRA。这时,我们可以使用 nn.Parameter() 来自定义权重和偏置,其实 nn.Linear() 本身就是使用的nn.Parameter(),感兴趣的话可以看官方源码。

以自定义线性层为例:

class CustomLinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomLinearLayer, self).__init__()
        # 使用 nn.Parameter 手动定义权重和偏置
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))

    def forward(self, x):
        # 手动实现线性变换 y = Wx + b
        return torch.matmul(x, self.weight.T) + self.bias

# 使用自定义的线性层
custom_layer = CustomLinearLayer(3, 2)
output = custom_layer(input_vector)
print(output)

image-20240912222625609

在看完代码后,你可能会产生两个疑惑:

Q

1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)?

这正是我写这篇博客的原因,接下来我们详细解释这个问题。

让我们重新使用 in_features \text{in\_features} in_features out_features \text{out\_features} out_features来重现之前的数学定义:

对于输入向量 x ∈ R in_features \mathbf{x} \in \mathbb{R}^{\text{in\_features}} xRin_features,全连接层的输出为:

y = W x + b \mathbf{y} = W\mathbf{x} + \mathbf{b} y=Wx+b

其中:

  • W ∈ R out_features × in_features W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} WRout_features×in_features 是权重矩阵,
  • b ∈ R out_features \mathbf{b} \in \mathbb{R}^{\text{out\_features}} bRout_features 是偏置项。

在线性变换中,输入向量 x \mathbf{x} x 的维度是 in_features \text{in\_features} in_features,而输出向量 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features。根据矩阵乘法的规则,要将输入 x \mathbf{x} x 映射到输出 y \mathbf{y} y,权重矩阵 W W W 的形状应该是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features),因为矩阵乘法中 W x W\mathbf{x} Wx的维度要求是:

( out_features × in_features ) × ( in_features × 1 ) = ( out_features × 1 ) (\text{out\_features} \times \text{in\_features}) \times (\text{in\_features} \times 1) = (\text{out\_features} \times 1) (out_features×in_features)×(in_features×1)=(out_features×1)

这保证了输出 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features

如果权重矩阵的形状是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features),矩阵乘法的维度将不匹配,无法实现线性变换。

现在是不是感觉清晰了?不要 nn.Linear(in_feature, out_feature) 用多了就将权重矩阵当作是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)遗忘了线性代数的概念,数学才是这一切的基石。

2. 为什么是torch.matmul(x, self.weight.T) + self.bias 而不是torch.matmul(self.weight, x) + self.bias?

主要原因还是在于 输入张量 x 的形状矩阵乘法规则

一般来说,模型的输入 x 实际上并不是 ( in_features , 1 ) (\text{in\_features}, 1) (in_features,1),而是 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features),而权重矩阵 self.weight 的形状是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)​,我们需要实现的线性变换是:
y = W x + b y = W x + b y=Wx+b
根据矩阵乘法规则,第一个矩阵的列数必须等于第二个矩阵的行数,这意味着我们不能直接计算 torch.matmul(self.weight, x),因为这样会导致维度不匹配:

  • self.weight 形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)x 形状为 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)
  • torch.matmul(self.weight, x) 的维度计算规则将要求 x 的形状为 ( in_features , batch_size ) (\text{in\_features}, \text{batch\_size}) (in_features,batch_size),但这与模型的输入不匹配。

因此,正确的矩阵乘法应该是 torch.matmul(x, self.weight.T),其中 self.weight.T 表示 self.weight 的转置矩阵,此时的形状为 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)

这样,torch.matmul(x, self.weight.T) 的维度计算为:

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}) (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features)

这就得到了正确的输出形状 ( batch_size , out_features ) (\text{batch\_size}, \text{out\_features}) (batch_size,out_features)

3. 为什么不直接设置self.weight = nn.Parameter(torch.randn(input_dim, output_dim))

这样不就可以不转置直接使用torch.matmul(x, self.weight)了吗?的确如此,或许是因为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features) 对于矩阵运算 W x W\mathbf{x} Wx 来讲更符合直觉吧。

计算过程的细分:torch.matmul() vs @ 运算符

在 PyTorch 中,torch.matmul() 用于实现矩阵乘法,而 @ 是其简洁的符号形式,是 Python 的语法糖,二者在功能上是等价的。

示例代码:

import torch

# 定义权重矩阵 W 和输入向量 input_vector
W = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
input_vector = torch.tensor([1.0, 2.0, 3.0])

# 使用 torch.matmul 实现矩阵乘法
result1 = torch.matmul(W, input_vector)

# 使用 @ 运算符
result2 = W @ input_vector

print("使用 torch.matmul 计算的结果:")
print(result1)

print("使用 @ 运算符计算的结果:")
print(result2)

结果:

image-20240912233355773

使用 F.linear()

PyTorch 提供了 F.linear() 作为函数式接口,它与 nn.Linear() 类似,但不需要创建一个线性层对象。F.linear() 可以接受线性层的权重和偏置作为输入。

示例代码:

import torch.nn.functional as F

# 使用 F.linear 进行线性变换
output = F.linear(input_vector, linear_layer.weight, linear_layer.bias)
print(output)

image-20240912233501651

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2131037.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【人工智能】OpenAI最新发布的GPT-o1模型,和GPT-4o到底哪个更强?最新分析结果就在这里!

在人工智能的快速发展中,OpenAI的每一次新模型发布都引发了广泛的关注与讨论。2023年9月13日,OpenAI正式推出了名为o1的新模型,这一模型不仅是其系列“推理”模型中的首个代表,更是朝着类人人工智能迈进的重要一步。本文将综合分析…

10款超好用的电脑文件加密软件推荐|2024文件加密软件排行榜

在数字时代,数据安全已成为个人和企业不可忽视的重要议题。加密软件作为守护数据安全的坚固防线,其重要性不言而喻。以下是2024年备受推荐的十款电脑文件加密软件。 1.安秉网盾 安秉网盾以其全面的数据保护和安全防护功能备受企业青睐。它支持多种加密…

C语言内存函数(21)

文章目录 前言一、memcpy的使用和模拟实现二、memmove的使用和模拟实现三、memset函数的使用四、memcmp函数的使用总结 前言 正文开始,发车! 一、memcpy的使用和模拟实现 函数模型:void* memcpy(void* destination, const void* source, size…

深入Redis:分布式锁

在一个分布式的系统中,会涉及到多个节点访问同一个公共资源的情况。此时就需要通过锁来做互斥控制,避免出现类似于“线程安全”的问题。 Java中的synchronize只能在当前线程中生效,在分布式的这种多个进程多个主机的场景下就无能为力了。此时…

原型模式详细介绍和代码实现

🎯 设计模式专栏,持续更新中, 欢迎订阅:JAVA实现设计模式 🛠️ 希望小伙伴们一键三连,有问题私信都会回复,或者在评论区直接发言 Java实现原型模式 介绍: 原型模式(Prototype Patte…

C4D2025来了!亮眼的新功能一览

C4D2025新功能亮点,同步上新的Redshift 2025.0.2。等我体验了再给大家讲详细的 成都渲染101云渲染支持对应软件渲染,3090等显卡,云渲码6666 渲染101云渲码6666 Mograph增强 引入线性域标签,用于精细控制对象参数。 为追踪器对象新…

安泰功率放大器有哪些特点呢

功率放大器是电子设备中的重要组成部分,其作用是将输入信号的电功率放大到足够的水平,以驱动负载,如扬声器或天线。功率放大器有一些独特的特点,这些特点对于各种应用至关重要。下面将详细介绍功率放大器的特点,以更好…

【Vue】移动端访问Vue项目页面无数据,但是PC访问有数据

问题: Vue项目,PC访问时下拉列表有数据,移动端访问时下拉列表没有数据; 思路: 首先打开了Fiddler抓包工具,把抓到的url复制到PC浏览器进行访问,结果发现PC上访问这个页面是有数据的&#xff…

利用Leaflet.js创建交互式地图:绘制固定尺寸的长方形

在现代Web开发中,交互式地图已成为展示地理位置数据的重要工具。Leaflet.js是一个轻量级、功能丰富的开源JavaScript库,用于构建移动友好的交互式地图。在本文中,我们将探讨如何利用Leaflet.js在地图上绘制一个固定尺寸的长方形,扩…

堆+堆排序+topK问题

目录 堆: 1、堆的概念 2、堆的结构 3、堆的实现 3.1、建堆 3.1.1、向上调整建堆(用于堆的插入) 3.1.2、向下调整建堆 3.2、堆的删除 3.3、堆的代码实现 3.3.1、Heap.h 3.3.2、Heap.c 堆排序:(O(N*log(N))) 1、排序如何…

接口测试用例的编写

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 1、接口测试发现的典型问题 接口测试经常遇到的bug和问题,如下: 传入参数处理不当,导致程序crash类型溢出,导…

下载chromedriver驱动

首先进入关于ChromeDriver最新下载地址:Chrome for Testing availability 进入之后找到与自己所匹配的,在浏览器中查看版本号,下载版本号需要一致。 下载即可,解压,找到 直接放在pycharm下即可 因为在环境变量中早已配…

严重干扰的验证码识别系统源码分享

严重干扰的验证码识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comp…

领夹麦克风哪个品牌好?大疆、西圣、博雅无线麦克风在线测评

​随着科技的不断进步,越来越多的专业音频设备出现在大家的视野中,无线领夹麦克风就是其中之一,并且很多人在视频创作、直播等场景中都会进行选购。但是近些年来无线领夹麦克风市场较为复杂,很多质量不佳的产品混杂其中&#xff0…

安全生产许可证的重要性

在现代社会,安全生产许可证对于企业来说,不仅仅是一种法律要求,更是一种社会责任和企业形象的体现。本文将深入探讨安全生产许可证的重要性,以及它如何影响企业的长期发展和社会责任。 一、法律合规性的重要性 安全生产许可证是企…

Windows上指定盘符-安装WSL虚拟机(机械硬盘)

参考来自于教程1:史上最全的WSL安装教程 - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/386590591#%E4%B8%80%E3%80%81%E5%AE%89%E8%A3%85WSL2.0 教程2:Windows 10: 将 WSL Linux 实例安装到 D 盘,做成移动硬盘绿色版也不在话下 - 知乎 (z…

那些网站需要使用OV SSL证书?怎么申请?

🔒 需要OV SSL证书的网站 🏢 企业网站 公司官网:展示公司信息、产品服务,增强信任度。 电子商务网站:处理在线交易,保护用户数据安全。 🏦 金融网站 银行网站:处理金融交易&#x…

SMB流量分析

SMB协议流量主要分为以下几个阶段 1、 tcp三次握手 2、会话建立(Negotiate Protocol): 客户端发送Negotiate Protocol Request ,其中包含客户端支持的smb协议版本列表以及SMB Header信息 服务端发送Negotiate Protocol Response,包含服务…

Springboot中自定义监听器

一、监听器模式图 二、监听器三要素 广播器:用来发布事件 事件:需要被传播的消息 监听器:一个对象对一个事件的发生做出反应,这个对象就是事件监听器 三、监听器的实现方式 1、实现自定义事件 自定义事件需要继承ApplicationEv…

Git学习尚硅谷(005 idea集成git)

尚硅谷Git入门到精通全套教程(涵盖GitHub\Gitee码云\GitLab) 总时长 4:52:00 共45P 此文章包含第27p-第p32的内容 文章目录 忽略特定文件在家目录里创建这个文件在.gitconfig文件里配置这个文件 配置IDEA定位到git程序进行添加文件初始化本地库添加单个…