超级ai 必须有个,超级大的词表,必须是个向量库 faiss is all you need

news2024/11/25 6:39:34

  • 说明
  • 优点
  • 图像表示流程
  • 代码实现如下
  • 全部代码

说明

使用极其庞大的词表在模型压缩和图像token化方面带来了显著优势。由于词表巨大,我们不得不利用向量数据库对词表进行搜索,以找到最匹配的token。预测出的token会再次通过嵌入矩阵(em)转换为向量形式,然后从大规模的向量化词表中检索出来。
根据计算,使用16位来表示large_token_id,可以表达每三个像素、三个通道作为一个token的图像。生成的尺寸大小取决于序列的长度。而嵌入矩阵(em)的维度为(1000,h),即40位可以表示每九个三通道像素作为一个token。如果我们的图像数据能够覆盖所有可能性,那么词表将变得极其庞大,以至于即使是阿里云这样的存储巨头也无法容纳,甚至可能超过整个地球上所有存储设备之和。
然而,实际上,随着序列长度的增加,可能性会逐渐减少。因此,16位也有可能覆盖所有信息,将其转换为token。而且,由于推理序列通常不会太长,这种方法在处理实际问题时仍然具有可行性。
总的来说,使用超级大的词表在模型压缩和图像token化方面具有显著优势。通过向量数据库对词表进行搜索,以及将预测出的token再次通过嵌入矩阵转换为向量形式,可以有效地处理大规模的图像数据。尽管词表可能非常庞大,但随着序列长度的增加,可能性逐渐减少,使得这种方法在实际应用中仍然具有可行性。

优点

可以在推理的时候由于em小所以模型很小,推理只需要强大的cpu,和足够的内存磁盘

图像表示流程

在这里插入图片描述

代码实现如下

import paddle
import faiss
from new_model_13 import GPT as GPT13

import pandas as pd
from sklearn.preprocessing import normalize
import json
import math
from collections import Counter
from tqdm import tqdm
import numpy as np



def gen_small_voc():
    num = "0123456789" + 'qwertyuiopasdfghjklzxcvbnm' + "QWERTYUIOPASDFGHJKLZXCVBNM"
    num = list(num)
    small_em_voc = dict()

    voc_id = 0
    for i in range(16):
        for n in num:
            small_em_voc[voc_id] = "{}_{}".format(i, n)
            voc_id += 1
    return small_em_voc


def random_gen_voc():
    num = "0123456789" + 'qwertyuiopasdfghjklzxcvbnm' + "QWERTYUIOPASDFGHJKLZXCVBNM"
    num = list(num)
    p_list = ["{}_{}".format(i, np.random.choice(num)) for i in range(16)]
    return "#".join(p_list)


def gen_text_voc_to_token_id(text, large_em_voc, small_voc_em):
    text = list(text)
    text_list = []
    for ii in text:
        one = large_em_voc.get(ii, None)
        if one is None:
            while True:

                two = random_gen_voc()
                if large_em_voc.get(two, None) is None:
                    large_em_voc[two] = ii
                    large_em_voc[ii] = two
                    two = [small_voc_em.get(i) for i in two.split("#")]
                    text_list.append(two)
                    break
        else:
            two = [small_voc_em.get(i) for i in one.split("#")]
            text_list.append(two)

    return text_list, large_em_voc


