手搓 自然语言模型 各种对比数据

news2024/10/5 21:15:38

基础模型和设计思想
最优网络结构

import paddle
import numpy as np
from tqdm import  tqdm
class EmMask(paddle.nn.Layer):
    def __init__(self, voc_size=19, hidden_size=256, max_len=48):
        super(EmMask, self).__init__()
        # 定义输入序列和标签序列

        self.embedding_layer = paddle.nn.Embedding(voc_size, hidden_size)
        self.pos_em_layer = paddle.nn.Embedding(max_len, hidden_size)
        self.pos_to_down = paddle.nn.Linear(hidden_size, 1)
        self.sample_buffer_data=paddle.zeros([1])


    # 定义模型计算过程
    def forward(self, x):
        # 将输入序列嵌入为向量表示
        embedded_x = self.embedding_layer(x)  # bs--->bsh
        # embedded_x  += paddle.fft.fft(embedded_x, axis=1).real()
        # embedded_p 有权重 后期预测的时候就要参与 这样会造成计算量增加 如果使用 1 代替 减少多样性
        # 但是使用pos 是 对于任何输入是固定的可以事先弄好的可以事先计算,一个固定的w 而已
        # 而当前的attention 这个参数是动态的,要通过其他方法来实现动态的 比如scale 多头等
        # 当前这种方式全靠 开头和结尾 中间固定参数哦 如果使用多个 加上softmax 那么就能完成多头scale 的操作了
        embedded_p = self.pos_em_layer(paddle.arange(1, x.shape[1] + 1).astype("int64"))
        embedded_p = self.pos_to_down(embedded_p)

        xp = embedded_x.transpose([0, 2, 1]).unsqueeze(3) @ embedded_p.transpose([1, 0])
        # mask
        mask = paddle.triu(paddle.ones([xp.shape[-1], xp.shape[-1]]))
        x = xp * mask
        return x



class JustMaskEm(paddle.nn.Layer):
    def __init__(self, voc_size=19, hidden_size=512, max_len=1024):
        super(JustMaskEm, self).__init__()
        # 定义输入序列和标签序列

        self.em_mask_one = paddle.nn.Embedding(voc_size, hidden_size)
        self.em_mask_two = EmMask(voc_size, hidden_size, max_len)
        self.head_layer = paddle.nn.Linear(hidden_size, voc_size,bias_attr=False)
        self.layer_nor = paddle.nn.LayerNorm(hidden_size)



    # 定义模型计算过程
    def forward(self, x):
        one = self.em_mask_one(x)
        two = self.em_mask_two(x)
        x = one* paddle.sum(two, -2).transpose([0,2,1])

        # x = paddle.sum(x, -2)
        # x=x.transpose([0, 2, 1])
        # x = self.head_layer(self.layer_nor(x))

        x = self.head_layer(self.layer_nor(x))

        return x





# 进行模型训练和预测
# if __name__ == '__main__':
#     net = JustMaskEm()
#     X = paddle.to_tensor([
#         [1, 2, 3, 4],
#         [5, 6, 7, 8]
#     ], dtype='int64')
#     print(net(X).shape)
#     print(net.sample_buffer(X).shape)

#
def train_data():
    net = JustMaskEm(voc_size=len(voc_id))
    net.load_dict(paddle.load("long_attention_model"))
    print("加载成功")
    opt = paddle.optimizer.Adam(parameters=net.parameters(), learning_rate=0.0003)
    loss_f = paddle.nn.CrossEntropyLoss()
    loss_avg = []
    acc_avg = []
    batch_size = 1000*3
    bar=tqdm(range(1, 3 * 600))
    for epoch in bar:
        np.random.shuffle(data_set)
        for i, j in [[i, i + batch_size] for i in range(0, len(data_set), batch_size)]:
            one_data = data_set[i:j]
            if (len(acc_avg) + 1) % 1000 == 0:
                # print(np.mean(loss_avg), "____", np.mean(acc_avg))
                paddle.save(net.state_dict(), "long_attention_model")
                paddle.save({"data": loss_avg}, "loss_avg")
                paddle.save({"data": acc_avg}, "acc_avg")

            one_data = paddle.to_tensor(one_data)
            in_put = one_data[:, :-1]
            label = one_data[:, 1:]
            # label = one_data[:, 1:]

            out = net(in_put)
            loss = loss_f(out.reshape([-1, out.shape[-1]]), label.reshape([-1]).astype("int64"))
            acc = np.mean((paddle.argmax(out, -1)[:, :].reshape([-1]) == label[:, :].reshape([-1])).numpy())
            # loss = loss_f(out, label.reshape([-1]).astype("int64"))
            # acc = np.mean((paddle.argmax(out, -1) == label.reshape([-1])).numpy())
            loss_data = loss.numpy()[0]
            acc_avg.append(acc)
            loss_avg.append(loss_data)
            bar.set_description(desc="{}{}{}{}{}".format(epoch, "____", np.mean(loss_avg), "____", np.mean(acc_avg)))
            opt.clear_grad()
            loss.backward()
            opt.step()
            if np.mean(acc_avg) > 0.80:
                opt.set_lr(opt.get_lr() / (np.mean(acc_avg) * 100 + 1))
    print(np.mean(loss_avg), "____", np.mean(acc_avg))
    paddle.save(net.state_dict(), "long_attention_model")
    paddle.save({"data": loss_avg}, "loss_avg")
    paddle.save({"data": acc_avg}, "acc_avg")

