【Andrej Karpathy 神经网络从Zero到Hero】--2.语言模型的两种实现方式 (Bigram 和 神经网络)

news2025/3/10 14:32:13

目录

  • 统计 Bigram 语言模型
    • 质量评价方法
  • 神经网络语言模型

【系列笔记】
【Andrej Karpathy 神经网络从Zero到Hero】–1. 自动微分autograd实践要点

本文主要参考 大神Andrej Karpathy 大模型讲座 | 构建makemore 系列之一:讲解语言建模的明确入门,演示

  1. 如何利用统计数值构建一个简单的 Bigram 语言模型
  2. 如何用一个神经网络来复现前面 Bigram 语言模型的结果,以此来展示神经网络相对于传统 n-gram 模型的拓展性。

统计 Bigram 语言模型

首先给定一批数据,每个数据是一个英文名字,例如:

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

Bigram语言模型的做法很简单,首先将数据中的英文名字都做成一个个bigram的数据

其中每个格子中是对应的二元组,eg: “rh” ,在所有数据中出现的次数。那么一个自然的想法是对于给定的字母,取其对应的行,将次数归一化转成概率值,然后根据概率分布抽取下一个可能的字母:

g = torch.Generator().manual_seed(2147483647)
P = N.float() # N 即为上述 counts 矩阵
P = P / P.sum(1, keepdims=True) # P是每行归一化后的概率值

for i in range(5):
  
  out = []
  ix = 0  ## start符和end符都用 id=0 表示,这里是start
  while True:
    p = P[ix] # 当前字符为 ix 时,预测下一个字符的概率分布,实质是一个多项分布(即可能抽到的值有多个,eg: 掷色子是六项分布)
    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix])
    if ix == 0: ## 当运行到end符,停止生成
      break
  print(''.join(out))

输出类似于:

mor.
axx.
minaymoryles.
kondlaisah.
anchshizarie.

质量评价方法

我们还需要方法来评估语言模型的质量,一个直观的想法是:
P ( s 1 s 2 . . . s n ) = P ( s 1 ) P ( s 2 ∣ s 1 ) ⋯ P ( s n ∣ s n − 1 ) P(s_1s_2...s_n) = P(s_1)P(s_2|s_1)\cdots P(s_n|s_{n-1}) P(s1s2...sn)=P(s1)P(s2s1)P(snsn1)
但上述计算方式有一个问题,概率值都是小于1的,当序列的长度比较长时,上述数值会趋于0,计算时容易下溢。因此实践中往往使用 l o g ( P ) log(P) log(P)来代替,为了可以对比不同长度的序列的预测效果,再进一步使用 l o g ( P ) / n log(P)/n log(P)/n 表示一个序列平均的质量

上述统计 Bigram 模型在训练数据上的平均质量为:

log_likelihood = 0.0
n = 0

for w in words: # 所有word里的二元组概率叠加
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    prob = P[ix1, ix2]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1 # 所有word里的二元组数量之和

nll = -log_likelihood
print(f'{nll/n}') ## 值为 2.4764,表示前面做的bigram模型,对现有训练数据的置信度
                  ## 这个值越低表示当前模型越认可训练数据的质量,而由于训练数据是我们认为“好”的数据,因此反过来就说明这个模型好

但这里有一个问题是,例如:

log_likelihood = 0.0
n = 0

#for w in words:
for w in ["andrejz"]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    prob = P[ix1, ix2]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1
    print(f'{ch1}{ch2}: {prob:.4f} {logprob:.4f}')

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')

输出是

.a: 0.1377 -1.9829
an: 0.1605 -1.8296
nd: 0.0384 -3.2594
dr: 0.0771 -2.5620
re: 0.1336 -2.0127
ej: 0.0027 -5.9171
jz: 0.0000 -inf
z.: 0.0667 -2.7072
log_likelihood=tensor(-inf)
nll=tensor(inf)
inf

可以发现由于,jz 在计数矩阵 N 中为0,即数据中没有出现过,导致 log(loss) 变成了负无穷,这里为了避免这样的情况,需要做 平滑处理,即 P = N.float() 改成 P = (N+1).float(),这样上述代码输出变成:

