Just Mask and Sum 手搓 自然语言模型

news2025/1/19 8:17:52

背景

在这个每天都能看到,各种新LLM论文,出现的今天,大家讨论的都是如何将transformer ,或者说是将attention 进行线性化。
很少有人讨论,注意力机制是必要的吗(attention is must)? 但是证明attention 的必要性,可能超出了个人算力。
如果我们拿的出一个小的反例,是不是能证明,attention(只针对当前的LLM的实现方式的attention),是非必要的。

正文

首先明确主要任务是拿出一个反例,证明实现LLM,attention 不是唯一的途径(这里不谈rnn,lstm,或者rwkv)。
那么,开始推演:

1,要预测一个字符串【ABC】后面接的是哪个字符,且这个字符的唯一确定性,或者说是某个字符的概率大小,绝对不会是N-Gram那样的假设,只和前一个字符有关。而是和前n个字符有关,n越长是某个字符的概率越大(这里指的是前n个字符都是利好是某个字符的)。
2,实现1 的假设并不难,但是太长绝对是算力爆炸,而N-Gram为的就是省算力,而我们暂时假设,这个能够确定是某个字符的概率的序列是有限长的。
3,基于上面的两个假设,那么假设ABC 如何计算后面的字符出现的概率:实现方法

上面的实现方法已经简单的证明了假设有一定的能力,预测下一个字符,但是,基于统计的模型,无论是从模型大小还是算力上都是更大的问题。
故而要转为一个函数的形式

可以任务转化为函数的形式就是信息压缩或者,不想不断地记录,就如同已知y=ax+b 就能知道 x是任意值的时候y
的值。只要记住,少量信息,就能知道y的所有值。可是说变为函数,是一种信息压缩。而建模,则是实现信息压缩(函数化)的主要方法。

而传统的建模方法是人为的根据信息,设定几个函数方程,不断地去测试实现。
而今天人们用神经网络替代任何的函数,不断地去调整参数得到函数。
可以将神经网络看做是一个万能函数。

接下来要实现,将上面统计模型,变为万能函数,神经网络(LLM自然语言模型)。

当然我们不使用当前的任何自然语言模型,那么有以下几种方法
1,直接输入voct 预测voc 而后根据概率模型推理方法进行预测。

经过验证或者说根本没有达到验证的地步,基本就凉了,或者说,本人不太喜欢,或者是选择语言模型是当前的任何语言模型,或者说是后期推理更消耗算力,但是有一个最大的问题是这样做,同样面临的问题是大的数据量,导致算力需求爆炸。没错被发现了,个人算力的确实现不了。

2,采取mask and sum 的方法 实现并行,具体大致如图。
在这里插入图片描述

import paddle
class EmMask(paddle.nn.Layer):
    def __init__(self, voc_size=19, hidden_size=256, max_len=48):
        super(EmMask, self).__init__()
        # 定义输入序列和标签序列

        self.embedding_layer = paddle.nn.Embedding(voc_size, hidden_size)
        self.pos_em_layer = paddle.nn.Embedding(max_len, hidden_size)
        self.pos_to_down = paddle.nn.Linear(hidden_size, 1)
        self.sample_buffer_data=paddle.zeros([1])


    # 定义模型计算过程
    def forward(self, x):
        # 将输入序列嵌入为向量表示
        embedded_x = self.embedding_layer(x)  # bs--->bsh
        # embedded_x  += paddle.fft.fft(embedded_x, axis=1).real()
        # embedded_p 有权重 后期预测的时候就要参与 这样会造成计算量增加 如果使用 1 代替 减少多样性
        # 但是使用pos 是 对于任何输入是固定的可以事先弄好的可以事先计算,一个固定的w 而已
        # 而当前的attention 这个参数是动态的,要通过其他方法来实现动态的 比如scale 多头等
        # 当前这种方式全靠 开头和结尾 中间固定参数哦 如果使用多个 加上softmax 那么就能完成多头scale 的操作了
        embedded_p = self.pos_em_layer(paddle.arange(1, x.shape[1] + 1).astype("int64"))
        embedded_p = self.pos_to_down(embedded_p)

        xp = embedded_x.transpose([0, 2, 1]).unsqueeze(3) @ embedded_p.transpose([1, 0])
        # mask
        mask = paddle.triu(paddle.ones([xp.shape[-1], xp.shape[-1]]))
        x = xp * mask
        return x






