karpathy build make more --- 2

news2024/11/24 19:54:34

1 Introduction

用多层神经网络实现更复杂一点名字预测器。

2 方案

采用两层全连接层,中间采用tanh作为激活函数,最后一层用softmax,loss用cross-entropy.

2.1 实施

step1: 生成输入的字符,输入三个字符,输出一个字符.
采用了队列的方式,好处是能完整覆盖收尾;

import torch
def build_datasets(lines):
    xs, ys = [], []
    block_size = 3
    for line in lines:
        context = [0] * block_size
        for ch in line + '.':
            ix = stoi[ch]
            xs.append(context)
            ys.append(ix)
            context = context[1:] + [ix]

    xs = torch.tensor(xs)
    ys = torch.tensor(ys)
    return xs, ys

step2: 对数据划分训练集,测试集和验证集,并进行迭代训练和验证

from sklearn.model_selection import train_test_split
train_set, temp_set = train_test_split(lines, test_size=0.2, random_state=42)
test_set, val_set = train_test_split(temp_set, test_size=0.5, random_state=42)
train_xs, train_ys = build_datasets(train_set)
test_xs, test_ys = build_datasets(test_set)
val_xs, val_ys = build_datasets(val_set)

step3: 对输入的字符进行token化;
在这里插入图片描述
从我的理解来说,这里隐式的采用one-hot encoding,然后通过矩阵C进行进行压缩,采用相同的矩阵,保证对于每个词有相同的映射关系,也被称为归纳偏置(inductive bias)。

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 10), generator=g)
batch_size = 32
ix = torch.randint(0, train_xs.shape[0], (batch_size,))
C[train_xs[ix]]

step4: 定义网络结构:
输入向量的维度[batch,27]->[27,10]-(合并)->(30, 200)->(200, 27)
pipeline: 输入向量通过Matrix C进行token化,然后再通过一个全连接层,得到隐藏层h,并对隐藏层进行tanh的激活函数;再经过一个全连接层,得到输出层,并对输出层进行softmax处理

import torch
import torch.nn.functional as F

# Assuming the embedding matrix C, W1, B1, W2, B2 have been defined and initialized as before

batch_size = 32
learning_rate = 0.1
print_interval = 100
for i in range(10000):
    ix = torch.randint(0, train_xs.shape[0], (batch_size,))
    emb = C[train_xs[ix]]
    h = torch.tanh(emb.view(-1, 30) @ W1 + B1)
    logits = h @ W2 + B2
    loss = F.cross_entropy(logits, train_ys[ix])
    loss.backward()
    
    with torch.no_grad():  # Update parameters
        C.data -= learning_rate * C.grad
        W1.data -= learning_rate * W1.grad
        B1.data -= learning_rate * B1.grad
        W2.data -= learning_rate * W2.grad
        B2.data -= learning_rate * B2.grad
        
        C.grad.zero_()
        W1.grad.zero_()
        B1.grad.zero_()
        W2.grad.zero_()
        B2.grad.zero_()
    if (i + 1) % print_interval == 0:
        print(f"Iteration {i+1}: Loss = {loss.item()}")

来看一下这个代码,存在几个问题
1)所有的参数都要搞一遍zero_()太麻烦了

  • 可以通过优化器来实现
# 使用 torch.optim 包中的优化器,比如 SGD
optimizer = optim.SGD([C, W1, B1, W2, B2], lr=learning_rate)
# 使用优化器更新参数
optimizer.step()
# 清除所有参数的梯度
optimizer.zero_grad()
  • 可以将所有参数放在一个list中
parameters = [C, W1, b1, W2, b2]
for p in parameters:
  p.requires_grad = True
for p in parameters:
  p.grad = None
for p in parameters:
  p.data += -lr * p.grad
  1. 学习率这个参数不好设置
  • 随着迭代的进行,学习率逐渐衰减
    lre = torch.linspace(-3, 0, 1000)
    lrs = 10**lre

  • 更简单的二分学习率
    lr = 0.1 if i < 100000 else 0.01
    3)没有统计loss的变化情况
    lossi = []
    stepi = []
    stepi.append(i)
    lossi.append(loss.log10().item())

将上面这些修正添加进去

import torch
import torch.nn.functional as F

