开源模型应用落地-chatglm3-6b-gradio-入门篇(七)

news2024/9/20 16:27:32

一、前言

    早前的文章,我们都是通过输入命令的方式来使用Chatglm3-6b模型。现在,我们可以通过使用gradio,通过一个界面与模型进行交互。这样做可以减少重复加载模型和修改代码的麻烦,
让我们更方便地体验模型的效果。


二、术语

2.1、Gradio

    是一个用于构建交互式界面的Python库。它使得在Python中创建快速原型、构建和共享机器学习模型变得更加容易。

    Gradio的主要功能是为机器学习模型提供一个即时的Web界面,使用户能够与模型进行交互,输入数据并查看结果,而无需编写复杂的前端代码。它提供了一个简单的API,可以将输入和输出绑定到模型的函数或方法,并自动生成用户界面。


三、前置条件

3.1. windows or linux操作系统均可

3.2. 下载chatglm3-6b模型

从huggingface下载:https://huggingface.co/THUDM/chatglm3-6b/tree/main

从魔搭下载:魔搭社区汇聚各领域最先进的机器学习模型,提供模型探索体验、推理、训练、部署和应用的一站式服务。https://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/fileshttps://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/files

 3.3. 创建虚拟环境&安装依赖

conda create --name chatglm3 python=3.10
conda activate chatglm3
pip install protobuf transformers==4.39.3 cpm_kernels torch>=2.0 sentencepiece accelerate
pip install gradio

四、技术实现

# -*-  coding = utf-8 -*-
import gradio as gr
import torch
from threading import Thread

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    TextIteratorStreamer
)

modelPath = "/model/chatglm3-6b"

def loadTokenizer():
    tokenizer = AutoTokenizer.from_pretrained(modelPath, use_fast=False, trust_remote_code=True)
    return tokenizer

def loadModel():
    model = AutoModelForCausalLM.from_pretrained(modelPath, device_map="auto",  trust_remote_code=True).cuda()
    model = model.eval()
    return model

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [0, 2]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text

