逐行讲解Transformer的代码实现和原理讲解:nn.Linear线性层原理

news2025/1/11 6:12:27

视频详细讲解:LLM模型:代码讲解Transformer运行原理(1)_哔哩哔哩_bilibili

1 概述

经过Transformer的12个块处理完之后,4批文本数据得到了一个矩阵[4, 8, 16],也就是每批数据都训练出了一个结果,在训练阶段,这个结果的作用是跟目标标签计算损失值,然后通过反向传播更新各个权重向量;在推理阶段就是输出每个字的向量表,目的是拿着这个向量表计算一个概率值,最大概率值就是输出结果了。

【训练数据】

【线性变换数据】

2 线性变换原理

2.1线性变换的作用

nn.Linear 层的作用就是:

  1. 通过乘法(权重)来调整输入数据。
  2. 通过加法(偏置)来添加一个基础值。

这样,输入数据经过简单的数学运算后,变成一个更有意义的输出数据。这个过程在神经网络中反复进行,最终帮助我们做出准确的预测。

2.2 线性变换实现

self.model_out_linear_layer = nn.Linear(model_dimension, max_token_value+1)
linear_predictions = self.model_out_linear_layer(transformer_train_data)

这段代码定义了一个线性层,并使用这个线性层对一些训练数据进行前向传播。具体来说,nn.Linear是PyTorch中的一个模块(类),用于创建一个线性变换的层,通常在神经网络中作为全连接层来使用。

  1. self.model_out_linear_layer = nn.Linear(model_dimension, max_token_value+1):

    • 这一行创建了一个nn.Linear实例,并将其赋值给self.model_out_linear_layer
    • nn.Linear接受两个参数:输入特征的数量(在这里是model_dimension)和输出特征的数量(在这里是max_token_value + 1)。
    • model_dimension通常是来自模型前一层的输出维度,例如,它可能是Transformer编码器或解码器的最后一层的隐藏状态维度。
    • max_token_value + 1表示输出层的大小,可能对应于词汇表的大小或者你希望预测的最大值加一(因为索引是从0开始的)。
  2. linear_predictions = self.model_out_linear_layer(transformer_train_data):

    • 这一行使用之前定义的线性层对transformer_train_data进行前向传播。
    • transformer_train_data应该是具有与model_dimension相同特征数量的数据,这通常是Transformer模型的输出。
    • self.model_out_linear_layer将这些输入映射到一个新的空间,其维度为max_token_value + 1
    • linear_predictions是线性层的输出,可以被看作是对下一个词或其他目标的未归一化的概率(即原始分数或logits)。

通常情况下,在训练阶段之后,这些线性预测会被传递给一个损失函数(比如交叉熵损失函数)以计算模型的误差,并且在推理阶段,可能会应用softmax函数将logits转换成概率分布,以便进行采样或选择最有可能的输出。

2.3 线性变换原理

2.3.1 线性变换计算公式

矩阵乘法加上偏置项。有一个输入向量 x,经过一个线性层 L 后,得到的输出向量 y 可以表示为:

y=Wx+b

其中:

  • W是一个权重矩阵,它的尺寸是输出特征数乘以输入特征数。
  • x是输入向量,它的尺寸是输入特征数。
  • b是一个偏置向量,它的尺寸是输出特征数。
  • y是输出向量,它的尺寸是输出特征数。

权重矩阵 W会在训练过程中不断更新。权重矩阵 W和偏置向量 b都是神经网络中的可学习参数。它们会在训练过程中通过优化算法(如梯度下降或其变种)进行更新,以最小化损失函数。

以下是权重矩阵如何更新的一个简要流程:

  1. 初始化

    • 在训练开始时,权重矩阵 W 和偏置向量 b通常会被随机初始化(如使用正态分布或均匀分布)。
  2. 前向传播

    • 在每个训练步骤中,输入数据通过神经网络进行前向传播,生成预测输出。
  3. 计算损失

    • 使用损失函数(如均方误差、交叉熵等)来衡量预测输出与真实标签之间的差距。
  4. 反向传播

    • 通过反向传播算法计算损失函数相对于每个权重的梯度。这意味着计算每个权重对损失的影响程度。
    • 反向传播本质上是应用链式法则来计算梯度,从输出层一直回传到输入层。
  5. 权重更新

    • 根据计算出的梯度和选择的优化算法(如随机梯度下降SGD、Adam等),这里使用的是Adam,更新权重矩阵 W 和偏置向量 b

