神经网络:Zero2Hero 2 - MLP、Embedding

news2024/11/18 17:52:43

Zero → \to Hero : 2

接上篇,Zero → \to Hero : 1,进一步的扩展模型:

  1. 增加输入字符序列的长度,通过多个字符预测下一个字符的概率分布
  2. 增加模型的深度,通过多层的MLP来学习和预测字符的生成概率
  3. 增加嵌入层,把字符转换为稠密向量
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
from matplotlib.font_manager import FontProperties
font = FontProperties(fname='../chinese_pop.ttf', size=10)

加载数据集

数据是一个中文名数据集

  • 名字最小长度为 2:
  • 名字最大长度为 3:
words = open('../Chinese_Names_Corpus.txt', 'r').read().splitlines()
# 数据包含100多万个姓名,过滤出一个姓氏用来测试
names = [name for name in words if name[0] == '王' and len(name) == 3]
len(names)
52127
# 构建词汇表到索引,索引到词汇表的映射,词汇表大小为:1561(加上开始和结束填充字符):
chars = sorted(list(set(''.join(names))))
char2i = {s:i+1 for i,s in enumerate(chars)}
char2i['.'] = 0               # 填充字符
i2char = {i:s for s,i in char2i.items()}
len(chars)
1650
block_size = 2   # 用两个字符预测下一个字符
X, Y = [], []
for w in names[:1]:
    context = [0] * block_size
    for ch in w + '.':
        ix = char2i[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(i2char[i] for i in context), '--->', i2char[ix])
        context = context[1:] + [ix] # crop and append
  
X = torch.tensor(X)
Y = torch.tensor(Y)
.. ---> 王
.王 ---> 阿
王阿 ---> 宝
阿宝 ---> .

构建训练数据

block_size = 2  

def build_dataset(names):  
    X, Y = [], []
    for w in names:
        context = [0] * block_size
        for ch in w + '.':
            ix = char2i[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] # crop and append

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y

划分数据集


import random
random.seed(42)
random.shuffle(names)
n1 = int(0.8*len(names))

Xtr, Ytr = build_dataset(names[:n1])
Xte, Yte = build_dataset(names[n1:])

torch.Size([166804, 2]) torch.Size([166804])
torch.Size([41704, 2]) torch.Size([41704])

构建MLP

模型结构:输入层 → \to 嵌入层 → \to 隐藏层 → \to 输出层。

  • 在Zero2Hero:1中,是直接对输入进行了one-hot编码,然后直接将编码后的稀疏向量传递给了输出层。

  • 在Zero2Hero:2中,增加嵌入层,通过查询嵌入层,把输入字符(token)序列转换低维的稠密向量,同时随着模型的训练,字符的表示也会发生变化。

  • 在Zero2Hero:2中,增加隐藏层,提取更高级的特征表示。

嵌入层:

  1. 随机初始化一张向量表
  2. 根据字符的索引查询对应的向量表示
  3. 向量表示向后传播,最后通过梯度下降,修正向量表
# 随机初始化和词汇表大小相同的向量表
len(char2i)
C = torch.randn((1651, 2))
1651
# 根据索引查询字符的向量表示
emb = C[X]
emb.shape
emb
tensor([[[ 0.7928, -0.2331],
         [ 0.7928, -0.2331]],

        [[ 0.7928, -0.2331],
         [-2.0187, -1.1116]],

        [[-2.0187, -1.1116],
         [ 0.9795,  1.4175]],

        [[ 0.9795,  1.4175],
         [ 0.2740,  1.5466]]])

初始化模型参数:

g = torch.Generator().manual_seed(2147483647)  
C = torch.randn((len(char2i), 2), generator=g)
W1 = torch.randn((4, 200), generator=g)
b1 = torch.randn(200, generator=g)
W2 = torch.randn((200, len(char2i)), generator=g)
b2 = torch.randn(len(char2i), generator=g)
parameters = [C, W1, b1, W2, b2]
print("参数统计:",sum(p.nelement() for p in parameters)) # 模型参数统计
参数统计: 336153
for p in parameters:
    p.requires_grad = True

训练模型:

lri = []
lossi = []
stepi = []
for i in range(20000):
    # minibatch construct
    ix = torch.randint(0, Xtr.shape[0], (32,))
    # forward pass
    emb = C[Xtr[ix]]                            # (32, 2, 2)
    h = torch.tanh(emb.view(-1, 4) @ W1 + b1)  # (32, 200)
    logits = h @ W2 + b2                        # (32, 1651)
    loss = F.cross_entropy(logits, Ytr[ix])
    #print(loss.item())
  
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
  
    # update
    
    lr = 0.1 if i < 10000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    #lri.append(lre[i])
    stepi.append(i)
    lossi.append(loss.log10().item())