def train():
    with open("唐诗.json", "r", encoding="utf-8") as f:
        data = f.read()
    data = json.loads(data)
    data = [i[4].split() for i in data if len(i[4].split()) > 3]
    data = np.hstack(data)
    data = [i for i in data if len("".join(i.split())) == 24 and "a" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "f" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "e" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "h" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "X" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "“" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and '□' not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and '《' not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and '》' not in i]

    small_em_voc = gen_small_voc()
    small_voc_em = {k: v for v, k in small_em_voc.items()}
    large_em_voc = dict()

    model = GPT13(len(small_em_voc), 512, 32, 8)
    # model.load_dict(paddle.load("gpt.pdparams"))
    print("参数量:",
          sum([i.shape[0] * i.shape[-1] if len(i.shape) > 1 else i.shape[-1] for i in model.parameters()]) / 1000000000,
          "B")
    loss_func = paddle.nn.CrossEntropyLoss()
    opt = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=0.0003)

    for epoch in range(190):
        bar = tqdm(range(0, len(data), 1000))
        for i in bar:
            j = i + 1000

            large_data = []
            for one in data[i:j]:
                two, large_em_voc = gen_text_voc_to_token_id(one, large_em_voc, small_voc_em)

                large_data.append(two)

            out, _ = model(paddle.to_tensor(large_data)[:, :-1])
            loss = loss_func(out, paddle.to_tensor(large_data)[:, 1:].reshape([out.shape[0], -1]))
            bar.set_description("epoch___{}__loss__{}".format(epoch, loss.item()))
            opt.clear_grad()
            loss.backward()
            opt.step()
        paddle.save(model.state_dict(), "duo_yang_xing.pkl")
        pd.to_pickle(large_em_voc, "large_em_voc.pkl")
        pd.to_pickle(small_em_voc, "small_em_voc.pkl")


def val():
    with open("唐诗.json", "r", encoding="utf-8") as f:
        data = f.read()
    data = json.loads(data)
    data = [i[4].split() for i in data if len(i[4].split()) > 3]
    data = np.hstack(data)
    data = [i for i in data if len("".join(i.split())) == 24 and "a" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "f" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "e" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "h" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "X" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and "“" not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and '□' not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and '《' not in i]
    data = [i for i in data if len("".join(i.split())) == 24 and '》' not in i]

    small_em_voc = pd.read_pickle("small_em_voc.pkl")
    small_voc_em = {k: v for v, k in small_em_voc.items()}
    large_em_voc = pd.read_pickle("large_em_voc.pkl")

    model = GPT13(len(small_em_voc), 512, 32, 8)
    model.load_dict(paddle.load("duo_yang_xing.pkl"))
    model.eval()

    print("参数量:",
          sum([i.shape[0] * i.shape[-1] if len(i.shape) > 1 else i.shape[-1] for i in model.parameters()]) / 1000000000,
          "B")

    k_list = []

    faiss_index = faiss.IndexFlatIP(8192)

    for k, v in large_em_voc.items():
        if len(k) <= 1:
            # one = paddle.max(
            #     model.embedding(paddle.to_tensor([small_voc_em.get(i) for i in v.split("#")]).reshape([1, -1])), 1)
            one = model.embedding(paddle.to_tensor([small_voc_em.get(i) for i in v.split("#")]).reshape([1, -1]))
            one = one.reshape([1, -1])
            one /= np.linalg.norm(one, axis=-1, keepdims=True)
            faiss_index.add(one)
            k_list.append(k)

   

    word = data[0][:10]
    for _ in range(17):
        two, large_em_voc = gen_text_voc_to_token_id(word, large_em_voc, small_voc_em)
        out, _ = model(paddle.to_tensor(two).unsqueeze(0))
        out = paddle.argmax(out, -1)[:, -16:]
        out_num = [small_em_voc.get(i.item()) for i in out[0]]
        out_voc = large_em_voc.get("#".join(out_num))
        if out_voc is None:
            # out_em = paddle.max(model.embedding(out), 1)
            out_em = model.embedding(out)
            out_em = out_em.reshape([1,-1])
            out_em /= np.linalg.norm(out_em, axis=-1, keepdims=True)

            di,ii=faiss_index.search(out_em,k=10)

            word += k_list[ii[0][0]]
        else:
            word += out_voc
        print(word)



if __name__ == '__main__':
    train()
    val()

全部代码

超级

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1846027.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

短剧片源授权,类目丰富优惠多,抢先一步更新你的短剧系统片库!

