深入解析全连接层:PyTorch 中的 nn.Linear、nn.Parameter 及矩阵运算

news2025/1/23 3:13:07

文章目录

这篇文章会从基础的一个数学概念到对应的代码实现,你将了解到:

  • 为什么nn.Parameter()接受 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)作为参数?
  • 为什么不是torch.matmul(self.weight, x) + self.bias
  • 如何使用torch.matmul()@F.linear() 去等价地实现nn.Linear()的输出。

数学概念(全连接层,线性层)

线性变化是数学中一个基础的概念,它描述了如何通过线性变换将输入映射到输出。在线性代数中,线性变化通常表示为矩阵乘法。在神经网络中,线性层的核心就是实现这样的矩阵运算。

数学公式:

给定一个输入向量 x ∈ R n \mathbf{x} \in \mathbb{R}^n xRn 和一个输出向量 y ∈ R m \mathbf{y} \in \mathbb{R}^m yRm,线性变化通过矩阵 W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} WRm×n 和偏置项 b ∈ R m \mathbf{b} \in \mathbb{R}^m bRm 进行变换,其公式为:
y = W x + b \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} y=Wx+b

  • W \mathbf{W} W:是权重矩阵,维度为 m × n m \times n m×n,它决定了输入向量如何线性变换到输出空间;
  • x \mathbf{x} x:是输入向量,维度为 n n n,表示特征数据;
  • b \mathbf{b} b:是偏置向量,维度为 m m m,用来调整线性变换的输出;
  • y \mathbf{y} y:是输出向量,维度为 m m m,是变换后的结果。

例子:

如果输入向量 x \mathbf{x} x 有 3 个特征,输出向量 y \mathbf{y} y 有 2 个特征,则权重矩阵 W \mathbf{W} W 的形状为 2 × 3 2 \times 3 2×3。假设:
W = [ 1 2 3 4 5 6 ] , x = [ 1 2 3 ] , b = [ 0 1 ] \mathbf{W} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}, \quad \mathbf{x} = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, \quad \mathbf{b} = \begin{bmatrix} 0 \\ 1 \end{bmatrix} W=[142536],x= 123 ,b=[01]
线性变换计算为:
y = W x + b = [ 1 2 3 4 5 6 ] [ 1 2 3 ] + [ 0 1 ] = [ 14 32 ] + [ 0 1 ] = [ 14 33 ] \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 33 \end{bmatrix} y=Wx+b=[142536] 123 +[01]=[1432]+[01]=[1433]
矩阵运算过程:
[ 1 2 3 4 5 6 ] [ 1 2 3 ] = [ ( 1 × 1 ) + ( 2 × 2 ) + ( 3 × 3 ) ( 4 × 1 ) + ( 5 × 2 ) + ( 6 × 3 ) ] = [ 14 32 ] \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} = \begin{bmatrix} (1 \times 1) + (2 \times 2) + (3 \times 3) \\ (4 \times 1) + (5 \times 2) + (6 \times 3) \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} [142536] 123 =[(1×1)+(2×2)+(3×3)(4×1)+(5×2)+(6×3)]=[1432]

nn.Linear()

nn.Linear() 会自动创建一个权重矩阵(Weight)和偏置项(Bias),并将它们应用到输入上。

代码示例:

import torch
import torch.nn as nn

# 定义一个输入为3,输出为2的线性层
linear_layer = nn.Linear(3, 2)

# 打印权重矩阵和偏置项
print("权重矩阵 W:")
print(linear_layer.weight)

print("偏置项 b:")
print(linear_layer.bias)

# 模拟输入向量
input_vector = torch.tensor([1.0, 2.0, 3.0])
output_vector = linear_layer(input_vector)
print("输出向量 y:")
print(output_vector)

image-20240912221728559

在这里,nn.Linear(3, 2) 创建了一个 2×3 的权重矩阵和一个 2 维的偏置向量。通过 linear_layer(input_vector),可以直接获得输入向量经过线性变换后的输出。

nn.Parameter()

在 PyTorch 中,nn.Linear() 自动处理了权重和偏置项的初始化和更新,但有时你可能希望对这些参数自定义一些操作,比如 LoRA。这时,我们可以使用 nn.Parameter() 来自定义权重和偏置,其实 nn.Linear() 本身就是使用的nn.Parameter(),感兴趣的话可以看官方源码。

以自定义线性层为例:

class CustomLinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomLinearLayer, self).__init__()
        # 使用 nn.Parameter 手动定义权重和偏置
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))

    def forward(self, x):
        # 手动实现线性变换 y = Wx + b
        return torch.matmul(x, self.weight.T) + self.bias

# 使用自定义的线性层
custom_layer = CustomLinearLayer(3, 2)
output = custom_layer(input_vector)
print(output)

image-20240912222625609

在看完代码后,你可能会产生两个疑惑:

Q

1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)?

实际这正是我写这篇博客分享的原因,现在进入正题,因为这个疑惑完全不应该产生。

让我们重新使用 in_features \text{in\_features} in_features out_features \text{out\_features} out_features来重现之前的数学定义:

对于输入向量 x ∈ R in_features \mathbf{x} \in \mathbb{R}^{\text{in\_features}} xRin_features,全连接层的输出为:

y = W x + b \mathbf{y} = W\mathbf{x} + \mathbf{b} y=Wx+b

其中:

  • W ∈ R out_features × in_features W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} WRout_features×in_features 是权重矩阵,
  • b ∈ R out_features \mathbf{b} \in \mathbb{R}^{\text{out\_features}} bRout_features 是偏置项。

在线性变换中,输入向量 x \mathbf{x} x 的维度是 in_features \text{in\_features} in_features,而输出向量 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features。根据矩阵乘法的规则,要将输入 x \mathbf{x} x 映射到输出 y \mathbf{y} y,权重矩阵 W W W 的形状应该是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features),因为矩阵乘法中 W x W\mathbf{x} Wx的维度要求是:

( out_features × in_features ) × ( in_features × 1 ) = ( out_features × 1 ) (\text{out\_features} \times \text{in\_features}) \times (\text{in\_features} \times 1) = (\text{out\_features} \times 1) (out_features×in_features)×(in_features×1)=(out_features×1)

这保证了输出 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features

如果权重矩阵的形状是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features),矩阵乘法的维度将不匹配,无法实现线性变换。

现在是不是感觉清晰了?不要 nn.Linear(in_feature, out_feature) 用多了就将权重矩阵当作是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)遗忘了线性代数的概念,数学才是这一切的基石。

2. 为什么是torch.matmul(x, self.weight.T) + self.bias 而不是torch.matmul(self.weight, x) + self.bias?

主要原因还是在于 输入张量 x 的形状矩阵乘法规则

一般来说,模型的输入 x 实际上并不是 ( in_features , 1 ) (\text{in\_features}, 1) (in_features,1),而是 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features),而权重矩阵 self.weight 的形状是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)​,我们需要实现的线性变换是:
y = W x + b y = W x + b y=Wx+b
根据矩阵乘法规则,第一个矩阵的列数必须等于第二个矩阵的行数,这意味着我们不能直接计算 torch.matmul(self.weight, x),因为这样会导致维度不匹配:

  • self.weight 形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)x 形状为 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)
  • torch.matmul(self.weight, x) 的维度计算规则将要求 x 的形状为 ( in_features , batch_size ) (\text{in\_features}, \text{batch\_size}) (in_features,batch_size),但这与模型的输入不匹配。

因此,正确的矩阵乘法应该是 torch.matmul(x, self.weight.T),其中 self.weight.T 表示 self.weight 的转置矩阵,此时的形状为 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)

这样,torch.matmul(x, self.weight.T) 的维度计算为:

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}) (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features)

这就得到了正确的输出形状 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)

3. 为什么不直接设置self.weight = nn.Parameter(torch.randn(input_dim, output_dim))

这样不就可以不转置直接使用torch.matmul(x, self.weight)了吗?的确如此,或许是因为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features) 对于矩阵运算 W x W\mathbf{x} Wx 来讲更符合直觉吧。

计算过程的细分:torch.matmul() vs @ 运算符

在 PyTorch 中,torch.matmul() 用于实现矩阵乘法,而 @ 是其简洁的符号形式,是 Python 的语法糖,二者在功能上是等价的。

示例代码:

import torch

# 定义权重矩阵 W 和输入向量 input_vector
W = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
input_vector = torch.tensor([1.0, 2.0, 3.0])

# 使用 torch.matmul 实现矩阵乘法
result1 = torch.matmul(W, input_vector)

# 使用 @ 运算符
result2 = W @ input_vector

print("使用 torch.matmul 计算的结果:")
print(result1)

print("使用 @ 运算符计算的结果:")
print(result2)

