nn.Embedding() 和 nn.Linear() 的区别

news2024/11/13 10:20:56

区别

  • nn.Embedding() 接收整数索引(如词汇表中的 Token ID),不要求固定输入维度,返回权重矩阵 W W W 中对应的行向量,类似查找表操作。
  • nn.Linear() 接收一个向量输入(要求固定的输入维度 input_dim),返回线性变换后的结果( W x + b Wx + b Wx+b),其中有偏置项。

运行代码

import torch
import torch.nn as nn

# 设置随机数种子,确保结果可重复
torch.manual_seed(0)

# 假设有 5 个词,每个词的嵌入维度为 3
num_embeddings = 5
embedding_dim = 3

# 定义 nn.Embedding 层
embedding = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)

# 定义 nn.Linear 层,输入维度为词汇表大小,输出维度为嵌入维度
linear = nn.Linear(in_features=num_embeddings, out_features=embedding_dim, bias=False)

# 手动将 nn.Linear 的权重设置为 nn.Embedding 的权重的转置
with torch.no_grad():
    linear.weight.copy_(embedding.weight.transpose(0, 1))

# 输入的词索引
indices = torch.tensor([0, 1, 2, 2, 4])

# 打印输入的词索引
print("输入的词索引 (indices, token IDs):")
print(indices)

# 使用 nn.Embedding 获取嵌入向量
embedding_output = embedding(indices)

# 打印嵌入层(Embedding)的权重矩阵
print("\nEmbedding 层的权重矩阵 (embedding.weight):")
print(embedding.weight)

# 打印 Embedding 的输出
print("\n使用 nn.Embedding 获取的嵌入向量 (embedding_output):")
print(embedding_output)

# 将词索引转换为 one-hot 向量
one_hot_input = nn.functional.one_hot(indices, num_classes=num_embeddings).float()

# 打印 one-hot 输入
print("\n转换后的 one-hot 向量 (one_hot_input):")
print(one_hot_input)

# 使用 nn.Linear 获取嵌入向量
linear_output = linear(one_hot_input)

# 打印线性层(Linear)的权重矩阵
print("\nLinear 层的权重矩阵 (linear.weight):")
print(linear.weight)

# 打印 Linear 的输出
print("\n使用 nn.Linear 获取的嵌入向量 (linear_output):")
print(linear_output)

# 将 indices 直接传入 nn.Linear
try:
    linear_indices_output = linear(indices.float())
    print("\n直接将 indices 传入 nn.Linear 的输出 (linear_indices_output):")
    print(linear_indices_output)
except Exception as e:
    print("\n直接将 indices 传入 nn.Linear 时发生错误:")
    print(e)

输出

输入的词索引 (indices, token IDs):
tensor([0, 1, 2, 2, 4])

Embedding 层的权重矩阵 (embedding.weight):
Parameter containing:
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820],
        [-0.8567,  1.1006, -1.0712]], requires_grad=True)

使用 nn.Embedding 获取的嵌入向量 (embedding_output):
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [ 0.4033,  0.8380, -0.7193],
        [-0.8567,  1.1006, -1.0712]], grad_fn=<EmbeddingBackward0>)

转换后的 one-hot 向量 (one_hot_input):
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.]])

Linear 层的权重矩阵 (linear.weight):
Parameter containing:
tensor([[ 1.5410,  0.5684,  0.4033, -0.4033, -0.8567],
        [-0.2934, -1.0845,  0.8380, -0.5966,  1.1006],
        [-2.1788, -1.3986, -0.7193,  0.1820, -1.0712]], requires_grad=True)

使用 nn.Linear 获取的嵌入向量 (linear_output):
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [ 0.4033,  0.8380, -0.7193],
        [-0.8567,  1.1006, -1.0712]], grad_fn=<MmBackward0>)

直接将 indices 传入 nn.Linear 的输出 (linear_indices_output):
tensor([-2.8583,  3.8007, -6.7578], grad_fn=<SqueezeBackward4>)

尝试

  • indices 改为 torch.tensor([0, 2, 4]),观察输出。
  • indices 改为 torch.tensor([0, 1, 2, 2, 4, 0]),观察输出。
  • indices 改为 torch.tensor([0, 1, 2, 2, 5]),观察输出。

代码差异

