samout llm解码 幻觉更低更稳定

news2024/12/18 20:00:10

这段代码定义了一个简单的对话生成系统,包括模型加载、词汇表加载、以及基于给定提示生成文本的功能。下面是对代码的解析:

  1. load_model_and_voc(device="cpu"):

    • 该函数用于加载预训练的模型和词汇表(vocabulary)。它首先从文件 total_voc.pkl 中加载词汇表,并创建一个名为 SamOut 的神经网络实例。
    • 模型参数的数量被打印出来以供参考。
    • 然后尝试加载指定路径下的预训练权重到模型中,并将模型移动到指定的设备(CPU 或 GPU)上。
    • 最后设置模型为评估模式(.eval()),并返回模型和词汇表。
  2. gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu"):

    • 这个函数负责根据提供的提示(prompt)生成新的文本序列。
    • 它接受多个参数,包括词汇表、模型、初始提示、最大生成长度等。
    • 函数内部实现了重复抑制、温度调整和top-k采样等技术来控制生成文本的质量。
    • 使用softmax函数对模型输出进行处理,并通过多类别抽样选择下一个token。
    • 如果生成了特殊的开始标记 <|sos|>,则停止生成过程。
    • 生成的每个token会立即打印在屏幕上,形成即时响应的效果。
  3. t_infre():

    • 此函数是交互式推理循环,允许用户输入文本,然后调用 gen_token 函数来生成回应。
    • 它是一个无限循环,持续等待用户的输入直到程序被手动终止。
  4. if __name__ == '__main__':

    • 这部分代码确保当脚本作为主程序运行时,会执行某些特定的操作或测试。
    • 注释掉的代码可能是之前用于数据预处理、训练或其他实验的部分。
    • 最终调用了 t_infre() 函数来启动交互式推理。

需要注意的是,这里使用的 SamOut 类并没有在给出的代码片段中定义,因此你可能需要确保这个类已经被正确实现并在其他地方导入。此外,为了使代码能够正常工作,你需要确保所有依赖库(如 PyTorch 和 pandas)已经安装,并且所有提及的数据文件和模型权重文件都存在于正确的路径下。

def load_model_and_voc(device="cpu"):
    voc = pd.read_pickle("total_voc.pkl")

    net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)
    # net = SamOut(len(voc["voc"]), 512, 32, 8)
    print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum(
        [i.shape[0] for i in net.parameters() if len(i.shape) == 1]))

    # net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))
    # net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))
    net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))
    # net.load_state_dict(torch.load("pretrain.pth", map_location=device))
    net.to(device)
    net.eval()
    return net, voc


def gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cpu"):
    print("agent:", end="", flush=True)

    for _ in range(max_len):

        prompt_list = []
        for i in prompt:
            if i not in voc["voc"]:
                prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]
            else:

                prompt_list.append(voc["voc"].index(i))
        out, _ = model(torch.Tensor([prompt_list]).to(device).long())
        out = out[:, -1:]
        # 重复抑制
        for token_id in enumerate(prompt_list):
            out[:, :, token_id] /= rp
        score = torch.softmax(out, -1)[0, 0]
        score, score_index = torch.sort(score,descending=True)
        score=score.detach().numpy()
        score_sum = np.cumsum(score)
        score_index = score_index.detach().numpy()
        score1=score[score_sum<0.8]
        if score1.size==0:
            score=score[:1]
        else:
            score=score1
        score_index=score_index[:score.size]



        out = score / temp

        v= out[:min(top_k, score.size)]



        idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)
        if voc["voc"][score_index[idx_next.item()]] == "<|sos|>":
            break
        prompt += [voc["voc"][score_index[idx_next.item()]]]
        print(prompt[-1], end="", flush=True)


def t_infre():
    model, voc = load_model_and_voc()
    while True:
        text = input("user:")
        gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)
        print()


if __name__ == '__main__':
    # print(pd.read_pickle("loss916"))
    # gen_one_voc()
    # gen_voc()
    # for i in range(17,18):
    #     gen_pre_data_align(i, 16)

    # train()
    # gen_sft_single_data_align()
    # train_single()
    # sft 推理  一本正经的胡说八道已练成

    t_infre()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2261761.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