.a: 0.1376 -1.9835
an: 0.1604 -1.8302
nd: 0.0384 -3.2594
dr: 0.0770 -2.5646
re: 0.1334 -2.0143
ej: 0.0027 -5.9004
jz: 0.0003 -7.9817
z.: 0.0664 -2.7122
log_likelihood=tensor(-28.2463)
nll=tensor(28.2463)
3.5307815074920654

避免了出现 inf 这种数据溢出问题。


神经网络语言模型

接下来尝试用神经网络的方式构建上述bigram语言模型:

# 构建训练数据
xs, ys = [], [] # 分别是前一个字符和要预测的下一个字符的id
for w in words[:5]:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    print(ch1, ch2)
    xs.append(ix1)
    ys.append(ix2)    
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)
# 输出示例:. e
#          e m
#          m m
#          m a
#          a .
#       xs: tensor([ 0,  5, 13, 13,  1])
#       ys: tensor([ 5, 13, 13,  1,  0])

# 随机初始化一个 27*27 的参数矩阵
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True) # 基于正态分布随机初始化
# 前向传播
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=27).float() # 将输入数据xs做成one-hot embedding
logits = xenc @ W # 用于模拟统计模型中的统计数值矩阵,由于 W 是基于正态分布采样,logits 并非直接是计数值,可以认为是 log(counts)
## tensor([[-0.5288, -0.5967, -0.7431,  ...,  0.5990, -1.5881,  1.1731],
##        [-0.3065, -0.1569, -0.8672,  ...,  0.0821,  0.0672, -0.3943],
##        [ 0.4942,  1.5439, -0.2300,  ..., -2.0636, -0.8923, -1.6962],
##        ...,
##        [-0.1936, -0.2342,  0.5450,  ..., -0.0578,  0.7762,  1.9665],
##        [-0.4965, -1.5579,  2.6435,  ...,  0.9274,  0.3591, -0.3198],
##        [ 1.5803, -1.1465, -1.2724,  ...,  0.8207,  0.0131,  0.4530]])
counts = logits.exp() # 将 log(counts) 还原成可以看作是 counts 的矩阵
## tensor([[ 0.5893,  0.5507,  0.4756,  ...,  1.8203,  0.2043,  3.2321],
##        [ 0.7360,  0.8548,  0.4201,  ...,  1.0856,  1.0695,  0.6741],
##        [ 1.6391,  4.6828,  0.7945,  ...,  0.1270,  0.4097,  0.1834],
##        ...,
##        [ 0.8240,  0.7912,  1.7245,  ...,  0.9438,  2.1732,  7.1459],
##        [ 0.6086,  0.2106, 14.0621,  ...,  2.5279,  1.4320,  0.7263],
##        [ 4.8566,  0.3177,  0.2802,  ...,  2.2722,  1.0132,  1.5730]])
probs = counts / counts.sum(1, keepdims=True) # 用于模拟统计模型中的概率矩阵,这其实即是 softmax 的实现
loss = -probs[torch.arange(5), ys].log().mean() # loss = log(P)/n, 这其实即是 cross-entropy 的实现

接下来可以通过loss.backward()来更新参数 W:

for k in range(100):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=27).float() 
  logits = xenc @ W # predict log-counts
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True) 
  loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() ## 这里加上了L2正则,防止过拟合
  print(loss.item())
  
  # backward pass
  W.grad = None # 每次反向传播前置为None
  loss.backward()
  
  # update
  W.data += -50 * W.grad  

注意这里 logits = xenc @ W 由于 xenc 是 one-hot 向量,因此这里 logits 相当于是抽出了 W 中的某一行,而结合 bigram 模型中,loss 实际上是在计算实际的 log(P[x_i, y_i]),那么可以认为这里 W 其实是在拟合 bigram 中的计数矩阵 N(不过实际是 logW 在拟合 N)

另外上述神经网络的 loss 最终也是达到差不多 2.47 的最低 loss。这是合理的,因为从上面的分析可知,这个神经网络是完全在拟合 bigram 计数矩阵的,没有使用更复杂的特征提取方法,因此效果最终也会差不多。