权重矩阵的形状

  • Embedding 的权重矩阵形状为 (num_embeddings, embedding_dim),即 (输入维度, 输出维度),没有偏置项。
  • Linear 的权重矩阵形状为 (output_dim, input_dim),即 (输出维度, 输入维度),且包含一个偏置向量。

初始化示例:

# Embedding
def __init__(self, num_embeddings, embedding_dim):
    self.weight = torch.nn.Parameter(torch.randn(num_embeddings, embedding_dim))

# Linear
def __init__(self, input_dim, output_dim):
    self.weight = nn.Parameter(torch.randn(output_dim, input_dim))
    self.bias = nn.Parameter(torch.randn(output_dim))

输入处理方式

  • Embedding 将离散的输入(如 Token ID)映射到连续的嵌入向量空间,更通俗一点就是查找表,把输入当成索引,返回权重矩阵对应的行
  • Linear 则对输入向量进行矩阵乘法,再加上偏置,实现线性变换。

示例代码:

# Embedding
def forward(self, input):
    return self.weight[input]  # 返回权重矩阵的对应行

# Linear
def forward(self, input):
    return torch.matmul(input, self.weight.T) + self.bias  # 线性变换

数学表达

Embedding

假设词汇表大小为 V V V,嵌入维度为 D D D,嵌入层表示为矩阵 E ∈ R V × D E \in \mathbb{R}^{V \times D} ERV×D。对于输入 token ID 序列 x 1 , x 2 , … , x n x_1, x_2, \dots, x_n x1,x2,,xn,嵌入层输出对应的嵌入向量 E x 1 , E x 2 , … , E x n E_{x_1}, E_{x_2}, \dots, E_{x_n} Ex1,Ex2,,Exn,其中每个 E x i ∈ R D E_{x_i} \in \mathbb{R}^{D} ExiRD。可以表示为:
E ( x i ) = E x i E(x_i) = E_{x_i} E(xi)=Exi

  • E E E:嵌入矩阵,从代码上看就是权重矩阵 W W W
  • x i x_i xi:是输入的 Token ID,就是索引。
  • E ( x i ) E(x_i) E(xi):嵌入向量,就是索引对应的行。

Linear

给定输入向量 x ∈ R n x \in \mathbb{R}^{n} xRn 和输出向量 y ∈ R m y \in \mathbb{R}^{m} yRm,线性层的变换表示为:
y = W x + b y = Wx + b y=Wx+b

  • W W W:权重矩阵,维度为 m × n m \times n m×n,决定输入到输出空间的映射。
  • x x x:输入向量,维度为 n n n
  • b b b:偏置项,维度为 m m m
  • y y y:输出向量,维度为 m m m

示例

如果输入向量 x \mathbf{x} x 有 3 个特征,输出向量 y \mathbf{y} y 有 2 个特征,则权重矩阵 W \mathbf{W} W 的形状为 2 × 3 2 \times 3 2×3。假设:

W = [ 1 2 3 4 5 6 ] , x = [ 1 2 3 ] , b = [ 0 1 ] W = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}, \quad x = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, \quad b = \begin{bmatrix} 0 \\ 1 \end{bmatrix} W=[142536],x= 123 ,b=[01]

线性变换计算为:

y = W ⋅ x + b = [ 1 2 3 4 5 6 ] [ 1 2 3 ] + [ 0 1 ] = [ 14 32 ] + [ 0 1 ] = [ 14 33 ] y = W \cdot x + b = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 33 \end{bmatrix} y=Wx+b=[142536] 123 +[01]=[1432]+[01]=[1433]

矩阵运算过程:
[ 1 2 3 4 5 6 ] [ 1 2 3 ] = [ ( 1 × 1 ) + ( 2 × 2 ) + ( 3 × 3 ) ( 4 × 1 ) + ( 5 × 2 ) + ( 6 × 3 ) ] = [ 14 32 ] \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} = \begin{bmatrix} (1 \times 1) + (2 \times 2) + (3 \times 3) \\ (4 \times 1) + (5 \times 2) + (6 \times 3) \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} [142536] 123 =[(1×1)+(2×2)+(3×3)(4×1)+(5×2)+(6×3)]=[1432]
拓展阅读

深入理解全连接层:从线性代数到 PyTorch 中的 nn.Linear 和 nn.Parameter