游戏引擎学习第43天

仓库 https://gitee.com/mrxiao_com/2d_game 介绍运动方程 今天我们将更进一步&#xff0c;探索运动方程&#xff0c;了解真实世界中的物理&#xff0c;并调整它们&#xff0c;以创建一种让玩家感觉愉悦的控制体验。这并不是在做一个完美的物理模拟&#xff0c;而是找到最有趣…

【已解决】启动此实时调试器时未使用必需的安全权限。要调试该进程,必须以管理员身份运行此实时调试器。是否调试该进程?

【已解决】启动此实时调试器时未使用必需的安全权限。要调试该进程&#xff0c;必须以管理员身份运行此实时调试器。是否调试该进程? 目录一、前言二、具体原因三、解决方法 目录 报错截图 一、前言 进行应用程序开发时&#xff0c;需要对w3wp进行附加调试等场景&#xff…

idea无法识别文件,如何把floder文件恢复成model

前景&#xff1a; 昨天&#xff0c;我在之前的A1214模块包下新增了一个demo类&#xff0c;然后又新建了一个A1216模块&#xff0c;写了算法题&#xff0c;后面打算用git提交&#xff0c;发现之前的A1214模块下的demo类和新建的模块源文件都已经被追踪了&#xff0c;都是绿色的&…

2024三掌柜赠书活动第三十六期:深度学习高手笔记系列

目录 前言 理解深度学习基础 数据预处理技巧 关于《深度学习高手笔记》 编辑推荐 内容简介 作者简介 图书目录 媒体评论 《深度学习高手笔记》全书速览 结束语 前言 不用多讲&#xff0c;近两年的技术圈关于AI相关的技术讨论层出不穷&#xff0c;而深度学习作为人工…

【技术干货】移动SDK安全风险及应对策略

移动SDK&#xff08;软件开发工具包&#xff09;已经成为应用开发中不可或缺的一部分。通过SDK&#xff0c;开发者能够快速集成分析、广告调度、音视频处理、社交功能和用户身份验证等常见功能&#xff0c;而无需从零开始构建。这不仅能节省时间和资源&#xff0c;还能提高开发…

【一文概述】常见的几种内外网数据交换方案介绍

一、内外网数据交换的核心需求 内外网数据交换的需求核心在于“安全、效率、合规”&#xff0c;而应用场景的多样性使得不同企业需要定制化的解决方案。通过结合业务特性和安全等级要求&#xff0c;企业能够选择适合的技术方案来实现高效、安全的内外网数据交换。 1、数据安全…

【Linux 篇】Docker 容器星河与镜像灯塔:Linux 系统下解锁应用部署奇幻征程

文章目录 【Linux 篇】Docker 容器星河与镜像灯塔&#xff1a;Linux 系统下解锁应用部署奇幻征程前言一 、docker上部署mysql1. 拉取mysql镜像2. 创建容器3. 远程登录mysql 二 、docker上部署nginx1. 拉取nginx镜像2. 在dockerTar目录下 上传nginx.tar rz命令3. 创建nginx容器4…

Pytorch | 从零构建Vgg对CIFAR10进行分类

Pytorch | 从零构建Vgg对CIFAR10进行分类 CIFAR10数据集Vgg网络结构特点性能应用影响 Vgg结构代码详解结构代码代码详解特征提取层 _make_layers前向传播 forward 训练和测试训练代码train.py测试代码test.py训练过程和测试结果 代码汇总vgg.pytrain.pytest.py 前面文章我们构建…

实战 | 某院校小程序记录

更多大厂面试经验的视频分享看主页和专栏 目录&#xff1a; 前言&#xff1a; 渗透思路 1.绕过前端 2.信息泄露 3.爆破用户账号密码 4.信息泄露2 结束 前言&#xff1a; 遇到一个学校小程序的站点&#xff0c;只在前端登录口做了校验&#xff0c;后端没有任何校验&#x…

k8s kubernetes