# Assuming the embedding matrix C, W1, B1, W2, B2 have been defined and initialized as before
parameters = [C, W1, B1, W2, B2]
batch_size = 32
print_interval = 100
stepi = []
lossi = []
for i in range(200000):
    ix = torch.randint(0, train_xs.shape[0], (batch_size,))
    emb = C[train_xs[ix]]
    h = torch.tanh(emb.view(-1, 30) @ W1 + B1)
    logits = h @ W2 + B2
    loss = F.cross_entropy(logits, train_ys[ix])
    for p in parameters:
        p.grad = None
    loss.backward()
    lr = 0.1 if i < 100000 else 0.01
    for p in parameters:
        p.data -= lr * p.grad
    stepi.append(i)
    lossi.append(loss.log10().item())
    if (i + 1) % print_interval == 0:
        print(f"Iteration {i+1}: Loss = {loss.item()}")

在这里插入图片描述
step5: 进行检验

# 训练误差
emb = C[train_xs]
h = torch.tanh(emb.view(-1, 30) @ W1 + B1)
logits = h @ W2 + B2
loss = F.cross_entropy(logits, train_ys)
print("training loss:", float(loss.item()))
# training loss: 2.1326515674591064
# test误差
emb = C[test_xs]
h = torch.tanh(emb.view(-1, 30) @ W1 + B1)
logits = h @ W2 + B2
loss = F.cross_entropy(logits, test_ys)
print("test loss:", float(loss.item()))
# test loss: 2.1820669174194336
# valid误差
emb = C[val_xs]
h = torch.tanh(emb.view(-1, 30) @ W1 + B1)
logits = h @ W2 + B2
loss = F.cross_entropy(logits, val_ys)
print("validation loss:", float(loss.item()))
# validation loss: 2.1854469776153564

三个数据集上的误差并不太大

g = torch.Generator().manual_seed(2147483647)
for i in range(10):
    out = []
    ix = torch.tensor([0, 0, 0])  # 假设0是起始字符的索引
    while True:
        emb = C[ix]  # 只获取最后一个索引的embedding
        h = torch.tanh(emb.view(-1, 30) @ W1 + B1)
        logits = h @ W2 + B2
        y_prob = torch.softmax(logits, dim=1)
        next_ix = torch.multinomial(y_prob, num_samples=1, replacement=True, generator=g).item()
        next_char = itos[next_ix]  # 假设 itos 是已经定义好的
        out.append(next_char)
        ix = torch.cat((ix[1:], torch.tensor([next_ix])), dim=0)  # 更新ix
        if next_ix == 0:  # 假设0是终止字符的索引
            break
    print(''.join(out))

输出结果:

min.
axuanie.
xiviyah.
anquen.
hida.
ariiseny.
ril.
paitheke.
dakshaldineah.
kareedusar.

最后作者还介绍了一种字符和高维空间的映射关系图,来说明encoding以后神经网络学到的字符关系。

# 创建散点图
plt.figure(figsize=(8,8))
plt.scatter(C[:, 0].data, C[:, 1].data, s=200, edgecolors='green', facecolors='green')  # 画圈

# 在每个点旁边添加文字
for i in range(C.shape[0]):
    plt.text(C[i, 0].data, C[i, 1].data, itos[i],  color='white', fontsize=12,
             ha='center', va='center')  # 添加水平和垂直居中对齐


# 设置图表标题和坐标轴标签
plt.title('One-hot Vectors Transformation Visualization')
plt.xlabel('Dimension 1')
plt.ylabel('Dimension 2')

plt.grid('minor')

在这里插入图片描述

在深度学习模型中,字符嵌入(如矩阵C)的作用是将每个字符映射到一个连续的向量空间,这样的向量表示可以捕获和编码字符间的某些关系和语义特征。每个字符的嵌入向量的维度(在您的例子中是10维)通常是通过模型学习得到的,目的是为了最佳地支持模型的任务,比如语言模型的下一个字符预测。
当我们从嵌入矩阵C中选择前两个维度来可视化时,我们试图在一个二维平面上捕捉和理解这些10维向量的结构和关系。C[:,0]C[:,1]分别代表嵌入向量的第一和第二维度。通过将它们可视化,我们可以:

  1. 观察字符嵌入的相对位置:字符向量在这个二维空间中的距离可以暗示字符之间的关系。例如,如果两个字符的向量在图中很接近,这可能意味着在模型学习的任务上它们有类似的作用或出现在相似的上下文中。
  2. 了解模型学到的表示:通过查看字符在两个维度上的分布,我们可以得到一些关于模型如何表示数据的直观了解。