if __name__ == "__main__":
    with open("poetrySong.txt", "r", encoding="utf-8") as f:
        data1 = f.readlines()
    data1 = [i.strip().split("::")[-1] for i in data1 if len(i.strip().split("::")[-1]) == 32]
    voc_id = ["sos"] + sorted(set(np.hstack([list(set(list("".join(i.split())))) for i in data1]))) + ["pad"]
    data_set = [[voc_id.index(j) for j in i] for i in data1]
    train_data()

实验对比数据
在这里插入图片描述
两种基本网络结构设计
在这里插入图片描述
在这里插入图片描述

总结

从上面实验数据可知 在使用方案 二的时候 ,如代码写 不断的扩大维度方可提高收敛时候的acc 上限且最高

且该网络模型可以在推理的时候如最后一幅图所示可以,进行单独解码 从而节约算力。

注意:
后面两幅图中 带框的两个是两个不同的方案,不带框的是公共部分
经过测试抛弃了蓝色框的方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/817375.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity3d C#快速打开萤石云监控视频流(ezopen)支持WebGL平台,替代UMP播放视频流的方案(含源码)

前言 Universal Media Player算是视频流播放功能常用的插件了,用到现在已经不知道躺了多少坑了,这个插件虽然是白嫖的,不过被甲方和领导吐槽的就是播放视频流的速度特别慢,可能需要几十秒来打开监控画面,等待的时间较…

我的第一个前端(VS code ,Node , lite-server简易服务器,npm 运行)

第一种方式:使用Visual Studio Code创建并运行 第一个前端项目的步骤,如下: 1. 下载和安装Visual Studio Code: 访问Visual Studio Code官方网站(Visual Studio Code - Code Editing. Redefined)并根据你…

二十三种设计模式第二十篇--备忘录模式

备忘录模式,备忘录模式属于行为型模式。它允许在不破坏封装的情况下捕获和恢复对象的内部状态。保存一个对象的某个状态,以便在适当的时候恢复对象,该模式通过创建一个备忘录对象来保存原始对象的状态,并将其存储在一个负责管理备…

133. 克隆图