结果:

image-20240912233355773

使用 F.linear()

PyTorch 提供了 F.linear() 作为函数式接口,它与 nn.Linear() 类似,但不需要创建一个线性层对象。F.linear() 可以接受线性层的权重和偏置作为输入,在下一篇有关 LoRA 的文章中,你将看到使用范例。

示例代码:

import torch.nn.functional as F

# 使用 F.linear 进行线性变换
output = F.linear(input_vector, linear_layer.weight, linear_layer.bias)
print(output)

image-20240912233501651

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2130266.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

复现OpenVLA:开源的视觉-语言-动作模型及原理详解

复现OpenVLA:开源的视觉-语言-动作模型及原理详解 1. 摘要2. 引言3. 相关工作4. 模型结构4.1 模视觉-语言模型VLM4.2 训练流程4.3 图像分辨率4.4 微调视觉编码器4.5 训练轮数4.6 学习率4.7 训练细节4.8 参数高效微调 5. 复现5.1 下拉代码5.2 安装环境依赖5.2.1 创建…

什么是科技与艺术相结合的异形创意圆形(饼/盘)LED显示屏

在当今数字化与创意并重的时代,科技与艺术的融合已成为推动社会进步与文化创新的重要力量。其中,晶锐创显异形创意圆形LED显示屏作为这一趋势下的杰出代表,不仅打破了传统显示设备的形态束缚,更以其独特的造型、卓越的显示效果和广…

浏览器百科:网页存储篇-IndexedDB介绍(十)

1.引言 在现代网页开发中,数据存储需求日益增多和复杂,传统的客户端存储技术如localStorage和sessionStorage已难以满足大型数据的存储和管理需求。为了解决这一问题,HTML5 引入了 IndexedDB,在本篇《浏览器百科:网页…

SiC,GaN驱动优选驱动方案SiLM5350系列SiLM5350MDDCM-DG 带米勒钳位Clamp保护功能 单通道隔离栅极驱动器

SiLM5350MDDCM-DG是一款适用于IGBT、MOSFET的单通道 隔离门极驱动器,具有10A拉电流和10A灌电流驱动能 力。提供内部钳位功能,可单独控制 上升时间和下降时间。 在 SOP8 封 装 中 具 有 3000VRMS 隔 离 耐 压 ( 符 合 UL1577)。 与…

Vue-Route4 ts

