神经网络:Zero2Hero 2

news2025/1/20 14:57:43

Zero → \to Hero : 2

接上篇,Zero → \to Hero : 1,进一步的扩展模型:

  1. 增加输入字符序列的长度,通过多个字符预测下一个字符的概率分布
  2. 增加模型的深度,通过多层的MLP来学习和预测字符的生成概率
  3. 增加嵌入层,把字符转换为稠密向量
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
from matplotlib.font_manager import FontProperties
font = FontProperties(fname='../chinese_pop.ttf', size=10)

加载数据集

数据是一个中文名数据集

  • 名字最小长度为 2:
  • 名字最大长度为 3:
words = open('../Chinese_Names_Corpus.txt', 'r').read().splitlines()
# 数据包含100多万个姓名,过滤出一个姓氏用来测试
names = [name for name in words if name[0] == '王' and len(name) == 3]
len(names)
52127
# 构建词汇表到索引,索引到词汇表的映射,词汇表大小为:1561(加上开始和结束填充字符):
chars = sorted(list(set(''.join(names))))
char2i = {s:i+1 for i,s in enumerate(chars)}
char2i['.'] = 0               # 填充字符
i2char = {i:s for s,i in char2i.items()}
len(chars)
1650
block_size = 2   # 用两个字符预测下一个字符
X, Y = [], []
for w in names[:1]:
    context = [0] * block_size
    for ch in w + '.':
        ix = char2i[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(i2char[i] for i in context), '--->', i2char[ix])
        context = context[1:] + [ix] # crop and append
  
X = torch.tensor(X)
Y = torch.tensor(Y)
.. ---> 王
.王 ---> 阿
王阿 ---> 宝
阿宝 ---> .

构建训练数据

block_size = 2  

def build_dataset(names):  
    X, Y = [], []
    for w in names:
        context = [0] * block_size
        for ch in w + '.':
            ix = char2i[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # crop and append

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

划分数据集


import random
random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))

Xtr, Ytr = build_dataset(names[:n1])
Xte, Yte = build_dataset(names[n1:])

torch.Size([166804, 2]) torch.Size([166804])
torch.Size([41704, 2]) torch.Size([41704])

构建MLP

模型结构:输入层 → \to 嵌入层 → \to 隐藏层 → \to 输出层。

  • 在Zero2Hero:1中,是直接对输入进行了one-hot编码,然后直接将编码后的稀疏向量传递给了输出层。

  • 在Zero2Hero:2中,增加嵌入层,通过查询嵌入层,把输入字符(token)序列转换低维的稠密向量,同时随着模型的训练,字符的表示也会发生变化。

  • 在Zero2Hero:2中,增加隐藏层,提取更高级的特征表示。

嵌入层:

  1. 随机初始化一张向量表
  2. 根据字符的索引查询对应的向量表示
  3. 向量表示向后传播,最后通过梯度下降,修正向量表
# 随机初始化和词汇表大小相同的向量表
len(char2i)
C = torch.randn((1651, 2))
1651
# 根据索引查询字符的向量表示
emb = C[X]
emb.shape
emb
tensor([[[ 0.7928, -0.2331],
         [ 0.7928, -0.2331]],

        [[ 0.7928, -0.2331],
         [-2.0187, -1.1116]],

        [[-2.0187, -1.1116],
         [ 0.9795,  1.4175]],

        [[ 0.9795,  1.4175],
         [ 0.2740,  1.5466]]])

初始化模型参数:

g = torch.Generator().manual_seed(2147483647)  
C = torch.randn((len(char2i), 2), generator=g)
W1 = torch.randn((4, 200), generator=g)
b1 = torch.randn(200, generator=g)
W2 = torch.randn((200, len(char2i)), generator=g)
b2 = torch.randn(len(char2i), generator=g)
parameters = [C, W1, b1, W2, b2]
print("参数统计:",sum(p.nelement() for p in parameters)) # 模型参数统计
参数统计: 336153
for p in parameters:
    p.requires_grad = True

训练模型:

lri = []
lossi = []
stepi = []
for i in range(20000):
    # minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (32,))
    # forward pass
    emb = C[Xtr[ix]]                            # (32, 2, 2)
    h = torch.tanh(emb.view(-1, 4) @ W1 + b1)  # (32, 200)
    logits = h @ W2 + b2                        # (32, 1651)
    loss = F.cross_entropy(logits, Ytr[ix])
    #print(loss.item())
  
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
  
    # update
    
    lr = 0.1 if i < 10000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    #lri.append(lre[i])
    stepi.append(i)
    lossi.append(loss.log10().item())