前言 如今的短剧作为一种新兴的视听艺术形式&#xff0c;正以其独特的魅力迅速占领市场高地。为了满足广大短剧爱好者和从业者的需求&#xff0c;我们提供短剧片源授权服务&#xff0c;凭借剧场独家提供的丰富片源&#xff0c;助力您轻松更新短剧系统片库&#xff0c;抢占市场…

不见五陵高管墓,无花无酒锄做田

不见五陵高管墓&#xff0c;无花无酒锄做田 Golang 通用代码生成器仙童 2.4.0 电音仙女尝鲜版七已发布&#xff0c;此版本测试修复了 PostgreSQL 数据库自动反射功能。此版本更新修复了前端代码生成器&#xff0c;并修复了前端多对多界面的缺陷。PostgreSQL 的数据库反射功能刚…

安装TensorFlow报错问题ERROR: Failed building wheel for h5py解决

安装TensorFlow报错问题&#xff1a; 安装命令: pip install tensorflow2.12.0 -i https://pypi.tuna.tsinghua.edu.cn/simple Building wheel for h5py (PEP 517) ... error ERROR: Command errored out with exit status 1: command: /usr/bin/python3 /tmp/tmpz0y9yg…

代码生成器技术乱弹五十三,人工智能和通用代码生成器的共同点:Token

代码生成器技术乱弹五十三&#xff0c;人工智能和通用代码生成器的共同点&#xff1a;Token 现在&#xff0c;随着人工智能的快速发展&#xff0c;特别是生成式人工智能的爆火&#xff0c;大家逐渐熟悉了一个概念&#xff0c;Token。我称之为字牌。在生成式人工智能的语境下&a…

【每日刷题】Day72

【每日刷题】Day72 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;每日刷题&#x1f34d; &#x1f33c;文章目录&#x1f33c; 1. 1287. 有序数组中出现次数超过25%的元素 - 力扣&#xff08;LeetCode&#xff09; 2. 993. 二叉树的…

视创云展为企业虚拟展厅搭建,提供哪些功能?

在当下数字化浪潮中&#xff0c;如何为用户创造更富生动性和真实感的展示体验&#xff0c;已成为企业营销策略的核心。借助视创云展的线上虚拟3D企业展厅搭建服务&#xff0c;利用3D空间漫游和VR技术的融合&#xff0c;可以为用户呈现出一个既真实又充满想象力的全景图或三维模…

中央空调水系统安装

冷热水管&#xff1a; 空调冷热水管道的材质应由业主或使用方明确&#xff1a; 1、普通焊接钢管&#xff1b; 2、无缝钢管&#xff1b; 3、镀锌钢管&#xff1b; 4、PP-R管&#xff1b; 5、紫铜管&#xff1b; 6、水管内外表面应光洁、无疵孔、裂缝、结疤、层裂或气泡。…

Python12 列表推导式

1.什么是列表推导式 Python的列表推导式&#xff08;list comprehension&#xff09;是一种简洁的构建列表&#xff08;list&#xff09;的方法&#xff0c;它可以从一个现有的列表中根据某种指定的规则快速创建一个新列表。这种方法不仅代码更加简洁&#xff0c;执行效率也很…

【总线】AXI4第四课时:信号描述

大家好,欢迎来到今天的总线学习时间!如果你对电子设计、特别是FPGA和SoC设计感兴趣&#xff0c;那你绝对不能错过我们今天的主角——AXI4总线。作为ARM公司AMBA总线家族中的佼佼者&#xff0c;AXI4以其高性能和高度可扩展性&#xff0c;成为了现代电子系统中不可或缺的通信桥梁…

05 Pytorch 数据读取 + 二分类模型

05 Pytorch 数据读取 二分类模型05 Pytorch 数据读取 二分类模型05 Pytorch 数据读取 二分类模型 01 数据读取 DataLoader&#xff08;set作为参数&#xff09; 02 Dataset 从哪读&#xff0c;怎么读&#xff1f; 功能&#xff1a;数据从哪里读取&#xff1f; 如何读取…

BEV端到端视觉论文合集|从不同的视角解析BEV感知技术