至于嵌入向量数值的大小,我们不能从一个或两个维度直接得出深刻的结论,因为每个维度通常都是在高维空间中与其他维度一起工作的。在优化过程中,模型试图找到一个高维空间,其中向量之间的距离或方向能够支持模型进行准确的预测。
在二维空间中:

  • 数值大的结果可能表示该维度在该特定字符嵌入向量中具有较高的数值。这可能意味着对于模型区分字符或其上下文非常重要的特征。
  • 数值小的结果可能表示在该维度上特征值不突出,这可能是一个对于当前模型不太重要的特征。

总的来说,这个二维可视化是高维特征的一个简化视图,虽然不能完全捕捉所有的细节,但却提供了一个关于字符向量如何在模型中组织的有用的直观印象。在实际情况中,每个维度的具体物理意义很难解释,因为它们通常是通过模型的学习过程自动发现的,并不直接对应于直观可解释的属性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java的spring循环依赖、Bean作用域等深入理解

前言 通过之前的几篇文章将Spring基于XML配置的IOC原理分析完成&#xff0c;但其中还有一些比较重要的细节没有分析总结&#xff0c;比如循环依赖的解决、作用域的实现原理、BeanPostProcessor的执行时机以及SpringBoot零配置实现原理&#xff08;ComponentScan、Import、Impo…

推荐一款websocket接口测试工具

网址&#xff1a;Websocket在线测试-Websocket接口测试-Websocket模拟请求工具 http://www.jsons.cn/websocket/ 很简单输入以ws开后的网址就可以了 这个网址是你后台设置的 如果连接成功会砸提示框内显示相关字样&#xff0c;反之则不行

python爬虫之爬取携程景点评价(5)

一、景点部分评价爬取 【携程攻略】携程旅游攻略,自助游,自驾游,出游,自由行攻略指南 (ctrip.com) import requests from bs4 import BeautifulSoupif __name__ __main__:url https://m.ctrip.com/webapp/you/commentWeb/commentList?seo0&businessId22176&busines…

U.2 NVMe全闪存储阵列在高性能计算环境中的表现

用户利用高性能计算 (HPC) 先进的计算技术来执行大规模的复杂计算任务。这有助于短时间内解决复杂问题&#xff0c;与传统计算方法相比遥遥领先。Infortrend 存储解决方案专门针对密集型 HPC 工作负载进行了优化。新推出的U.2 NVMe全闪存储阵列GS 5024UE在0.3毫秒的延迟下提供1…

镭速助力企业集成OIDC实现安全高效的大文件数据传输

在当今数字化时代&#xff0c;企业尤其是科研机构、研究所和实验室等&#xff0c;对于大量敏感数据的传输安全和效率有着日益增长的需求。面对这一挑战&#xff0c;企业需要一种既能保障数据传输安全&#xff0c;又能提高传输效率的解决方案。镭速&#xff0c;作为一款面向企业…

【C++学习】C++4种类型转换详解

这里写目录标题 &#x1f680;C语言中的类型转换&#x1f680;为什么C需要四种类型转换&#x1f680;C强制类型转换&#x1f680;static_cast&#x1f680;**reinterpret_cast**&#x1f680;const_cast与volatile&#x1f680;dynamic_cast &#x1f680;C语言中的类型转换 在…

buuctf——[ZJCTF 2019]NiZhuanSiWei

buuctf——[ZJCTF 2019]NiZhuanSiWei 1.绕过file_get_contents()函数 file_get_contents函数介绍 定义和用法 file_get_contents() 把整个文件读入一个字符串中。 该函数是用于把文件的内容读入到一个字符串中的首选方法。如果服务器操作系统支持&#xff0c;还会使用内存映射…

【opencv】示例-videocapture_starter.cpp 从视频文件、图像序列或连接到计算机的摄像头中捕获帧...