2.3.2 具体步骤:

  1. 输入数据:假设你有一批数据,每条数据是一个长度为 input_features 的向量。
  2. 权重矩阵nn.Linear 模块内部维护一个权重矩阵 W,尺寸为 [output_features, input_features]
  3. 偏置向量:同样,nn.Linear 模块内部还有一个偏置向量 b,尺寸为 [output_features]
  4. 矩阵乘法:每个输入向量 x与权重矩阵 W进行矩阵乘法。
  5. 加上偏置:将上述结果与偏置向量 b相加,得到最终的输出向量 y。

2.3.3 示例

假设你有以下参数:

  • 输入特征数 input_features = 10
  • 输出特征数 output_features = 5

那么,nn.Linear 层将接收一个形状为 [batch_size, 10] 的输入张量,并输出一个形状为 [batch_size, 5] 的张量。

2.3.4 代码示例

1import torch
2import torch.nn as nn
3
4# 假设输入特征数为10,输出特征数为5
5model_dimension = 10
6max_token_value = 4  # 词汇表大小为5
7
8# 创建线性层
9linear_layer = nn.Linear(model_dimension, max_token_value + 1)
10
11# 假设有一个批次的训练数据,batch_size为3
12transformer_train_data = torch.randn(3, model_dimension)
13
14# 前向传播
15output = linear_layer(transformer_train_data)
16
17print(output.shape)  # 应该输出 (3, 5)

在这个例子中:

  • transformer_train_data 是一个形状为 [3, 10] 的张量,表示有3个样本,每个样本有10个特征。
  • 经过 nn.Linear 层后,输出的张量形状为 [3, 5],表示每个样本现在有5个特征(或说5个输出值)。

这个输出可以被看作是对每个样本的一组预测值或原始分数(logits),通常用于后续的处理,如应用激活函数(如softmax)来获得概率分布。

2.4 线性变换通俗原理

假设你要调制饮料

你正在调制一杯饮料,你需要根据不同的配料来调整饮料的味道。你有两样主要的配料:糖和柠檬汁。

输入

  • 你有一个杯子,里面已经有了一定量的水。
  • 你准备往里面加入糖和柠檬汁。

线性变换

你想要通过加糖和柠檬汁来调整饮料的味道。具体来说:

  1. 糖的比例:每加一勺糖,会让甜度增加 2 分。
  2. 柠檬汁的比例:每加一勺柠檬汁,会让酸度增加 1 分。

此外,你还希望饮料本身有一定的基础甜度和酸度。

具体步骤

  1. 糖的比例(权重):假设你加了 3 勺糖。
  2. 柠檬汁的比例(权重):假设你加了 2 勺柠檬汁。
  3. 基础甜度和酸度(偏置):假设饮料本身就有 1 分甜度和 1 分酸度。

计算

  • 甜度 = (糖的量 × 糖的比例)+ 基础甜度
  • 酸度 = (柠檬汁的量 × 柠檬汁的比例)+ 基础酸度

具体来说:

  • 甜度 = (3 × 2) + 1 = 7 分
  • 酸度 = (2 × 1) + 1 = 3 分

结果

最终,你的饮料有 7 分甜度和 3 分酸度。

类比 nn.Linear

nn.Linear 层中:

  • 糖的比例和柠檬汁的比例 相当于权重矩阵 W。
  • 基础甜度和酸度 相当于偏置向量 b。
  • 加糖和柠檬汁 相当于输入向量 x。
  • 最终的甜度和酸度 相当于输出向量 y。

数学表达

用数学公式表示就是:

y=Wx+b

在这个例子中:

通过简单的乘法和加法,你得到了最终的甜度和酸度。

3 什么是前向传播