def predict(history, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    for idx, (user_msg, model_msg) in enumerate(history):
        if idx == len(history) - 1 and not model_msg:
            messages.append({"role": "user", "content": user_msg})
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    model_inputs = tokenizer.apply_chat_template(messages,
                                                 add_generation_prompt=True,
                                                 tokenize=True,
                                                 return_tensors="pt").to(next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1.2,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    for new_token in streamer:
        if new_token != '':
            history[-1][1] += new_token
            yield history


with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
    chatbot = gr.Chatbot()

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


    def user(query, history):
        return "", history + [[parse_text(query), ""]]


    submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
        predict, [chatbot, max_length, top_p, temperature], chatbot
    )
    emptyBtn.click(lambda: None, None, chatbot, queue=False)

if __name__ == '__main__':
    model = loadModel()
    tokenizer = loadTokenizer()

    demo.queue()
    demo.launch(server_name="0.0.0.0", server_port=8989, inbrowser=True, share=False)

调用结果:

启动成功:

GPU使用情况:

浏览器访问:

推理:


五、附带说明

5.1. 问题:AttributeError: 'ChatGLMTokenizer' object has no attribute 'apply_chat_template'

1. transformers的版本太低,需要升级

pip install --upgrade transformers==4.39.3

5.2. 界面无法打开

1. 服务监听地址不能是127.0.0.1

2. 检查服务器的安全策略或防火墙配置

 服务端:lsof -i:8989 查看端口是否正常监听

 客户端:telnet ip 8989 查看是否可以正常连接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1597436.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 Fetch API 执行 GraphQL 查询和变体

GraphQL 简介 GraphQL 是从远程服务器查询数据的强大工具&#xff0c;也是我构建 API 的首选方式。对一些人来说&#xff0c;学习它可能有一定难度&#xff0c;因为教程通常使用 Apollo 或 Relay 等工具进行编写。 这些工具很不错&#xff0c;但通常更适用于复杂项目。在某些…

Python基于Django的微博热搜、微博舆论可视化系统

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…

【TEE论文】IceClave: A Trusted Execution Environment for In-Storage Computing

摘要 使用现代固态硬盘&#xff08;SSD&#xff09;的存储中计算使开发人员能够将程序从主机转移到SSD上。这被证明是缓解I/O瓶颈的有效方法。为了促进存储中计算&#xff0c;已经提出了许多框架。然而&#xff0c;其中很少有框架将存储中的安全性作为首要任务。具体而言&…

WPS二次开发系列:WPS SDk功能就概览

作者持续关注WPS二次开发专题系列&#xff0c;持续为大家带来更多有价值的WPS开发技术细节&#xff0c;如果能够帮助到您&#xff0c;请帮忙来个一键三连&#xff0c;更多问题请联系我&#xff08;QQ:250325397&#xff09; 作者通过深度测试使用了WPS SDK提供的Demo&#xff0…

比特币突然暴跌

作者&#xff1a;秦晋 周末愉快。 今天给大家分享两则比特币新闻&#xff0c;也是两个数据。一则是因为中东地缘政治升温&#xff0c;传统资本市场的风险情绪蔓延至加密市场&#xff0c;引发加密市场暴跌。比特币跌至66000美元下方。杠杆清算金额高达8.5亿美元。 二则是&#x…

代码随想录Day41:动态规划Part3

Leetcode 343. 整数拆分 讲解前&#xff1a; 毫无头绪 讲解后&#xff1a; 这道题的动态思路一开始很不容易想出来&#xff0c;虽然dp数组的定义如果知道是动态规划的话估摸着可以想出来那就是很straight forward dp定义&#xff1a;一维数组dp[i], i 代表整数的值&#xf…

蓝桥杯 — — 纯质数

纯质数 题目&#xff1a; 思路&#xff1a; 一个最简单的思路就是枚举出所有的质数&#xff0c;然后再判断这个质数是否是一个纯质数。 枚举出所有的质数&#xff1a; 可以使用常规的暴力求解法&#xff0c;其时间复杂度为&#xff08; O ( N N ) O(N\sqrt{N}) O(NN ​)&…

(十)C++自制植物大战僵尸游戏设置功能实现

植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/m0EtD 游戏设置 游戏设置功能是一个允许玩家根据个人喜好和设备性能来调整游戏各项参数的重要工具。游戏设置功能是为了让玩家能够根据自己的需求和设备性能来调整游戏&#xff0c;以获得最佳的游戏体验。不同的游戏和平…

相机系列——透视投影:针孔相机模型

作者&#xff1a;木一 引言 上文我们提到&#xff0c;三维相机是对真实世界成像的模拟&#xff0c;为了让三维物体在计算机屏幕上呈现出来的图像符合人眼观察效果&#xff0c;通常采用透视投影方式模拟相机成像&#xff0c;为了简化计算&#xff0c;可以用针孔相机模型来描述…

设计模式——观察者模式17

观察者模式指多个对象间存在一对多的依赖关系&#xff0c;当一个对象的状态发生改变时&#xff0c;所有依赖于它的对象都得到通知并被自动更新。这种模式有时又称作发布-订阅模式。 中介者模式是N对N的双向关系。观察者模式是1对N的单向关系。 设计模式&#xff0c;一定要敲代码…

STC89C52学习笔记(十三)

STC89C52学习笔记&#xff08;十三&#xff09; 综述&#xff1a;本文讲述了红外调控的原理和通信以及外部中断相关知识。 一、红外调控 1、定义 红外遥控是利用红光进行通信的设备。 2、特点 ①由红LED将调制后的信号发出&#xff0c;由专门的红外接收头进行解调输出。 …

​​​​网络编程探索系列之——广播原理剖析

hello &#xff01;大家好呀&#xff01; 欢迎大家来到我的网络编程系列之广播原理剖析&#xff0c;在这篇文章中&#xff0c; 你将会学习到如何在网络编程中利用广播来与局域网内加入某个特定广播组的主机&#xff01; 希望这篇文章能对你有所帮助&#xff0c;大家要是觉得我写…

jupyter使用虚拟环境里的依赖配置

进入虚拟环境fourier-features-pytorchconda activate fourier-features-pytorch 安装ipykernel pip install ipykernel -i https://pypi.tuna.tsinghua.edu.cn/simple将核与虚拟环境匹配 python -m ipykernel install --user --namefourier-features-pytorch打开jupyter j…

Kafka 架构深入介绍 及搭建Filebeat+Kafka+ELK

目录 一 架构深入介绍 &#xff08;一&#xff09;Kafka 工作流程及文件存储机制 &#xff08;二&#xff09;数据可靠性保证 &#xff08;三&#xff09;数据一致性问题 &#xff08;四&#xff09;故障问题 &#xff08;五&#xff09;ack 应答机制 二 实…

【max材质addtive叠加模式特效渲染不出通道的解决办法】

max材质addtive叠加模式特效渲染不出通道的解决办法 2021-12-22 18:15 max的scanline扫描线&#xff0c;vray渲染可以&#xff0c;红移不行(只支持它自己的材质&#xff0c;它自己的材质没有additive模式)。据说mr是可以的。 右侧的球体使用附加不透明度。 附加不透明度通过将…

关于机器学习/深度学习的一些事-答知乎问(六)

如何使用频率域变换对序列数据进行增强&#xff1f; 时频变换是常见的信号分析思路&#xff0c;同样可用于数据增强。在频率域添加噪声是方法之一。比如可以对传感器信号应用短时傅里叶变换STFT得到具有时序关系的谱特征&#xff0c;再在谱特征上应用两种数据增强方法。一是对…

Amazon MemoryDB for Redis的探索和实践

一、背景 由于当下项目的日益增长的数据量&#xff0c;单机Redis已经远远不能满足我们的要求。考虑转成集群&#xff0c;但是直接在服务器中搭建Redis集群的话&#xff0c;EC2挂掉则Redis也会随之挂掉&#xff0c;耦合性太强。所以将Redis相关的服务全部抽取在单独的服务器上或…

论文笔记:Time Travel in LLMs: Tracing Data Contamination in Large Language Models

iclr 2024 spotlight reviewer评分 688 1 intro 论文认为许多下游任务&#xff08;例如&#xff0c;总结、自然语言推理、文本分类&#xff09;上观察到的LLMs印象深刻的表现可能因数据污染而被夸大 所谓数据污染&#xff0c;即这些下游任务的测试数据出现在LLMs的预训练数据…

ins视频批量下载,instagram批量爬取视频信息

简介 Instagram 是目前最热门的社交媒体平台之一,拥有大量优质的视频内容。但是要逐一下载这些视频往往非常耗时。在这篇文章中,我们将介绍如何使用 Python 编写一个脚本,来实现 Instagram 视频的批量下载和信息爬取。 我们使用selenium获取目标用户的 HTML 源代码,并将其保存…

MySQL模糊查询

一、MySQL通配符模糊查询(%&#xff0c;_) 1.1.通配符的分类 1.“%”百分号通配符&#xff1a;表示任何字符出现任意次数&#xff08;可以是0次&#xff09; 2.“_”下划线通配符&#xff1a;表示只能匹配单个字符&#xff0c;不能多也不能少&#xff0c;就是一个字符。当然…