class JustMaskEm(paddle.nn.Layer):
    def __init__(self, voc_size=19, hidden_size=256, max_len=1024):
        super(JustMaskEm, self).__init__()
        # 定义输入序列和标签序列

        self.em_mask_one = EmMask(voc_size, hidden_size, max_len)
        self.em_mask_two = EmMask(voc_size, hidden_size, max_len)
        self.head_layer = paddle.nn.Linear(hidden_size, voc_size,bias_attr=False)
        self.layer_nor = paddle.nn.LayerNorm(hidden_size)



    # 定义模型计算过程
    def forward(self, x):
        one = self.em_mask_one(x)
        two = self.em_mask_two(x)
        x = one* paddle.sum(two, -2).unsqueeze(3)

        x = paddle.sum(x, -2)
        x=x.transpose([0, 2, 1])

        x = self.head_layer(self.layer_nor(x))

        return x





# 进行模型训练和预测
if __name__ == '__main__':
    net = JustMaskEm()
    X = paddle.to_tensor([
        [1, 2, 3, 4],
        [5, 6, 7, 8]
    ], dtype='int64')
    print(net(X).shape)
    print(net.sample_buffer(X).shape)

#
# def train_data():
#     net = JustMaskEm(voc_size=len(voc_id))
#     net.load_dict(paddle.load("long_attention_model"))
#     print("加载成功")
#     opt = paddle.optimizer.Adam(parameters=net.parameters(), learning_rate=0.0003)
#     loss_f = paddle.nn.CrossEntropyLoss()
#     loss_avg = []
#     acc_avg = []
#     batch_size = 1000*3
#     for epoch in range(1, 3 * 600):
#         np.random.shuffle(data_set)
#         for i, j in [[i, i + batch_size] for i in range(0, len(data_set), batch_size)]:
#             one_data = data_set[i:j]
#             if (len(acc_avg) + 1) % 1000 == 0:
#                 print(np.mean(loss_avg), "____", np.mean(acc_avg))
#                 paddle.save(net.state_dict(), "long_attention_model")
#                 paddle.save({"data": loss_avg}, "loss_avg")
#                 paddle.save({"data": acc_avg}, "acc_avg")
#
#             one_data = paddle.to_tensor(one_data)
#             in_put = one_data[:, :-1]
#             label = one_data[:, 1:]
#             # label = one_data[:, 1:]
#
#             out = net(in_put)
#             loss = loss_f(out.reshape([-1, out.shape[-1]]), label.reshape([-1]).astype("int64"))
#             acc = np.mean((paddle.argmax(out, -1)[:, :].reshape([-1]) == label[:, :].reshape([-1])).numpy())
#             # loss = loss_f(out, label.reshape([-1]).astype("int64"))
#             # acc = np.mean((paddle.argmax(out, -1) == label.reshape([-1])).numpy())
#             loss_data = loss.numpy()[0]
#             acc_avg.append(acc)
#             loss_avg.append(loss_data)
#             print(epoch, "____", np.mean(loss_avg), "____", np.mean(acc_avg))
#             opt.clear_grad()
#             loss.backward()
#             opt.step()
#             if np.mean(acc_avg) > 0.80:
#                 opt.set_lr(opt.get_lr() / (np.mean(acc_avg) * 100 + 1))
#     print(np.mean(loss_avg), "____", np.mean(acc_avg))
#     paddle.save(net.state_dict(), "long_attention_model")
#     paddle.save({"data": loss_avg}, "loss_avg")
#     paddle.save({"data": acc_avg}, "acc_avg")
#
# if __name__ == "__main__":
#     with open("poetrySong.txt", "r", encoding="utf-8") as f:
#         data1 = f.readlines()
#     data1 = [i.strip().split("::")[-1] for i in data1 if len(i.strip().split("::")[-1]) == 32]
#     voc_id = ["sos"] + sorted(set(np.hstack([list(set(list("".join(i.split())))) for i in data1]))) + ["pad"]
#     data_set = [[voc_id.index(j) for j in i] for i in data1]
#     train_data()