这里 loss 中还加了一个 L2 正则,主要目的是压缩 W,使得它向全 0 靠近,这里的效果非常类似于 bigram 中的平滑手段,想象给一个极大的平滑:P = (N+10000).float()`,那么 P 会趋于一个均匀分布,而 W 全为 0 会导致 counts = logits.exp() 全为 1,即也在拟合一个均匀分布。这里前面的参数 0.01 即是用来调整平滑强度的,如果这个给的太大,那么平滑太大了,就会学成一个均匀分布(当然实际不会希望这样,所以不会给很大)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2312739.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android MVC、MVP、MVVM三种架构的介绍和使用。

写在前面:现在随便出去面试Android APP相关的工作,面试官基本上都会提问APP架构相关的问题,用Java、kotlin写APP的话,其实就三种架构MVC、MVP、MVVM,MVC和MVP高度相似,区别不大,MVVM则不同&…

python使用django搭建图书管理系统

大家好,你们喜欢的梦幻编织者回来了 随着计算机网络和信息技术的不断发展,人类信息交流的方式从根本上发生了改变,计算机技术、信息化技术在各个领域都得到了广泛的应用。图书馆的规模和数量都在迅速增长,馆内藏书也越来越多,管理…

JavaScript系列06-深入理解 JavaScript 事件系统:从原生事件到 React 合成事件

JavaScript 事件系统是构建交互式 Web 应用的核心。本文从原生 DOM 事件到 React 的合成事件,内容涵盖: JavaScript 事件基础:事件类型、事件注册、事件对象事件传播机制:捕获、目标和冒泡阶段高级事件技术:事件委托、…

大话机器学习三大门派:监督、无监督与强化学习

以武侠江湖为隐喻,系统阐述了机器学习的三大范式:​监督学习(少林派)​凭借标注数据精准建模,擅长图像分类等预测任务;无监督学习(逍遥派)​通过数据自组织发现隐藏规律,…

win11编译llama_cpp_python cuda128 RTX30/40/50版本

Geforce 50xx系显卡最低支持cuda128,llama_cpp_python官方源只有cpu版本,没有cuda版本,所以自己基于0.3.5版本源码编译一个RTX 30xx/40xx/50xx版本。 1. 前置条件 1. 访问https://developer.download.nvidia.cn/compute/cuda/12.8.0/local_…

FY-3D MWRI亮温绘制

1、FY-3D MWRI介绍 风云三号气象卫星(FY-3)是我国自行研制的第二代极轨气象卫星,其有效载荷覆 盖了紫外、可见光、红外、微波等频段,其目标是实现全球全天候、多光谱、三维定量 探测,为中期数值天气预报提供卫星观测数…

Codeforces1929F Sasha and the Wedding Binary Search Tree

目录 tags中文题面输入格式输出格式样例输入样例输出说明 思路代码 tags 组合数 二叉搜索树 中文题面 定义一棵二叉搜索树满足,点有点权,左儿子的点权 ≤ \leq ≤ 根节点的点权,右儿子的点权 ≥ \geq ≥ 根节点的点权。 现在给定一棵 …

HBuilder X 使用 TortoiseSVN 设置快捷键方法

HBuilder X 使用 TortoiseSVN 设置快捷键方法 单文件:(上锁,解锁,提交,更新) 安装好 TortoiseSVN ,或者 按图操作: 1,工具栏中 【自定义快捷键】 2,点击 默认的快捷键设置&…

Java jar包后台运行方式详解

目录 一、打包成 jar 文件二、后台运行 jar 文件三、示例四、总结在 Java 开发中,我们经常需要将应用程序打包成可执行的 jar 文件,并在后台运行。这种方式对于部署长时间运行的任务或需要持续监听事件的应用程序非常重要。本文将详细介绍如何实现 Java jar 包的后台运行,并…

Mysql5.7-yum安装和更改mysql数据存放路径-2020年记录

记录下官网里用yum rpm源安装mysql, 1 官网下载rpm https://dev.mysql.com/downloads/repo/yum/ https://dev.mysql.com/doc/refman/5.7/en/linux-installation-yum-repo.html(附官网操作手册) wget https://repo.mysql.com//mysql80-community-release…

[项目]基于FreeRTOS的STM32四轴飞行器: 七.遥控器按键

基于FreeRTOS的STM32四轴飞行器: 七.遥控器 一.遥控器按键摇杆功能说明二.摇杆和按键的配置三.按键扫描 一.遥控器按键摇杆功能说明 两个手柄四个ADC。 左侧手柄: 前后推为飞控油门,左右推为控制飞机偏航角。 右侧手柄: 控制飞机飞行方向&a…

Android15使用FFmpeg解码并播放MP4视频完整示例

效果: 1.编译FFmpeg库: 下载FFmpeg-kit的源码并编译生成安装平台库 2.复制生成的FFmpeg库so文件与包含目录到自己的Android下 如果没有prebuiltLibs目录,创建一个,然后复制 包含目录只复制arm64-v8a下

安装树莓派3B+环境(嵌入式开发)

一、环境配置 1、下载树莓派镜像工具 点击进入下载连接 进入网站,点击下载即可。 2、配置wifi及ssh 将SD卡插入读卡器,再接入电脑,随后打开Raspberry Pi Imager下载工具, 选择Raspberry Pi 3 选择64位的操作系统 选择SD卡 选择…

p5.js:sound(音乐)可视化,动画显示音频高低变化

本文通过4个案例介绍了使用 p5.js 进行音乐可视化的实践,包括将音频振幅转化为图形、生成波形图。 承上一篇:vite:初学 p5.js demo 画圆圈 cd p5-demo copy .\node_modules\p5\lib\p5.min.js . copy .\node_modules\p5\lib\addons\p5.soun…

Linux下安装elasticsearch(Elasticsearch 7.17.23)

Elasticsearch 是一个分布式的搜索和分析引擎,能够以近乎实时的速度存储、搜索和分析大量数据。它被广泛应用于日志分析、全文搜索、应用程序监控等场景。 本文将带你一步步在 Linux 系统上安装 Elasticsearch 7.17.23 版本,并完成基本的配置&#xff0…

【The Rap of China】2018

中国新说唱第一季,2018 2018年4月13日,该节目通过官方微博宣布,其第二季将更名为《中国新说唱》。 《中国新说唱2018》由张震岳、MC Hotdog、潘玮柏、邓紫棋、WYF 担任明星制作人; 艾热获得冠军、那吾克热玉素甫江获得亚军、ICE…

通义万相2.1开源版本地化部署攻略,生成视频再填利器

2025 年 2 月 25 日晚上 11:00 通义万相 2.1 开源发布,前两周太忙没空搞它,这个周末,也来本地化部署一个,体验生成效果如何,总的来说,它在国内文生视频、图生视频的行列处于领先位置&#xff0c…

好玩的谷歌浏览器插件-自定义谷歌浏览器光标皮肤插件-Chrome 的自定义光标

周末没有啥事 看到了一个非常有意思的插件 就是 在使用谷歌浏览器的时候,可以把鼠标的默认样式换一个皮肤。就像下面的这种样子。 实际谷歌浏览器插件开发对于有前端编程基础的小伙伴 还是比较容易的,实际也是写 html css js 。 所以这个插件使用的技术…

svn删除所有隐藏.svn文件,文件夹脱离svn控制

新建一个文件,取名remove-svn-folders.reg,输入如下内容: Windows Registry Editor Version 5.00 [HKEY_LOCAL_MACHINE\SOFTWARE\Classes\Folder\shell\DeleteSVN] "Delete SVN Folders" [HKEY_LOCAL_MACHINE\SOFTWARE\Class…

六十天前端强化训练之第十二天之闭包深度解析

欢迎来到编程星辰海的博客讲解 目录 第一章:闭包的底层运行机制 1.1 词法环境(Lexical Environment)的构成JavaScript 引擎通过三个关键组件管理作用域: 1.2 作用域链的创建过程当函数被定义时: 1.3 闭包变量的生命…