文章目录 CGroupk8s运行时k8s组件k8s组件安装kubeadm命令kubectl命令k8s官网代码 CGroup 在 Linux 上&#xff0c;控制组&#xff08;CGroup&#xff09;用于限制分配给进程的资源。kubelet 和底层容器运行时都需要对接控制组来强制执行 为 Pod 和容器管理资源 并为诸如 CPU、…

uniapp中vuex(全局共享)的应用

一、Vuex概述 1.1 官方解释 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。 它采用集中式存储管理 应用的所有组件的状态&#xff0c;并以相应的规则保证状态以一种可预测的方式发生变化 - Vuex 也集成到 Vue 的官方调试工具 devtools extension&#xff0c;提供了诸…

React简单入门 - [Next.js项目] - 页面跳转、AntD组件、二级目录等

须知 1Next.js 官网(英文)Next.js by Vercel - The React Framework2Next.js 文档(中文)简介 | Next.js 中文文档3React官网(中文)https://react.docschina.org/learn4Ant Design组件总览组件总览 - Ant Design5tailwindcss类名大全 官网英Justify Content - TailwindCS…

【十进制整数转换为其他进制数——短除形式的贪心算法】

之前写过一篇用贪心算法计算十进制转换二进制的方法&#xff0c;详见&#xff1a;用贪心算法计算十进制数转二进制数&#xff08;整数部分&#xff09;_短除法求二进制-CSDN博客 经过一段时间的研究&#xff0c;本人又发现两个规律&#xff1a; 1、不仅仅十进制整数转二进制可…

企业内训|阅读行业产品运营实战训练营-某运营商数字娱乐公司

近日&#xff0c;TsingtaoAI公司为某运营商旗下数字娱乐公司组织的“阅读行业产品运营实战训练营”在杭州落下帷幕。此次训练营由TsingtaoAI资深互联网产品专家程靖主持。该公司的业务骨干——来自内容、市场、业务、产品与技术等跨部门核心岗位、拥有8-10年实战经验的中坚力量…

pinctrl子系统学习笔记

一、背景 cpu的gpio引脚可以复用成多个功能&#xff0c;如可以配置成I2C或者普通GPIO模式。配置方式一般是通过写引脚复用的配置寄存器&#xff0c;但是不同芯片厂商配置寄存器格式内容各不相同&#xff0c;设置引脚复用无法做到通用且自由的配置&#xff0c;只能在启动初始化…

免费开源了一个图床工具 github-spring-boot-starter

文章目录 第一步&#xff0c;新建一个SpringBoot项目第二步&#xff0c;在pom文件里面引入jar包第三步&#xff0c;配置你的github信息github.authorization1、进入github官网&#xff0c;登录账号&#xff0c;点击头像&#xff0c;选择setting2、选择[Developer Settings](htt…

JVM系列之内存区域

每日禅语 有一位年轻和尚&#xff0c;一心求道&#xff0c;多年苦修参禅&#xff0c;但一直没有开悟。有一天&#xff0c;他打听到深山中有一古寺&#xff0c;住持和尚修炼圆通&#xff0c;是得道高僧。于是&#xff0c;年轻和尚打点行装&#xff0c;跋山涉水&#xff0c;千辛万…

自动驾驶AVM环视算法--python版本的俯视碗型投影模式

c语言版本和算法原理的可以查看本人的其他文档。《自动驾驶AVM环视算法--3D碗型投影模式的exe测试工具》本文档进用于展示部分代码的视线&#xff0c;获取方式网盘自行获取&#xff08;非免费介意勿下载&#xff09;&#xff1a;链接: https://pan.baidu.com/s/1STjUd87_5wDk_C…

【并发容器】源码级ConcurrentHashMap详解(java78)

1. ConcurrentHashMap 为什么要使用ConcurrentHashmap 在多线程的情况下&#xff0c;使用HashMap是线程不安全的。另外可以使用Hashtable&#xff0c;其是线程安全的&#xff0c;但是Hashtable的运行效率很低&#xff0c;之所以效率低下主要是因为其实现使用了synchronized关…

程序设计考题汇总(四:SQL练习)

文章目录 查询结果限制返回行数 查询结果限制返回行数 select device_id from user_profile LIMIT 2;