想象一下你在玩一个很长的流水线游戏。在这个游戏中,你有一个球,你需要通过一系列的障碍物来让这个球到达终点。这些障碍物就像是游戏中的不同关卡,而你的目标就是通过每一关,最后让球顺利到达目的地。

在神经网络中,数据就像是那个球,而每一层神经网络就像游戏中的一个关卡。前向传播就是让数据从网络的开始一直传递到结束的过程,就像让球从游戏的第一个关卡一直滚到最后一个关卡一样。

具体来说:

  • 输入层:这是起点,你把球(即数据)放在这里。
  • 隐藏层:这些是中间的关卡,每一个关卡都会改变球的状态。比如,球的颜色可能会变,或者球的形状可能会变。在神经网络中,每一层都会对数据做一些数学运算,改变数据的样子,让它变得更符合我们需要的形式。
  • 输出层:这是终点,球经过所有关卡后到达的地方。在这里,球已经变成了我们想要的样子,比如它可能代表了一个预测结果。

所以,前向传播就是让数据通过神经网络的所有层,从输入层开始,一层接一层地传递,直到输出层,得到最终的结果。这个过程不需要人为干预,数据会自动按照每层设定好的规则流动。

在训练过程中,我们会比较最终的结果与实际需要的结果之间的差异,然后调整整个游戏(神经网络)的规则(权重),使得下一次前向传播时,球能更接近正确的目标位置。这就是为什么我们要进行前向传播的主要原因——为了得到预测结果,并且在训练过程中不断改进这些结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2112662.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Netty系列-3 ChannelFuture和ChannelPromise介绍

背景 Netty源码中大量使用了Future和Promise,学习ChannelFuture和ChannelFuture有助于理解Netty的设计思路。 本文的重点内容在于梳理清楚这些类的关系以及结合源码实现介绍这些类的作用,其中核心逻辑在于DefaultPromise和DefaultChannelPromise&#x…

GBase8sV8.8安装指南

目录 一、下载 Gbase 安装包二、安装预置条件1.确保安装包和平台适配2.安装依赖包:jdk(1.6版本以上)、unzip、libaio、libgcc、libstdc、ncurses、pam,如果缺失请提前安装 三、上传包并解压四、安装五、登录并创建数据库六、启动停止数据库七、常见问题八…

虚拟机ubuntu与主机共享文件夹

现在主机(windows)上新建一个共享文件夹 打开虚拟机 按下面操作打开共享文件夹 进入虚拟机的系统 cd /mnt/hgfs 如果报错 可以按下面的解决 挂载一下 sudo mount -t fuse.vmhgfs-fuse .host:/ /mnt/hgfs -o allow_other 如果显示不存在这个文…

session机制

场景:当众多用户访问网站,发出HTTP请求,那么网站是如何判断哪个HTTP请求对应的是哪个用户 ? 作用:用于服务端区分用户。 当用户使用客户端登录时,服务端会进行验证,验证通过后会为这次登录创建…

剖析Cookie的工作原理及其安全风险

Cookie的工作原理主要涉及到HTTP协议中的状态管理。HTTP协议本身是无状态的,这意味着每次请求都是独立的,服务器不会保留之前的请求信息。为了在无状态的HTTP协议上实现有状态的会话,引入了Cookie机制。 1. Cookie定义 Cookie,也…

EMC测试

传导干扰测试: 现场实录CE传导骚扰电压测试,硬件环境: R&S EPL1000 EMI测量接收机(支持时域测试) R&S ENV216人工电源网络 R&S ELEKTRA 测试软件 黑色底板,不写丝印,0402封装平行排…

Tomcat服务详解

一、部署Tomcat服务器 JDK安装官方网址:https://www.oracle.com/cn/java Tomcat安装官方网址:Apache Tomcat - Welcome! 安装JDK 1.获取安装包 wget https://download.oracle.com/otn/java/jdk/8u411-b09/43d62d619be4e416215729597d70b8ac/jdk-8u41…

【工程测试技术】第13章 流体参量测量