print(loss.item())
3.3062288761138916
plt.plot(stepi, lossi)

在这里插入图片描述

测试误差:

emb = C[Xte]  
h = torch.tanh(emb.view(-1, 4) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2                       # (32, 1651)
loss = F.cross_entropy(logits, Yte)
loss
tensor(3.2770, grad_fn=<NllLossBackward0>)

可视化

可视化字符的嵌入:

# visualize dimensions 0 and 1 of the embedding matrix C for all characters
plt.figure(figsize=(12,6))
plt.scatter(C[:,0].data, C[:,1].data, s=200)
for i in range(C.shape[0]):
    plt.text(C[i,0].item(), C[i,1].item(), i2char[i], ha="center", va="center", color='white',fontproperties=font)
plt.grid('minor')

在这里插入图片描述

测试生成

g = torch.Generator().manual_seed(20230516 + 10)
for _ in range(10):
    out = []
    context = [0] * block_size           # initialize with all ...
    while True:
        emb = C[torch.tensor([context])] # (1,block_size,d)
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break
    print(''.join(i2char[i] for i in out))
王胜兵.
王紫琴.
王坡健.
王家菲.
王青金.
王碧财.
王华士.
王海维.
王旭荣.
王玉树.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/532172.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习04-CNN经典模型

简介 卷积神经网络&#xff08;CNN&#xff09;是深度学习中非常重要的一种网络结构&#xff0c;它可以处理图像、文本、语音等各种类型的数据。以下是CNN的前4个经典模型 LeNet-5 LeNet-5是由Yann LeCun等人于1998年提出的&#xff0c;是第一个成功应用于手写数字识别的卷积…

【数据结构】线性表之链表

目录 前言一、链表的定义二、链表的分类1. 单向和双向2. 带头和不带头3. 循环和不循环4. 常用&#xff08;无头单向非循环链表和带头双向循环链表&#xff09; 三、无头单向非循环链表的接口及实现1. 单链表的接口2. 接口的实现 四、带头双向循环链表接口的及实现1. 双向链表的…

磺酸基-Cy5 羧酸Sulfo-Cy5 COOH分子式C32H37N2KO8S2

Sulfo CY5 COOH是一种有机化合物&#xff0c;属于荧光染料。它具有荧光、稳定、水溶性等特点&#xff0c;因此被应用于分析化学、生物技术、药物研发等领域。Sulfo CY5 COOH的分子式为C32H37N2KO8S2&#xff0c;分子量为680.87。它的荧光波长为670nm&#xff0c;可以通过荧光显…

如何在AD中添加自定义材料单模板

AD默认的材料单格式和常用的格式有点区别&#xff0c;为了减少在材料单格式编辑的工作&#xff0c;决定添加自定义模板到AD的模板中。 1.查找AD模板的安装位置 在AD菜单Reports中&#xff0c;找到“Bill of materials”菜单&#xff0c; 点击后&#xff0c;弹出的窗口中包含了…

Kubernets1.20部署Redis7.0集群6节点三主三从(完整版)-2023.5.13

目录 一、产品选型二、草图三、部署1、安装NFS服务1&#xff09;NFS Server端安装NFS2&#xff09;创建NFS 共享点3&#xff09;启动rpcbind、nfs服务4&#xff09;验证服务配置 2、创建持久卷PVC1&#xff09;创建ServiceAccount账号2&#xff09;创建provisioner3&#xff09…

vite入坑之路:react+vite动态导入报错@vite-ignore的解决方法

正常的动态组件导入方式 webpack搭建的项目&#xff0c;不管是react还是vue通常引入动态组件基本这么写&#xff1a; const url import(../pages/${locale}) // vite不支持or const url import(../pages/${locale}/index.jsx) // vite不支持这在vite架构中&#xff0c;一般…

Vue3+vite环境变量配置

在项目开发中&#xff0c;通常来说&#xff0c;不同的环境会有不同的请求api接口&#xff0c;这就需要修改配置&#xff0c;才能满足对应的环境。所以这里就使用了环境变量。环境变量就是在不同的环境中使用不同的变量值。 # 环境变量文件(.env) 在项目根目录&#xff08;和sr…

TCP协议和相关特性

1.TCP协议的报文结构 TCP的全称为&#xff1a;Transmission Control Protocol。 特点: 有连接可靠传输面向字节流全双工 下面是TCP的报文结构&#xff1a; 源端口和目的端口&#xff1a; 源端口表示数据从哪个端口传输出来&#xff0c;目的端口表示数据传输到哪个端口去。…

FPGA_学习_03_第一个FPGA程序流水灯

学习编程&#xff0c;最重要永远就是动手&#xff0c;本文将在开发板上实现FPGA的“Hello world”→流水灯。本文主要目的是熟悉在Vivado上从零到程序运行起来的基本开发流程。 1 硬件电路介绍 本人购买的开发板接在PL端的只有2个LED灯&#xff0c;刚好达到流水灯的最低要求。…

今年这情况,大家多一手准备吧......

大家好&#xff0c;最近有不少小伙伴在后台留言&#xff0c;又得准备面试了&#xff0c;不知道从何下手&#xff01; 不论是跳槽涨薪&#xff0c;还是学习提升&#xff01;先给自己定一个小目标&#xff0c;然后再朝着目标去努力就完事儿了&#xff01; 为了帮大家节约时间&a…

ASEMI代理MAX5048BAUT+T原装ADI车规级MAX5048BAUT+T

编辑&#xff1a;ll ASEMI代理MAX5048BAUTT原装ADI车规级MAX5048BAUTT 型号&#xff1a;MAX5048BAUTT 品牌&#xff1a;ADI /亚德诺 封装&#xff1a;SOT-23-6 批号&#xff1a;2023 安装类型&#xff1a;表面贴装型 引脚数量&#xff1a;6 工作温度:-40C~125C 类型&a…

npx下载构建nuxt3开发模板失败的解决方案

在搭建nuxt3项目开发的时候&#xff0c;安装nuxt3开发模板的时候&#xff0c;使用命令&#xff1a; npx nuxi init my-app 会出出现一下错误&#xff1a; This is related to npm not being able to find a file. 发生上述错误是因为您有一个未正确安装的依赖项。 以下是解决…

大央企的“中央厨房”,泰裤辣

本文来源&#xff1a;特大号 作者&#xff1a;特大妹 最近两年&#xff0c;大央企大国企在数字化转型中&#xff0c;特热衷成立“中央厨房”。 有的中央厨房&#xff0c;单独挂牌为“数科公司”&#xff0c;有的中央厨房&#xff0c;升级为集团数字化转型的一级部门。 把之前各…

“警”彩集结|北峰通信亮相11届警博会,多场景助力警务智能化

2023年5月11日-14日&#xff0c;第十一届中国国际警用装备博览会(警博会)在北京首钢会展中心隆重召开。“警博会”作为中国乃至亚太地区最具影响力、最权威的警用装备盛会&#xff0c;代表了中国警用装备行业的最高水平。北峰通信作为服务公共安全实战30余年的企业&#xff0c;…

软考A计划-真题-分类精讲汇总-第十二章(法律法规与标准化)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

Web渗透 不断更新

Web渗透 SQL注入一般注入步骤 文件上传漏洞过滤绕过空格绕过 针对Linux特定字符过滤绕过 针对Linux(例如&#xff1a;cat) 序列号unserialize SQL注入 一般注入步骤 注入点 --> 查询注入字段数 --> 查询注入回显位 --> 查询当前数据库信息 --> 查询数据库表 --&g…

MySQL基础(三十四)锁

1. 概述 在数据库中&#xff0c;除传统的计算资源&#xff08;如CPU、RAM、I/O等&#xff09;的争用以外&#xff0c;数据也是一种供许多用户共享的 资源。为保证数据的一致性&#xff0c;需要对 并发操作进行控制&#xff0c;因此产生了 锁 。同时 锁机制 也为实现MySQL 的各…

HIT数据结构lab2-树型结构的建立与遍历

title: 数据结构lab2-树型结构的建立与遍历 date: 2023-05-16 11:42:26 tags: 数据结构与算法 哈尔滨工业大学计算机科学与技术学院 实验报告 课程名称&#xff1a;数据结构与算法 课程类型&#xff1a;必修 实验项目&#xff1a;树型结构的建立与遍历 实验题目&#xff1…

【目标检测】模型信息解析/YOLOv5检测结果中文显示

前言 之前写过一篇博文【目标检测】YOLOv5&#xff1a;标签中文显示/自定义颜色&#xff0c;主要从显示端解决目标中文显示的问题。 本文着重从模型角度&#xff0c;从模型端解决目标中文显示问题。 模型信息解析 正常情况下&#xff0c;可以直接加载模型打印信息&#xff0…

GPT专业应用:英语作文修改与解释

正文共 868 字&#xff0c;阅读大约需要 3 分钟 英语学习者/老师必备技巧&#xff0c;您将在3分钟后获得以下超能力&#xff1a; 快速修改英语作文 Beezy评级 &#xff1a;B级 *经过简单的寻找&#xff0c; 大部分人能立刻掌握。主要节省时间。 推荐人 | Kim 编辑者 | Linda …