PyTorch nn.Embedding() 嵌入层详解和要点提醒

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2237419.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ODOO学习笔记(4):Odoo与SAP的主要区别是什么?

Odoo 和 SAP 都是知名的企业资源规划&#xff08;ERP&#xff09;软件&#xff0c;它们之间存在以下一些主要区别&#xff1a; Odoo与SAP的区别 一、功能特点 功能广度 Odoo&#xff1a;提供了一整套全面的业务应用程序&#xff0c;涵盖了销售、采购、库存管理、生产、会计、…

python之正则表达式总结

正则表达式 对于正则表达式的学习&#xff0c;我整理了网上的一些资料&#xff0c;希望可以帮助到各位&#xff01;&#xff01;&#xff01; 我们可以使用正则表达式来定义字符串的匹配模式&#xff0c;即如何检查一个字符串是否有跟某种模式匹配的部分或者从一个字符串中将与…

【日志框架整合】Slf4j、Log4j、Log4j2、Logback配置模板

文章目录 一、日志框架介绍1、浅谈与slfj4、log4j、logback的关系2、性能方面3、Slf4j使用方法 二、log4j配置三、log4j2配置1、SpringBoot整合Log4j22、非SpringBoot项目引入的依赖3、log4j2-spring.xml文件&#xff08;Spring项目&#xff09;或log4j2.xml&#xff08;非Spri…

StarUML建模工具安装学习与汉化最新零基础详细教程【一键式下载】(适用于Windows、MacOS系统、Linux系统)

StarUML破解安装下载教程 前言&#xff1a; StarUML破解与汉化安装下载教程&#xff0c;仅供学习研究和交流使用&#xff0c;禁止作为商业用途或其他非法用途&#xff01; 仓库作者&#xff1a;X1a0He&#xff0c;经仓库作者授权使用。 目录 StarUML破解安装下载教程1. 下载…

【网络安全】2.3 安全的网络设计_2.防御深度原则

文章目录 一、网络架构二、网络设备三、网络策略四、处理网络安全事件五、实例学习&#xff1a;安全的网络设计结论 网络设计是网络安全的基础&#xff0c;一个好的网络设计可以有效的防止攻击者的入侵。在本篇文章中&#xff0c;我们将详细介绍如何设计一个安全的网络&#…

IoTDB 与 HBase 对比详解:架构、功能与性能

五大方向&#xff0c;洞悉 IoTDB 与 HBase 的详尽对比&#xff01; 在物联网&#xff08;IoT&#xff09;领域&#xff0c;数据的采集、存储和分析是确保系统高效运行和决策准确的重要环节。随着物联网设备数量的增加和数据量的爆炸式增长&#xff0c;开发者和决策者们需要选择…

如何找到系统中bert-base-uncased默认安装位置

问题&#xff1a; 服务器中无法连接huggingface&#xff0c;故需要自己将模型文件上传 ubuntu 可以按照这个链接下载 Bert下载和使用&#xff08;以bert-base-uncased为例&#xff09; - 会自愈的哈士奇 - 博客园 里面提供了giehub里面的链接 GitHub - google-research/be…

Qt 学习第十六天:文件和事件

一、创建widget对象&#xff08;文件&#xff09; 二、设计ui界面 放一个label标签上去&#xff0c;设置成box就可以显示边框了 三、新建Mylabel类 四、提升ui界面的label标签为Mylabel 五、修改mylabel.h&#xff0c;mylabel.cpp #ifndef MYLABEL_H #define MYLABEL_H#incl…

华为ensp配置bgp(避坑版)

文章目录 前言一、BGP是什么&#xff1f;二、拓扑三、基础配置四、测试五、拓展总结 前言 BGP&#xff08;Border Gateway Protocol&#xff0c;边界网关协议&#xff09;是一种在互联网中使用的路径矢量协议。它主要用于在不同的自治系统&#xff08;AS&#xff09;之间交换路…

QT最新版6.8在线社区版安装教程

访问QT的官网&#xff1a; Qt | Tools for Each Stage of Software Development Lifecycle 点击 Download Try&#xff1a; 点击社区版最新在线安装&#xff1a; 往下翻网页&#xff0c; 点击下载&#xff1a; 开始安装&#xff1a; 使用--mirror进行启动安装程序&#xff1…