/** * file videocapture_starter.cpp * brief 一个使用OpenCV的VideoCapture与捕获设备&#xff0c;视频文件或图像序列的入门示例 * 就像CV_PI一样简单&#xff0c;对吧&#xff1f; * * 创建于: 2010年11月23日 * 作者: Ethan Rublee * * 修改于: 2013年4月17日 * …

【笔记】十分钟学会正确的github工作流,和开源作者们使用同一套流程

对视频十分钟学会正确的github工作流&#xff0c;和开源作者们使用同一套流程的记录&#xff0c;方便自己回顾和使用。 注1&#xff1a;一个分支 只有一个人在进行 注2&#xff1a; main和master是不同时期对主分支的命名&#xff0c;两者是同一个东西。如果项目已经有了&#…

将MySQL数据库导入到EA模型的教程

将MySQL数据库导入到EA 1.下载安装mysql-connector-odbc2.在管理工具中新增ODBC数据源3.在EA中新建项目4.链接MYSQL数据源4.1 安装64位的ODBC驱动可能出现”在连接ODBC 时发生错误&#xff0c;请相关检查设置“的提示&#xff0c;卸载后重新安装32位ODBC驱动后可以正常执行 5.导…

【1577】java网吧收费管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java 网吧收费管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5.0…

唠一唠,氮化镓和普通快充的不同

充电器,对于手机、电脑、平板等电子产品来说,就是“生命之源”。没有它,这些设备就像鱼儿离开了水。 过去数十年,充电器市场一直在不断变革。从最早的座插式,到后来的一体式、万能充、分离式直充、快充、无线充,再到现在的氮化镓充电器,技术的更新换代快得让人眼花缭乱。 那么…

SQL Server 2022 安装及使用

SQL Server 2022 前言一、安装SQL Server 2022下载SQL Server 2022安装SQL Server 2022配置SQL Server 2022 二、安装SQL Server Management Studio下载SQL Server Management Studio安装SSMS-Setup-CHS 三、使用SQL Server 2022四、解决连接到服务器报错问题 前言 SQL Serve…

Jmeter 性能-内存溢出问题定位分析

1、堆内存溢出 ①稳定性压测一段时间后&#xff0c;Jmeter报错&#xff0c;日志报&#xff1a; java.lang.OutOfMemoryError.Java heap space ②用jmap -histo pid命令dump堆内存使用情况&#xff0c;查看堆内存排名前20个对象。 看是否有自己应用程序的方法&#xff0c;从…

C++Primer3.2 标准类型string

文章目录 初始化string对象读写string对象string的empty和size操作不同string对象的比较string的加法处理string对象的字符遍历string中的每个字符 初始化string对象 //string的初始化 void test01() { string s1; // 默认初始化&am…

YOLOv8改进 | Conv篇 | CVPR2024最新DynamicConv替换下采样(包含C2f创新改进,解决低FLOPs陷阱)

一、本文介绍 本文给大家带来的改进机制是CVPR2024的最新改进机制DynamicConv其是CVPR2024的最新改进机制,这个论文中介绍了一个名为ParameterNet的新型设计原则,它旨在在大规模视觉预训练模型中增加参数数量,同时尽量不增加浮点运算(FLOPs),所以本文的DynamicConv被提出…

AIoT人工智能物联网之deepstream

1.deepstream介绍安装 deepstream是一个很强大的工具集,能够执行数据收集、数据预处理、视频追踪、编码等功能 (1)deepstream docker 版本查询 网页查询 https://catalog.ngc.nvidia.com/containers (2)下载 deepstream docker 对应 版本 https://catalog.ngc.nvidia.c…

【微信公众平台】扫码登陆

文章目录 前置准备测试号接口配置 带参数二维码登陆获取access token获取Ticket拼装二维码Url编写接口返回二维码接收扫描带参数二维码事件编写登陆轮训接口测试页面 网页授权二维码登陆生成ticket生成授权地址获取QR码静态文件支持编写获取QR码的接口 接收重定向参数轮训登陆接…

正确解决:关于Lattic Diamond和Radiant License冲突问题(无法破解问题)

一、问题 今天工作&#xff0c;搞16nm Avant E系列FPGA&#xff0c;需要用到莱迪思的Radiant 2023.2软件&#xff08;按这个博主的安装流程Lattice Radiant 2023.1 软件安装教程&#xff09;。 安装好之后&#xff0c;设置环境变量&#xff0c;导入License.dat就是破解不了&…

从零开始学习Linux(3)----权限

1.Linux权限的概念 Linux用户&#xff1a;1.root&#xff0c;超级管理员 2.非root&#xff0c;XXX&#xff0c;普通用户 命令&#xff1a;su[用户名] 功能&#xff1a;切换用户。 su -&#xff1a;是指以root的身份重新登录一次。 普通用户切换root需要输入密码&#xff0c;…