目录 第13章 流体参量测量 13.1压力的测量 13.1.1 弹性式压力敏感元件 1. 波登管 2. 膜片和膜盒 3. 波纹管 13.1.2 常用压力传感器 1. 应变式压力传感器 2. 压阻式压力传感器 3. 压电式压力传感器 4. 电容式压力传感器 5. 谐振式压力传感器 6. 位移式压力传感器 (1)…

整型数组按个位值排序

题目描述 给定一个非空数组(列表),其元素数据类型为整型,请按照数组元素十进制最低位从小到大进行排序,十进制最低位相同的元司 相对位置保持不变。 当数组元素为负值时,十进制最低位等同于去除符号位后对应十进制值最低位。 输…

吐血整理 ChatGPT 3.5/4.0 新手使用手册~ 【2024.09.04 更新】

以前我也是通过官网使用,但是经常被封号,就非常不方便,后来有朋友推荐国内工具,用了一阵之后,发现:稳定方便,用着也挺好的。 最新的 GPT-4o、4o mini,可搭配使用~ 1、 最新模型科普&…

VisualStudio环境搭建C++

Visual Studio环境搭建 说明 C程序编写中,经常需要链接头文件(.h/.hpp)和源文件(.c/.cpp)。这样的好处是:控制主文件的篇幅,让代码架构更加清晰。一般来说头文件里放的是类的申明,函数的申明,全局变量的定义等等。源…

Java面试题·解释题·框架部分

系列文章目录 Java面试题解释题总体概括 Java面试题解释题JavaSE部分 Java面试题解释题框架部分 文章目录 系列文章目录前言一、MyBatis1. 请你介绍MyBatis框架2. MyBatis框架的核心思想是什么?3. MyBatis的核心配置文件中常用的子标签有哪些?4. mapper…

饲料加工机器设备有哪些组成部分

在快速发展的畜牧业中,饲料加工作为支撑养殖业的重要环节,其效率与品质直接影响着养殖业的成本效益与动物健康。随着科技的进步,饲料加工机器设备也在不断升级,为养殖行业带来了变革。一、智能化粉碎机:细度可调&#…

Unity Adressables 使用说明(五)在运行时使用 Addressables(Use Addressables at Runtime)

一旦你将 Addressable assets 组织到 groups 并构建到 AssetBundles 中,就需要在运行时加载、实例化和释放它们。 Addressables 使用引用计数系统来确保 assets 只在需要时保留在内存中。 Addressables 初始化 Addressables 系统在运行时第一次加载 Addressable …

SimD:基于相似度距离的小目标检测标签分配

摘要 https://arxiv.org/pdf/2407.02394 由于物体尺寸有限且信息不足,小物体检测正成为计算机视觉领域最具挑战性的任务之一。标签分配策略是影响物体检测精度的关键因素。尽管已经存在一些针对小物体的有效标签分配策略,但大多数策略都集中在降低对边界…

怎么利用XML发送物流快递通知短信

现如今短信平台越来越普遍了,而短信通知也分很多种,例如服务通知、订单通知、交易短信通知、会议通知等。而短信平台在物流行业通知这一块作用也很大。在家时:我们平时快递到了,如果电话联系不到本人,就会放到代收点,然…

正负极层数更新器

文件名:dcs_tkinter.py import tkinter as tk from tkinter import messagebox import redis# 连接Redis r redis.Redis(hostlocalhost, port6379, db0)def update_redis_and_display():try:# 从输入框获取值positive_layers int(entry_positive.get())negative_…

2024国赛数学建模C题论文:基于优化模型的农作物的种植策略

大家可以查看一下35页,包含结构完整,数据完整的C题论文,完整论文见文末名片 添加图片注释,不超过 140 字(可选) 添加图片注释,不超过 140 字(可选) 添加图片注释&#xf…

Nexus配置npm私服

1,配置npm-hub 2,配置proxy-npm 3,配置group-npm 4,配置local-npm 5,配置淘宝

Java语言程序设计基础篇_编程练习题**17.20 (二进制编辑器)

目录 题目:**17.20 (二进制编辑器) 代码示例 结果展示 题目:**17.20 (二进制编辑器) 编写一个GUI应用程序,让用户在文本域输入一个文件名,然后单击回车键,在文本区域显示它的二进制表示形式。用户也可以修改这个二…