后期
完善网络结构
模型参数量与数据参数量的讨论
使用just mask and sum 验证模型参数量和数据量的 关系
几种just mask and sum 网络结构的讨论
应用在其任务上。。。。。。。。。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/813775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CentOS 7.6使用yum安装stress,源码安装stree-ng 0.15.06,源码安装sysstat 12.7.2

cat /etc/redhat-release看到操作系统的版本是CentOS Linux release 7.6.1810 (Core),uname -r可以看到内核版本是3.10.0-957.21.3.el7.x86_64 yum install stress sysstat -y安装stress和sysstat。 使用pidstat -u 5 1没有%wait项: 原因是CentOS 7仓…

数电模电基础知识学习笔记汇总

文章目录: 数电和模电的关系 一:模电学习笔记 二:数电学习笔记 三:福利 1.NI Multisim14.0电路仿真软件的下载安装 2.进制转换 3.电路常用公式 4.好的参考笔记 4.1 笔记 3.1.1 模电 3.1.1 数电 4.2 网站 5.八股文 …

LeetCode130.Surrounded-Regions<被围绕的区域>

题目:被围绕的区域 思路: 好吧,这题不会。 bfs递归 dfs非递归 dfs并查集 - 被围绕的区域 - 力扣(LeetCode) 将问题转化为与边界O相连的O。有点像岛屿问题了。 代码是: //codeclass Solution { public:vo…

【ARM】内核驱动之设备树的学习-长文

❤️作者主页:凉开水白菜 ❤️作者简介:共同学习,互相监督,热于分享,多加讨论,一起进步! ❤️点赞 👍 收藏 ⭐再看,养成习惯 订阅的粉丝可通过PC端文末加我微信,可对文章的内容进行一对一答疑! 文章目录 一、什么是设备树,为什么叫设备树?二、如何编译设备树?三、…

妙用指针实现qsort

妙用指针实现qsort qsort是什么qsort代码使用例子冒泡排序引言冒泡排序模拟qsort函数 qsort是什么 是一个可以对任意类型进行排序的函数 函数为: void qsort(void *base,size_t nmemb,size_t size,int (*compar)(const void *, const void *));参数解释 参数base …

【数据结构与算法】基数排序

基数排序 基数排序(Radix Sort)属于“分配式排序”,又称“桶子法”或 bin sort,顾名思义,它是通过键值的各个位的值,将要排序的元素分配至某些“桶”中,达到排序的作用。基数排序法是属于稳定性…

学C的第三十一天【通讯录的实现】

相关代码gitee自取:C语言学习日记: 加油努力 (gitee.com) 接上期: 学C的第三十天【自定义类型:结构体、枚举、联合】_高高的胖子的博客-CSDN博客 通讯录需求: 实现一个通讯录, 通讯录中存放保存人的信息&#xff1…

【WebGL】初探WebGL,我了解到这些

WebGL(Web图形库)是一种强大的技术,允许您在Web浏览器中直接创建交互式的3D图形和动画。它利用现代图形硬件的能力来呈现令人惊叹的视觉效果,使其成为Web开发人员和计算机图形爱好者必备的技能。 WebGL基础知识 WebGL基于OpenGL …

1.3 eureka+ribbon,完成服务注册与调用,负载均衡源码追踪