print(loss.item())
3.3062288761138916
plt.plot(stepi, lossi)

在这里插入图片描述

测试误差:

emb = C[Xte]  
h = torch.tanh(emb.view(-1, 4) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2                       # (32, 1651)
loss = F.cross_entropy(logits, Yte)
loss
tensor(3.2770, grad_fn=<NllLossBackward0>)

可视化

可视化字符的嵌入:

# visualize dimensions 0 and 1 of the embedding matrix C for all characters
plt.figure(figsize=(12,6))
plt.scatter(C[:,0].data, C[:,1].data, s=200)
for i in range(C.shape[0]):
    plt.text(C[i,0].item(), C[i,1].item(), i2char[i], ha="center", va="center", color='white',fontproperties=font)
plt.grid('minor')

在这里插入图片描述

测试生成

g = torch.Generator().manual_seed(20230516 + 10)
for _ in range(10):
    out = []
    context = [0] * block_size           # initialize with all ...
    while True:
        emb = C[torch.tensor([context])] # (1,block_size,d)
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break
    print(''.join(i2char[i] for i in out))
王胜兵.
王紫琴.
王坡健.
王家菲.
王青金.
王碧财.
王华士.
王海维.
王旭荣.
王玉树.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/545898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode高频算法刷题记录5

文章目录 1. 最长递增子序列【中等】1.1 题目描述1.2 解题思路1.3 代码实现 2. 接雨水【困难】2.1 题目描述2.2 解题思路2.3 代码实现 3. 二叉树中的最大路径和【困难】3.1 题目描述3.2 解题思路3.3 代码实现 4. 二叉树的中序遍历【简单】4.1 题目描述4.2 解题思路4.3 代码实现…

HTB靶机014-Sunday-WP

Sunday 靶机IP&#xff1a;10.10.10.76 PortScan Nmap 快速扫描&#xff1a; ┌──(xavier㉿kali)-[~] └─$ sudo nmap -sSV -T4 -F 10.10.10.76 Starting Nmap 7.93 ( https://nmap.org ) at 2023-05-06 00:10 CST Nmap scan report for 10.10.10.76 Host is …

2023年广东省中职网络安全Web渗透测试解析(超详细)

一、竞赛时间 180分钟 共计3小时 二、竞赛阶段 1.访问地址http://靶机IP/task1,分析页面内容,获取flag值,Flag格式为flag{xxx}; 2.访问地址http://靶机IP/task2,访问登录页面。用户user01的密码为1-1000以内的数,获取用户user01的密码,将密码作为Flag进行提交,Flag格式…

客户端读取响应头+后端读取请求头的那些事

在一些特殊场景中&#xff0c;我们在客户端想要去获取服务端接口设置的一些自定义响应头&#xff0c;服务端该如何处理&#xff0c;客户端才能取到这些自定义响应头的值呢&#xff1f; 特殊场景&#xff0c;我这里也举例一下&#xff0c;原生页面webView嵌入web页面。这个时候…

Shell脚本编程入门--Day2

文章目录 几个简单内置shell命令shell字串的语法计算变量长度的各种玩法批量修改文件名特殊shell扩展变量实际应用父子shell创建进程列表&#xff08;创建子shell&#xff09; 几个简单内置shell命令 echo -n 不换行输出 -e 解析字符串中的特殊符号 &#xff08;\n, \r, \t, \…

浏览器的进程和线程

浏览器是多进程多线程的应用程序 浏览器进程 主要负责界面显示、用户交互、子进程管理等。浏览器进程内部会启动多个线程处理不同的任务。 网络进程 负责加载网络资源。网络进程内部会启动多个线程来处理不同的网络任务。 渲染进程 渲染进程启动后&#xff0c;会开启一个染主线…

macOS 13.4正式版(22F66)With OpenCore 0.9.2开发版 and winPE双引导分区原版镜像

镜像特点 原文地址&#xff1a;http://www.imacosx.cn/113625.html 完全由黑果魏叔官方制作&#xff0c;针对各种机型进行默认配置&#xff0c;让黑苹果安装不再困难。系统镜像设置为双引导分区&#xff0c;全面去除clover引导分区&#xff08;如有需要&#xff0c;可以自行直…

java boot项目基础配置之设置启动端口

因为 springboot 项目是一个内嵌的tomcat 那么 我们就来研究一下 怎么改它的启动端口 其实 它的配置 还是非常多的 我们基础部分讲一下 后面 到实用部分 再一边用 一边再看一些 首先 我们如果不设置 他就会占用 我们的 8080端口 那么 我们最好就直接用 80端口 就不用输入端口…

【滤波】卡尔曼滤波数学

本文主要翻译自rlabbe/Kalman-and-Bayesian-Filters-in-Python的第7章节07-Kalman-Filter-Math&#xff08;卡尔曼滤波数学&#xff09;。 %matplotlib inline#format the book import book_format book_format.set_style()简介 如果你已经学习到了这一步&#xff0c;我希望你…

vue基础知识一:说说你对vue的理解?

一、从历史说起 Web是World Wide Web的简称&#xff0c;中文译为万维网 我们可以将它规划成如下的几个时代来进行理解 石器时代 文明时代 工业革命时代 百花齐放时代 石器时代 石器时代指的就是我们的静态网页&#xff0c;可以欣赏一下1997的Apple官网 最早的网页是没有数据库…

day1 IO 模型

目录 基本概念 同步和异步 阻塞和非阻塞 线程在运行过程中&#xff0c;可能由于以下几种原因进入阻塞状态&#xff1a; 可能阻塞套接字的Linux Sockets API调用分为以下四种 五种 I/O 模型 阻塞I/O 非阻塞I/O ​编辑 I/O多路复用模型 信号驱动式I/O模型 异步I/O 模型 …

网络模块封装

网络模块封装 library-network模块配置依赖一.自定义LiveDataCallAdapterFactory1.定义ApiResponse返回的数据类型2.LiveDataCallAdapter.kt3.LiveDataCallAdapter.kt 二.自定义CustomGsonConverterFactory三.拦截器1.HeaderInterceptor请求头拦截器2.BasicParamsInterceptor参…

Android - 内容提供者(Content Provider) 使用

Android - 内容提供者(Content Provider) 内容提供者组件通过请求从一个应用程序向其他的应用程序提供数据。这些请求由类 ContentResolver 的方法来处理。内容提供者可以使用不同的方式来存储数据。数据可以被存放在数据库&#xff0c;文件&#xff0c;甚至是网络。 有时候需…

Python如何进行性能测试?(Locust对接口进行压测)

python如何进行性能测试呢&#xff1f;其实原理就是对于接口进行加线程&#xff0c;打个比方就是当你有一个电梯&#xff0c;你同时可以搭载多少个人坐电梯那这个人数就是这部电梯的其中一个性能指标&#xff0c;那么对于接口来说每秒钟能有多少人成功发起请求后得到成功的响应…

QT 学习笔记2 信号与槽

上次做的界面&#xff0c;并没有逻辑。你点击按钮&#xff0c;并不会执行什么 要想其能作出反映&#xff0c;就不得不提到一个很重要的机制---信号与槽 当我们点击确定的时候&#xff0c;按钮会发出一个信号 点击确定的时候&#xff0c;会执行一段代码&#xff0c;这段程序就…

Ada 语言学习(3)复合类型数据——Array

文章目录 Array数据类型声明数组索引数组范围数组复制数组初始化直接赋值通过拷贝赋值不同索引范围但长度相等非指定类型边界收缩 多维数组数组遍历数组切片访问和动态检查直接访问动态检查 数组字面量 Array literal数组拼接两个数组拼接数组和单个值拼接 Array Equality&…

机器学习平台 PAI 支持抢占型实例,模型服务最高降本 90%

助力模型推理服务降本增效&#xff0c;适用于推理成本敏感场景&#xff0c;如&#xff1a;AIGC 内容生成异步推理、批量图像处理、批量音视频处理等。 在 AI 开发及服务不断追求效率的背景下&#xff0c;阿里云机器学习平台 PAI 宣布支持抢占型实例&#xff08;Spot Instance&a…

2023逆向分析代码渗透测试flag0072解析(超详细)

一、竞赛时间 180分钟 共计3小时 1.从靶机服务器的FTP上下载flag0072,分析该文件,请提交代码保护技术的类型。提交格式:XXXX。 2.提交被保护的代码所在地址。提交格式: 0xXXXX。 3.提交代码解密的密钥。提交格式: 0xXX。 4.请提交输入正确flag时的输出。提交格式: XXXX。…

Python入门(十二)while循环(二)

while循环&#xff08;二&#xff09; 1.使用while循环处理列表和字典2.在列表之间移动元素3.删除为特定值的所有列表元素4.使用用户输入来填充字典 作者&#xff1a;xiou 1.使用while循环处理列表和字典 到目前为止&#xff0c;我们每次都只处理了一项用户信息&#xff1a;获…

建站教程:腾讯云轻量服务器安装宝塔面板搭建网站流程

腾讯云轻量应用服务器镜像选择宝塔Linux面板&#xff0c;然后在宝塔面板上安装LNMP网站所需的Web环境&#xff0c;在宝塔面板上新建站点&#xff0c;上床网站程序安装包到根目录&#xff0c;并安装网站全流程。腾讯云百科来详细说下腾讯云轻量应用服务器搭建网站全流程&#xf…