随着自动驾驶技术的不断发展&#xff0c;基于摄像头的感知系统已成为关键&#xff0c;而Bird’s Eye View (BEV)大模型在其中发挥着重要作用。BEV大模型是一种将摄像头捕捉到的2D图像转换为自上而下视角的3D感知的技术&#xff0c;使得车辆能够更好地理解周围环境。 BEV大模型…

吴恩达机器学习 第三课 week1 无监督机器学习(下)

目录 01 学习目标 02 异常检测算法 2.1 异常检测算法的概念 2.2 基于高斯模型的异常检测 03 利用异常检测算法检测网络服务器的故障 3.1 问题描述 3.2 算法实现 3.3 问题升级 04 总结 01 学习目标 &#xff08;1&#xff09;理解异常检测算法&#xff08;Anomaly Det…

编程精粹—— Microsoft 编写优质无错 C 程序秘诀 06:危险的行业

这是一本老书&#xff0c;作者 Steve Maguire 在微软工作期间写了这本书&#xff0c;英文版于 1993 年发布。2013 年推出了 20 周年纪念第二版。我们看到的标题是中译版名字&#xff0c;英文版的名字是《Writing Clean Code ─── Microsoft’s Techniques for Developing》&a…

Mac安装多个jdk环境(jdk8+jdk17)保姆级

Mac安装多个jdk环境&#xff08;jdk8jdk17&#xff09;保姆级 背景&#xff1a;新机安装开发环境发现需要找很多文章&#xff0c;&#xff0c;&#xff0c;&#xff0c;这里一篇文章安装所有环境 文章目录 Mac安装多个jdk环境&#xff08;jdk8jdk17&#xff09;保姆级&#x1f…

基于springboot实现火车票订票系统项目【项目源码+论文说明】

基于springboot实现火车票订票系统演示 摘要 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff0c;在计算机上安装火车票订票系统软件来…

【SpringCloud】Eureka的简单使用

本文使用的是jdk17&#xff0c;mysql8。 以下用两个服务做演示&#xff1a; 订单服务&#xff1a;提供订单ID&#xff0c;获取订单详细信息。 商品服务&#xff1a;提供商品ID&#xff0c;获取商品详细信息。 对于上篇http://t.csdnimg.cn/vcWpo 订单服务调用商品服务的时候&a…

一文读懂 HTTP 和 RPC 的区别

随着互联网技术的发展&#xff0c;网络通信在各种应用中扮演着至关重要的角色。无论是构建 Web 应用还是进行服务之间的交互&#xff0c;选择合适的通讯协议成为开发者们需要深入思考的问题。在众多协议中&#xff0c;HTTP&#xff08;HyperText Transfer Protocol&#xff09;…

JavaSE 面向对象程序设计进阶 抽象类和接口 2024年详解

目录 抽象类 抽象方法 抽象类和抽象方法的注意事项 ​编辑 接口 如何定义接口 注意 代码实现 ​编辑 接口中的成员特点 接口和类之间的关系 1.类与类的关系 2.类与接口的关系 3.接口与接口的关系 ​编辑 拓展 接口中的默认方法 接口中的静态方法 ​编辑 接口…

全新升级微信分销商城小程序源码系统 前后端分离 带完整的安装代码包以及搭建部署教程

系统概述 微信分销商城小程序源码系统是基于先进的技术和理念开发而成的。它旨在为企业和商家打造一个功能齐全、用户体验良好的分销平台&#xff0c;帮助他们更好地管理商品、销售渠道和用户关系&#xff0c;实现业务的快速增长和持续发展。 代码示例 系统特色功能一览 1.多…

TikTok API接口——获取TikTok用户QRcode二维码

一、引言 在数字化时代&#xff0c;QRcode二维码已经成为连接线上线下的重要桥梁。在社交媒体领域&#xff0c;TikTok作为短视频领域的佼佼者&#xff0c;用户量庞大且活跃度高。为了满足用户之间更便捷的互动需求&#xff0c;我们特别开发了一款针对TikTok平台的接口&#xf…