鸿蒙多线程开发——Worker多线程

1、概 述 1.1、基本介绍 Worker主要作用是为应用程序提供一个多线程的运行环境&#xff0c;可满足应用程序在执行过程中与主线程分离&#xff0c;在后台线程中运行一个脚本进行耗时操作&#xff0c;极大避免类似于计算密集型或高延迟的任务阻塞主线程的运行。 创建Worker的线…

海量数据迁移:Elasticsearch到OpenSearch的无缝迁移策略与实践

文章目录 一&#xff0e;迁移背景二&#xff0e;迁移分析三&#xff0e;方案制定3.1 使用工具迁移3.2 脚本迁移 四&#xff0e;方案建议 一&#xff0e;迁移背景 目前有两个es集群&#xff0c;版本为5.2.2和7.16.0&#xff0c;总数据量为700T。迁移过程需要不停服务迁移&#…

在配置环境变量之后使用Maven报错 : mvn : 无法将“mvn”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。

最近&#xff0c;我在 Windows 系统上安装和配置 Apache Maven 时遇到了一些问题&#xff0c;想在此记录下我的解决历程&#xff0c;希望对遇到类似问题的朋友有所帮助。 问题描述 我下载了 Maven 并按照常规步骤配置了相关的环境变量。然而&#xff0c;在 PowerShell 中输入…

大模型,智能家居的春秋战国之交

智能家居&#xff0c;大家都不陌生。尽管苹果、谷歌、亚马逊等AI科技巨头&#xff0c;以及传统家电厂商都在积极进入这一领域&#xff0c;但发展了十多年之后&#xff0c;智能家居依然长期呈现出一种技术上人工智障、市场上四分五裂的局面。 究其原因&#xff0c;是此前传统家电…

【设计模式】结构型模式(四):组合模式、享元模式

《设计模式之结构型模式》系列&#xff0c;共包含以下文章&#xff1a; 结构型模式&#xff08;一&#xff09;&#xff1a;适配器模式、装饰器模式结构型模式&#xff08;二&#xff09;&#xff1a;代理模式结构型模式&#xff08;三&#xff09;&#xff1a;桥接模式、外观…

众测遇到的一些案列漏洞

文章中涉及的敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打码处理。传播、利用本文章所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,作者不为此承担任何责任,一旦造成后果请自行…

算法求解(C#)-- 寻找包含目标字符串的最短子串算法

1. 引言 在字符串处理中&#xff0c;我们经常需要从一个较长的字符串中找到包含特定目标字符串的最短子串。这个问题在文本搜索、基因序列分析等领域有着广泛的应用。本文将介绍一种高效的算法来解决这个问题。 2. 问题描述 给定一个源字符串 source 和一个目标字符串 targe…

ThingsBoard规则链节点:RPC Call Reply节点详解

引言 1. RPC Call Reply 节点简介 2. 节点配置 2.1 基本配置示例 3. 使用场景 3.1 设备控制 3.2 状态查询 3.3 命令执行 4. 实际项目中的应用 4.1 项目背景 4.2 项目需求 4.3 实现步骤 5. 总结 引言 ThingsBoard 是一个开源的物联网平台&#xff0c;提供了设备管理…

动态规划(简单多状态 dp 问题 1.按摩师 2.打家劫舍 II 3. 删除并获得点数 4.粉刷房子 5.买卖股票的最佳时机(全系列))

面试题 17.16. 按摩师213. 打家劫舍 II740. 删除并获得点数LCR 091. 粉刷房子 &#xff08;原&#xff1a;剑指 Offer II 091. 粉刷房子&#xff09;309. 买卖股票的最佳时机含冷冻期714. 买卖股票的最佳时机含手续费123. 买卖股票的最佳时机 III188. 买卖股票的最佳时机 IV 1.…

【VBA实战】用Excel制作排序算法动画续

为什么会产生用excel来制作排序算法动画的念头&#xff0c;参见【VBA实战】用Excel制作排序算法动画一文。这篇文章贴出我所制作的所有排序算法动画效果和源码&#xff0c;供大家参考。 冒泡排序&#xff1a; 插入排序&#xff1a; 选择排序&#xff1a; 快速排序&#xff1a;…