小满学习视频 Vue-Route 官网 项目的目录结构: 1. Vue-Router的使用 安装Vue-route pnpm add vue-router4创建router文件 /route/index.vue import { createRouter } from "vue-router"; import {createMemoryHistory,createWebHashHistory,create…

C语言 | Leetcode C语言题解之第395题至少有K个重复字符的最长子串

题目&#xff1a; 题解&#xff1a; int longestSubstring(char* s, int k) {int ret 0;int n strlen(s);for (int t 1; t < 26; t) {int l 0, r 0;int cnt[26];memset(cnt, 0, sizeof(cnt));int tot 0;int less 0;while (r < n) {cnt[s[r] - a];if (cnt[s[r] - …

Framework | 在Android中运行时获取顶层Activity并处理业务逻辑

Framework | 在Android中运行时获取顶层Activity并处理业务逻辑 在Android应用的开发中,有时需要获取当前正在运行的顶层Activity,尤其是当应用需要监控特定的页面或执行特殊的业务处理时,例如在截图界面进行操作或在特定的活动页面展示特定的功能。本文将详细介绍如何通过…

中电金信:金融级数字底座“源启”:打造新型数字基础设施 筑牢千行百业数字化转型发展基石

近期&#xff0c;金融级数字底座“源启”登录中国电子《最轻大国重器》融媒体报道。从数字底座到数智底座&#xff0c;从金融行业到千行百业&#xff0c;“源启”用数智化转型的中国电子解决方案&#xff0c;为全球企业转型及安全发展提供强大动能。 立足中国电子科技创新成果&…

英飞凌motor电机方案

电机主流方案的应用场景,一个是家电,一个是汽车,尤其是新能源汽车。 多家MCU厂商也都推出自己电机解决方案,比如华大电子: HC32M140 系列产品为华大半导体研制的 32bit 基于 ARM-Cortex M0+ 的 MCU,与传统的 CPU 内核相比,效率更高,功耗更低。更宽的工作电压范围,可同…

C语言的结构体类型

在我们使用C语言进行编写代码时&#xff0c;常常会使用已经给定的类型来创建变量&#xff0c;比如int型&#xff0c;char型&#xff0c;double型等&#xff0c;而当我们想创建一些较为复杂的东西时&#xff0c;单单用一个类型变量是没办法做到的&#xff0c;比如我们想创建一个…

图文讲解HarmonyOS应用发布流程

HarmonyOS应用的开发和发布过程可以分为以下几个步骤&#xff1a;证书生成、应用开发、应用签名和发布。 1. 证书生成&#xff1a; 在开始开发HarmonyOS应用之前&#xff0c;首先需要生成一个开发者证书。开发者证书用于标识应用的开发者身份并确保应用的安全性。可以通过Har…

R语言统计分析——功效分析3(相关、线性模型)

参考资料&#xff1a;R语言实战【第2版】 1、相关性 pwr.r.test()函数可以对相关性分析进行功效分析。格式如下&#xff1a; pwr.r.test(n, r, sig.level, power, alternative) 其中&#xff0c;n是观测数目&#xff0c;r是效应值&#xff08;通过线性相关系数衡量&#xff0…

LoRA: Low-Rank Adaptation Abstract

LoRA: Low-Rank Adaptation Abstract LoRA 论文的摘要介绍了一种用于减少大规模预训练模型微调过程中可训练参数数量和内存需求的方法&#xff0c;例如拥有1750亿参数的GPT-3。LoRA 通过冻结模型权重并引入可训练的低秩分解矩阵&#xff0c;减少了10,000倍的可训练参数&#xf…

HashMap常用方法及底层原理

目录 一、什么是HashMap二、HashMap的链表与红黑树1、数据结构2、链表转为红黑树3、红黑树退化为链表 三、存储&#xff08;put&#xff09;操作四、读取&#xff08;get&#xff09;操作五、扩容&#xff08;resize&#xff09;操作六、HashMap的线程安全与顺序1、线程安全2、…

云计算实训49——k8s环镜搭建(续2)

一、Metrics 部署 在新版的 Kubernetes 中系统资源的采集均使⽤ Metrics-server&#xff0c;可 以通过 Metrics 采集节点和 Pod 的内存、磁盘、CPU和⽹络的使⽤ 率。 &#xff08;1&#xff09;复制证书到所有 node 节点 将 master 节点的 front-proxy-ca.crt 复制到所有 No…

Linux进阶命令-top

作者介绍&#xff1a;简历上没有一个精通的运维工程师。希望大家多多关注作者&#xff0c;下面的思维导图也是预计更新的内容和当前进度(不定时更新)。 经过上一章Linux日志的讲解&#xff0c;我们对Linux系统自带的日志服务已经有了一些了解。我们接下来将讲解一些进阶命令&am…

【计算机网络】初识网络

初识网络 初识网络网络的发展局域网广域网 网络基础IP地址端口号协议五元组协议分层OSI 七层模型TCP/IP五层模型封装和分用"客户段-服务器"结构 初识网络 网络的发展 在过去网络还没有出现的时候, 我们的计算机大部分都是独自运行的, 比如以前那些老游戏, 都是只能…

Chainlit集成Langchain并使用通义千问实现文生图网页应用

前言 本文教程如何使用通义千问的大模型服务平台的接口&#xff0c;实现图片生成的网页应用&#xff0c;主要用到的技术服务有&#xff0c;chainlit 、 langchain、 flux。合利用了大模型的工具选择调用能力。实现聊天对话生成图片的网页应用。 阿里云 大模型服务平台百炼 API…

1.SpringCloud与SpringCloud Alibaba

SpringCloud与SpringCloud Alibaba主要讲解的内容&#xff1a; 备注&#xff1a;黑色部分是springcloud社区原版&#xff0c;红色的是SpringCloud Alibaba 服务注册与发现 Consul Alibaba Nacos 服务调用和负载均衡 LoadBalancer OpenFeign 分布式事务 Alibaba Seata 服务熔…

批量插入insert到SQLServer数据库,BigDecimal精度丢失解决办法,不动代码,从驱动层面解决

概述 相信很多人都遇到过&#xff0c;使用sql server数据库&#xff0c;批量插入数据时&#xff0c;BigDecimal类型出现丢失精度的问题&#xff0c;网上也有很多人给出过解决方案&#xff0c;但一般都要修改应用代码&#xff0c;不推荐。 丢失精度的本质是官方的驱动有BUG造成…