给你无向 连通 图中一个节点的引用,请你返回该图的 深拷贝(克隆)。 图中的每个节点都包含它的值 val(int) 和其邻居的列表(list[Node])。 class Node { public int val; public List&…

互联网医院系统开发:打造便捷高效的医疗服务平台

随着互联网技术的飞速发展,互联网医院系统的出现为医疗行业带来了许多新的机遇和优势。互联网医院系统是一种基于互联网技术的医疗服务平台,旨在提供便捷、高效、个性化的医疗服务。下面将介绍互联网医院系统开发的优势。   提供便捷的医疗服务&#x…

【模仿学习】:离线和在线模仿

一、说明 模仿学习(Imitation Learning )是机器学习的一种,代理通过观察和模仿专家的行为来学习。在这种方法中,为代理提供了一组所需行为的演示或示例,并通过尝试复制专家的行为来学习输入观察和输出操作之间的映射。…

【单机多卡】torch改造代码为DDP单机多卡分布式并行

torch分布式数据并行DDPtorch.nn.parallel.DistributedDataParallel代码修改记录。(要求pytorch_version>1.0) 目录 1.🍄🍄要修改的地方概览 2.✏️✏️初始化 3.✏️✏️设置当前进程GPU 4.✏️✏️设置sampler 5.✏️✏…

HTML笔记(1)

介绍 浏览器中内置了HTML的解析引擎,通过解析标记语言来展现网页;HTML标签都是预定义好的;Java工程师:后台代码的编写,和数据库打交道,把数据给网页前端的工程师;网页前端工程师:写H…

拯救者Y9000K无线Wi-Fi有时不稳定?该如何解决?

由于不同品牌路由器的性能差异,无法完美兼容最新的无线网卡技术,在连接网络时(特别是网络负载较大的情况下),可能会出现Wi-Fi信号断开、无法网络无法访问、延迟突然变大的情况;可尝试下面方法进行调整。 1…

go 如何知道一个对象是分配在栈上还是堆上?

如何判断变量是分配在栈(stack)上还是堆(heap)上? Go和C不同,Go局部变量会进行逃逸分析。如果变量离开作用域后没有被引用,则优先分配到栈上,否则分配到堆上。判断语句:…

Stable Doodle:Stability AI推出的一款零门槛AI绘画神器

Stable Doodle是由Stability AI推出的一款零门槛AI绘画神器,可以将简单的草图转化为精美的图像。它可以将随手的塗鴉草稿转化为高畫質的完成圖,让用户能够以更快的速度将想法转化为精美的艺术作品。Stable Doodle利用最新的Stable Diffusion模型&#xf…

智能车域控制器设计

摘要: 本文主要针对ADCU从硬件设计到软件设计的开发流程进行详细阐述,主要包含了需求场景、关键硬件电路、电路可靠性、AUTOSAR架构、CAN通信简介、CAN通信软件设计等。最后基于以上硬件技术和软件技术开发出一款产品级智能驾驶域控制器。 // 智能驾驶域控制器研究现状 //…

iOS开发-实现自定义Tabbar及tabbar按钮动画效果

iOS开发-实现自定义Tabbar及tabbar按钮动画效果 之前整理了一个继承UITabbarController的Tabbar效果 查看 https://blog.csdn.net/gloryFlow/article/details/132012628 这里是继承与UIViewController的INSysTabbarViewController实现及点击tabbar按钮动画效果。 一、INSysT…

学习记录——TransNormerLLM

Scaling TransNormer to 175 Billion Parametes 线性注意力的Transformer大模型 2023 Transformer 存在局限。首要的一点,它们有着对于序列长度的二次时间复杂度,这会限制它们的可扩展性并拖累训练和推理阶段的计算资源和时间效率。 TransNormerLLM 是首…

中小企业如何低成本实施MES管理系统

中小企业在市场竞争中需要有高效的管理体系来支持其运营和发展。中小企业MES管理系统是一种先进的管理系统,可以提升工厂智能化水平,提高生产效率,是中小企业必须采取的有效管理工具。然而,由于资金和技术的限制,中小企…

Java API指南:掌握常用工具类与字符串操作

文章目录 1. API简介2. Java API的使用2.1 创建和使用Java API工具类2.2 使用String类进行字符串操作 结语 导语: Java作为一门功能强大的编程语言,其成功之处不仅在于语法结构的简洁明了,更因为其丰富的API(Application Programm…

面向对象中的多态性

一、权限修饰符 public, 缺省, protected,private 二、this和super关键字 this:表示当前对象 super:表示父类声明的成员 原则:遵循就近原则和追根溯源原则。 三、Object类 java.lang.Object类是所有java类的超类,即所有的J…

微信小程序测试要点

一、什么是小程序? 可以将小程序理解为轻便的APP,不用安装就可以使用的应用。用户通过扫一扫或者搜索的方式,就可以打开应用。 小程序最主要的特点是内嵌于微信之中,而使用小程序的目的是为了能够方便用户不在受下载多个APP的烦…

更好搭建负载测试环境的六个技巧

如果你如我昨天谈到的客户一样,花费了24到48个小时用于每个负载测试环境的搭建,那你的测试及构建部署能力绝对是受限的。 搭建一个仿真测试环境对于做好负载测试非常重要,同时它也是一个非常具有挑战性的任务,需要考虑技术解决、…

2023 7-31

题目1 寻找不同二叉树两节点的公共祖先 递归解法 仔细看这个解法更加容易理解: l、r 非空时,说明 p、q 分居 root 的两侧,root 就是 LCAl、r 任一为空,说明 LCA 位于另一子树或其祖先中代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* …