本篇继先前发布的1.2 eureka注册中心,完成服务注册的内容。 目录 环境搭建 采用eurekaribbon的方式,对多个user服务发送请求,并实现负载均衡 负载均衡原理 负载均衡源码追踪 负载均衡策略 如何选择负载均衡策略? 饥饿加载…

数据结构07:查找[C++][线性查找]

图源:文心一言 考研笔记整理~🥝🥝 在数据结构和算法中,查找是一种常见的操作,它的目的是在一个数据集合中找到一个满足条件的元素。本文将介绍三种常用的查找方法,分别是顺序查找、折半查找和分块查找~&a…

61 # http 数据处理

node 中的核心模块 http 可以快速的创建一个 web 服务 const http require("http"); const url require("url");// req > request 客户端的所有信息 // res > respone 可以给客户端写入数据 const server http.createServer();server.on("r…

使用Spring Boot AOP实现日志记录

目录 介绍 1.1 什么是AOP 1.2 AOP体系与概念 AOP简单实现 2.1 新建一个SpringBoot项目,无需选择依赖 2.2 设置好本地Maven配置后,在pom.xml文件里添加添加maven依赖 2.3 创建一个业务类接口 2.4 在实体类实现接口业务 2.5 在单元测试运行结果 …

机器学习--课后作业--hw1

机器学习(课后作业–hw1) 本篇文章全文参考这篇blog 网上找了很多教程,这个是相对来说清楚的,代码可能是一模一样,只是进行了一些微调,但是一定要理解这个模型具体的处理方法,这个模型我认为最巧妙的它对于数据的处理…

【1.4】Java微服务:服务注册和调用(Eureka和Ribbon实现)

✅作者简介:大家好,我是 Meteors., 向往着更加简洁高效的代码写法与编程方式,持续分享Java技术内容。 🍎个人主页:Meteors.的博客 💞当前专栏: 微服务 ✨特色专栏: 知识分享 &#x…

小研究 - JVM GC 对 IMS HSS 延迟分析(一)

用户归属服务器(IMS HSS)是下一代通信网(NGN)核心网络 IP 多媒体子系统(IMS)中的主要用户数据库。IMS HSS 中存储用户的配置文件,可执行用户的身份验证和授权,并提供对呼叫控制服务器…

ARTS Activity -- Using Java

About ARTS - Complete one ARTS per week: ● Algorithm: Do at least one LeetCode algorithm per week Review: Read and comment on at least one technical article in English ● Tips: Learn at least one technical trick ● Share: Share a technical article with op…

1.2 eureka注册中心,完成服务注册

目录 环境搭建 搭建eureka服务 导入eureka服务端依赖 编写启动类,添加EnableEurekaServer注解 编写eureka配置文件 启动服务,访问eureka Euraka服务注册 创建了两个子模块 在模块里导入rureka客户端依赖 编写eureka配置文件 添加Services 环境搭建 创建父…

08-向量的范数_范数与正则项的关系

⛳向量的范数 范数的公式是向量每个分量 绝对值 P 次方 再用幂函数计算 P 分之一,这里 P 肯定是整数 1,2,3…到正无穷都是可以的 向量的范数就是把向量变成一个标量,范数的表示就是两个竖线来表示,然后右下角写上 P&a…

LeetCode36.Valid-Sudoku<有效的数独>

题目: 思路: 这题并不难,它类似于N皇后问题。在N皇后问题中,行,列,对角线,写对角线,都不能出现连续的皇后。 本题类似,不过他是行,列,还有一个B…

【数据结构篇C++实现】- 图

友情链接:C/C系列系统学习目录 文章目录 🚀一、图的基本概念和术语1、有向图和无向图3、基本图和多重图4、完全图5、子图6、连通、连通图和连通分量7、强连通图、强连通分量8、生成树、生成森林9、顶点的度、入度和出度10、边的权和网11、稠密图、稀疏图…