当科技遇上神器:用Streamlit定制AI可视化问答界面

news2024/11/27 9:42:27

Streamlit是一个开源的Python库,利用Streamlit可以快速构建机器学习应用的用户界面。

本文主要探讨如何使用Streamlit构建大模型+外部知识检索的AI问答可视化界面。

我们先构建了外部知识检索接口,然后让大模型根据检索返回的结果作为上下文来回答问题。

Streamlit-使用说明

下面简单介绍下Streamlit的安装和一些用到的组件。

  1. Streamlit安装
pip install streamlit
  1. Streamlit启动
streamlit run xxx.py --server.port 8888

说明:

  • 如果不指定端口,默认使用8501,如果启动多个streamlit,端口依次升序,8502,8503,…。
  • 设置server.port可指定端口。
  • streamlit启动后将会给出两个链接,Local URL和Network URL。
  1. 相关组件
import streamlit as st
  • st.header

streamlit.header(body)

body:字符串,要显示的文本。

  • st.markdown

st.markdown(body, unsafe_allow_html=False)

body:要显示的markdown文本,字符串。

unsafe_allow_html: 是否允许出现html标签,布尔值,默认:false,表示所有的html标签都将转义。 注意,这是一个临时特性,在将来可能取消。

  • st.write

st.write(*args, **kwargs)

*args:一个或多个要显示的对象参数。

unsafe_allow_html :是否允许不安全的HTML标签,布尔类型,默认值:false。

  • st.button

st.button(label, key=None)

label:按钮标题字符串。

key:按钮组件的键,可选。如果未设置的话,streamlit将自动生成一个唯一键。

  • st.radio

st.radio(label, options, index=0, format_func=<class 'str'>, key=None)

label:单选框文本,字符串。

options:选项列表,可以是以下类型:
list
tuple
numpy.ndarray
pandas.Series

index:选中项的序号,整数。

format_func:选项文本的显示格式化函数。

key:组件ID,当未设置时,streamlit会自动生成。

  • st.sidebar

st.slider(label, min_value=None, max_value=None, value=None, step=None, format=None, key=None)

label:说明文本,字符串。

min_value:允许的最小值,默认值:0或0.0。

max_value:允许的最大值,默认值:0或0.0。

value:当前值,默认值为min_value。

step:步长,默认值为1或0.01。

format:数字显示格式字符串

key:组件ID。

  • st.empty

st.empty()

填充占位符。

  • st.columns

插入并排排列的容器。

st.columns(spec, *, gap="small")

spec: 控制要插入的列数和宽度。

gap: 列之间的间隙大小。

AI问答可视化代码

这里只涉及到构建AI问答界面的代码,不涉及到外部知识检索。

  1. 导入packages
import streamlit as st
import requests
import json
import sys,os

import torch
import torch.nn as nn
from dataclasses import dataclass, asdict
from typing import List, Optional, Callable
import copy
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
from peft import PeftModel
from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration
  1. 外部知识检索
def get_reference(user_query,use_top_k=True,top_k=10,use_similar_score=True,threshold=0.7):
  """
  外部知识检索的方式,使用top_k或者similar_score控制检索返回值。
  """
    # 设置检索接口
    SERVICE_ADD = ''
    
    ref_list = []
    user_query = user_query.strip()
    input_data = {}
    if use_top_k:
        input_data['query'] = user_query
        input_data['topk'] = top_k
        result = requests.post(SERVICE_ADD, json=input_data)
        res_json = json.loads(result.text)
        for i in range(len(res_json['answer'])):
            ref = res_json['answer'][i]
            ref_list.append(ref)
    elif use_similar_score:
        input_data['query'] = user_query
        input_data['topk'] = top_k
        result = requests.post(SERVICE_ADD, json=input_data)
        res_json = json.loads(result.text)
        for i in range(len(res_json['answer'])):
            maxscore = res_json['answer'][i]['prob']
            if maxscore > threshold:  
                ref = res_json['answer'][i]
                ref_list.append(ref)
    return ref_list
  1. 参数设置
# 设置清除按钮
def on_btn_click():
    del st.session_state.messages

# 设置参数
def set_config():
    # 设置基本参数
    base_config = {"model_name":"","use_ref":"","use_topk":"","top_k":"","use_similar_score":"","max_similar_score":""}
    # 设置模型参数
    model_config = {'top_k':'','top_p':'','temperature':'','max_length':'','do_sample':""}
    
    # 左边栏设置
    with st.sidebar:
        model_name = st.radio(
            "模型选择:",
            ["baichuan2-13B-chat", "qwen-14B-chat","chatglm-6B","chatglm3-6B"],
            index="0",
        )
        base_config['model_name'] = model_name
        
        set_ref = st.radio(
            "是否使用外部知识库:",
            ["是","否"],
            index="0",
        )
        base_config['use_ref'] = set_ref
        
        if set_ref=="是":
            set_topk_score = st.radio(
                '设置选择参考文献的方式:',
                ['use_topk','use_similar_score'],
                index='0',
                )
            
            if set_topk_score=='use_topk':
                set_topk = st.slider(
                    'Top_K', 1, 10, 5,step=1
                )
                base_config['top_k'] = set_topk
                base_config['use_topk'] = True
                base_config['use_similar_score'] = False
                set_score = st.empty()
                
            elif set_topk_score=='use_similar_score':
                set_score = st.slider(
                    "Max_Similar_Score",0.00,1.00,0.70,step=0.01
                )
                base_config['max_similar_score'] = set_score
                base_config['use_similar_score'] = True
                base_config['use_topk'] = False
                set_topk = st.empty()
                
            else:
                set_topk_score = st.empty()
                set_topk = st.empty()
                set_score = st.empty()
                
        sample = st.radio("Do Sample", ('True', 'False'))
        max_length = st.slider("Max Length", min_value=64, max_value=2048, value=1024)
        top_p = st.slider(
            'Top P', 0.0, 1.0, 0.7, step=0.01
        )
        temperature = st.slider(
            'Temperature', 0.0, 2.0, 0.05, step=0.01
        )
        st.button("Clear Chat History", on_click=on_btn_click)
        
    # 设置模型参数
    model_config['top_p']=top_p
    model_config['do_sample']=sample
    model_config['max_length']=max_length
    model_config['temperature']=temperature
    return base_config,model_config
  1. 设置模型输入格式
# 设置不同模型的输入格式
def set_input_format(model_name):
    # ["baichuan2-13B-chat", "baichuan2-7B-chat", "qwen-14B-chat",'chatglm-6B','chatglm3-6B']
    if model_name=="baichuan2-13B-chat" or model_name=='baichuan2-7B-chat':
        input_format = "<reserved_106>{{query}}<reserved_107>"
    elif model_name=="qwen-14B-chat":
        input_format = """
        <|im_start|>system 
        你是一个乐于助人的助手。<|im_end|>
        <|im_start|>user
        {{query}}<|im_end|>
        <|im_start|>assistant"""
    elif model_name=="chatglm-6B":
        input_format = """{{query}}"""
    elif model_name=="chatglm3-6B":
        input_format = """
        <|system|>
        You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
        <|user|>
        {{query}}
        <|assistant|>
        """
    return input_format
  1. 加载模型
# 加载模型和分词器
@st.cache_resource
def load_model(model_name):
    if model_name=="baichuan2-13B-chat":
        model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat",trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat",trust_remote_code=True)
        model.to("cuda:0")
    elif model_name=="qwen-14B-chat":
        model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat",trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat",trust_remote_code=True)
        model.to("cuda:1")
    elif model_name=="chatglm-6B":
        model = ChatGLMForConditionalGeneration.from_pretrained('THUDM/chatglm-6b',trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm-6b',trust_remote_code=True)
        model.to("cuda:2")
    elif model_name=="chatglm3-6B":
        model = AutoModelForCausalLM.from_pretrained('THUDM/chatglm3-6b',trust_remote_code=True)
        lora_path = ""
        tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm3-6b',trust_remote_code=True)
        model.to("cuda:3")
        
    # 加载lora包
    model = PeftModel.from_pretrained(model,lora_path)
    return model,tokenizer
  1. 推理参数设置

def llm_chat(model_name,model,tokenizer,model_config,query):
    response = ''
    top_k = model_config['top_k']
    top_p = model_config['top_p']
    max_length = model_config['max_length']
    do_sample = model_config['do_sample']
    temperature = model_config['temperature']
    
    if model_name=="baichuan2-13B-chat" or model_name=='baichuan-7B-chat':
        messages = []
        messages.append({"role": "user", "content": query})
        response = model.chat(tokenizer, messages)
        
    elif model_name=="qwen-14B-chat":
        response, history = model.chat(tokenizer, query, history=None, top_p=top_p, max_new_tokens=max_length, do_sample=do_sample, temperature=temperature)
        
    elif model_name=="chatglm-6B":
        response, history = model.chat(tokenizer, query, history=None, top_p=top_p, max_length=max_length, do_sample=do_sample, temperature=temperature)
    
    elif model_name=="chatglm3-6B":
        response, history= model.chat(tokenizer, query, top_p=top_p, max_length=max_length, do_sample=do_sample, temperature=temperature)
        
    return response
  1. 主程序
if __name__=="__main__":
    
    #对话的图标
    user_avator = "🧑‍💻"
    robot_avator = "🤖"
    
    if "messages" not in st.session_state:
        st.session_state.messages = []
        
    torch.cuda.empty_cache()
    base_config,model_config = set_config()
    model_name = base_config['model_name']
    use_ref = base_config['use_ref']
    
    model,tokenizer = load_model(model_name=model_name)
    
    input_format = set_input_format(model_name=model_name)

    header_text = f'Large Language Model :{model_name}'
    st.header(header_text)
    
    if use_ref=="是":
        col1, col2 = st.columns([5, 3])  
        with col1:
            for message in st.session_state.messages:
                with st.chat_message(message["role"], avatar=message.get("avatar")):
                    st.markdown(message["content"])
        
        if user_query := st.chat_input("请输入内容..."):
            with col1:  
                with st.chat_message("user", avatar=user_avator):
                    st.markdown(user_query)
                st.session_state.messages.append({"role": "user", "content": user_query, "avatar": user_avator})
                
                with st.chat_message("robot", avatar=robot_avator):
                    message_placeholder = st.empty()
                    use_top_k = base_config['use_topk']
                    
                    if use_top_k:
                        top_k = base_config['top_k']
                        use_similar_score = base_config['use_similar_score']
                        ref_list = get_reference(user_query,use_top_k=use_top_k,top_k=top_k,use_similar_score=use_similar_score) 
                    else:
                        use_top_k = base_config['use_topk']
                        use_similar_score = base_config['use_similar_score']
                        threshold = base_config['max_similar_score']
                        ref_list = get_reference(user_query,use_top_k=use_top_k,use_similar_score=use_similar_score,threshold=threshold)
                    
                    if ref_list:
                        context = ""
                        for ref in ref_list:
                            context = context+ref['para']+"\n"
                        context = context.strip('\n')
                        query = f'''
                        上下文:
                        【
                        {context} 
                        】
                        只能根据提供的上下文信息,合理回答下面的问题,不允许编造内容,不允许回答无关内容。
                        问题:
                        【
                        {user_query}
                        】
                        '''
                    else:
                        query = user_query
                    query = input_format.replace("{{query}}",query)
                    print('输入:',query)
                    max_len = model_config['max_length']
                    if len(query)>max_len:
                        cur_response = f'字数超过{max_len},请调整max_length。'
                    else:
                        cur_response = llm_chat(model_name,model,tokenizer,model_config,query)
                    fs.write(f'输入:{query}')
                    fs.write('\n')
                    fs.write(f'输出:{cur_response}')
                    fs.write('\n')
                    sys.stdout.flush()

                    if len(query)<max_len:
                        if ref_list:
                            cur_response = f"""
                            大模型将根据外部知识库回答您的问题:{cur_response}
                            """
                        else:
                            cur_response = f"""
                            大模型将根据预训练时的知识回答您的问题,存在编造事实的可能性。因此以下输出仅供参考:{cur_response}
                            """
                            
                    message_placeholder.markdown(cur_response)
                st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
                
            with col2:
                ref_list = get_reference(user_query)
                if ref_list:
                    for ref in ref_list:
                        ques = ref['ques']
                        answer = ref['para']
                        score = ref['prob']
                        question = f'{ques}--->score: {score}'
                        with st.expander(question):
                            st.write(answer)
    
    else:
        for message in st.session_state.messages:
            with st.chat_message(message["role"], avatar=message.get("avatar")):
                st.markdown(message["content"])
        if user_query := st.chat_input("请输入内容..."):
            with st.chat_message("user", avatar=user_avator):
                st.markdown(user_query)
            st.session_state.messages.append({"role": "user", "content": user_query, "avatar": user_avator})
            with st.chat_message("robot", avatar=robot_avator):
                message_placeholder = st.empty()
                query = input_format.replace("{{query}}",user_query)
                max_len = model_config['max_length']
                if len(query)>max_len:
                    cur_response = f'字数超过{max_len},请调整max_length。'
                else:
                    cur_response = llm_chat(model_name,model,tokenizer,model_config,query)
                fs.write(f'输入:{query}')
                fs.write('\n')
                fs.write(f'输出:{cur_response}')
                fs.write('\n')
                sys.stdout.flush()
                cur_response = f"""
                大模型将根据预训练时的知识回答您的问题,存在编造事实的可能性。因此以下输出仅供参考:{cur_response}
                """
                message_placeholder.markdown(cur_response)
                st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
                    
  1. 可视化界面展示

总结

Streamlit工具使用非常方便,说明文档清晰。

这个可视化界面集成了多个大模型+外部知识检索,同时可以在线调整模型参数,使用方便。

完整代码:https://github.com/hjandlm/Streamlit_LLM_QA

参考

[1] https://docs.streamlit.io/
[2] http://cw.hubwiz.com/card/c/streamlit-manual/
[3] https://github.com/hiyouga/LLaMA-Factory/tree/9093cb1a2e16d1a7fde5abdd15c2527033e33143

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1166926.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UPS设备还只知道人工巡检?这个神器你一定要试试!

随着电子设备在我们的生活和工作中扮演越来越重要的角色&#xff0c;电力的可靠性变得至关重要。不间断电源系统作为一种关键设备&#xff0c;可以提供电力备份&#xff0c;以保障设备在电力中断或波动的情况下能够正常运行。然而&#xff0c;UPS设备的有效监控和管理对于确保其…

unity 使用TriLib插件动态读取外部模型

最近在做动态加载读取外部模型的功能使用了triLib插件&#xff0c;废话不多说直接干货。 第一步下载导入插件&#xff0c;直接分享主打白嫖共享&#xff0c;不搞花里胡哨的。 链接&#xff1a;https://pan.baidu.com/s/1DK474wSrIZ0R6i0EBh5V8A 提取码&#xff1a;tado 导入后第…

录屏软件免费版,精选3款,助你轻松录制!

录屏软件在今天的数字时代中扮演着重要的角色&#xff0c;无论是为了创建教育教程、分享游戏成就&#xff0c;还是记录计算机上的操作步骤。然而&#xff0c;许多用户面临付费和高级功能的限制&#xff0c;很难找到合适的免费录屏软件。那录屏软件免费版都有哪些呢&#xff1f;…

SpringBoot整合Activiti7——全局监听器(八)

文章目录 一、全局监听器事件类型配置方式(选)日志监听器代码实现xml文件创建全局监听器全局配置类测试流程部署流程启动流程 一、全局监听器 它是引擎范围的事件监听器&#xff0c;可以捕获所有的Activiti事件。 事件类型 ActivitiEventType 枚举类中包含全部事件类型 配置方…

【接口测试】Postman登录接口鉴权实战案例,跟着大牛通关...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 在做接口测试的时…

双十一快递“当天达”?宏电助力物流分拣系统高效运行

​众所周知&#xff0c;每年双11都是快递业务的高峰期&#xff0c;是对各大物流企业运输能力的一次大考。为了持续提升快递配送的速度&#xff0c;自动化物流仓储建设的速度也在不断的加快&#xff0c;而在一个完整的自动化物流仓储系统中&#xff0c;输送分拣设备是物流自动化…

掌握视频剪辑技巧:高手教您如何批量减少片头时长并调整播放倍速

随着社交媒体的普及&#xff0c;视频已经成为人们传递信息、表达观点的重要方式。而视频剪辑则是在这个过程中不可或缺的一环。在视频剪辑过程中&#xff0c;时长是一个重要的因素。有时候&#xff0c;我们需要对视频进行裁剪&#xff0c;以减少其时长。今天&#xff0c;我们讲…

分享一下在微信小程序里怎么做一个投票链接

在当今信息化社会&#xff0c;投票已成为各行各业收集意见、汇聚智慧的重要手段。传统的投票方式往往需要投入大量人力物力&#xff0c;而如今&#xff0c;借助微信小程序&#xff0c;我们可以在几分钟内创建一个高效、便捷的投票平台。本文将详细介绍如何在微信小程序中添加投…

OpenCV检测圆(Python版本)

文章目录 示例代码示例结果调参 示例代码 import cv2 import numpy as np# 加载图像 image_path DistanceComparison/test_image/1.png image cv2.imread(image_path, cv2.IMREAD_COLOR)# 将图像转换为灰度 gray cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)# 使用高斯模糊消除…

喜讯!云起无垠获评GEEKCON 2023“前沿突破奖“

近日&#xff0c;“GEEKCON 2023”中国站的比赛在上海西岸艺术中心成功举办。这场活动围绕着人工智能与专业安全的前沿技术展开了深入的探讨和实践活动。本次活动特设五大系列专场&#xff1a;“对抗研判 AVSS挑战赛”、“深蓝洞察之特别披露”、“年度主题大碰撞&#xff1a;G…

分享74个Python管理系统源代码总有一个是你想要的

分享74个Python管理系统源代码总有一个是你想要的 源码链接&#xff1a;https://pan.baidu.com/s/1JXFePOOk_ot4Jdu6_RylOg?pwd8888 提取码&#xff1a;8888 项目名称 ais系统后台项目&#xff0c;基于python flask框架 BNUZ教务系统认证爬虫Python语言实现&#xff0c;你…

java服务器环境配置以及项目搭建

一. 内容简介 使用Mavn聚合工程&#xff0c;springboot整合spring,springmvc,mybatis框架&#xff0c;完成项目搭建 二. 软件环境 2.1 java 1.8.0_144 2.2 mysql Ver 8.0.30( 8.10的好像出问题&#xff0c;我给重装了) 2.3 IntelliJ IDEA 2023.1 2.4 Apache Maven 3.9.5 …

冬天女儿的羽绒服就选它了,哈哈很喜欢

长款设计感满满的羽绒服 真的一下子就戳中了我的心巴 90白鸭绒&#xff0b;杜邦三防工艺&#xff0b;精细压线 厚实保暖不臃肿&#xff0c;粉色撞色甜美又可爱

【MongoDB】索引 - 单字段索引

MongoDB支持在集合文档中的任意字段上创建索引&#xff0c;默认情况下所有的集合都有一个_id字段的索引&#xff0c;用户和应用可以新增索引用于查询和操作。 一、准备工作 这里准备一些学生数据 db.students.insertMany([{ _id: 1, name: "张三", age: 20, clas…

springboot整合七牛云oss操作文件

文章目录 springboot整合七牛云oss操作文件核心代码&#xff08;记得修改application.yml配置参数⭐&#xff09;maven依赖QiniuOssProperties配置类UploadControllerResponseResult统一封装响应结果ResponseType响应类型枚举OssUploadService接口QiniuOssUploadServiceImpl实现…

劳务派遣派遣人员如何缴纳保险

《劳动合同法》规定&#xff1a;劳务派遣协议应当约定派遣人员的社会保险费的数额与支付方式以及违反协议的责任。可见&#xff0c;由哪一方为劳务派遣人员缴纳各项社会保险是由劳务派遣公司和实际用工单位协商确定的。但不管如何约定&#xff0c;劳务派遣单位或用工单位都必须…

倾斜摄影三维模型的根节点合并的文件大小与质量关系分析

倾斜摄影三维模型的根节点合并的文件大小与质量关系分析 倾斜摄影三维模型的根节点合并过程涉及大量的数据&#xff0c;包括高分辨率图像和点云信息。在进行根节点合并时&#xff0c;文件大小和质量之间存在一定的关系。本文将分析倾斜摄影三维模型的根节点合并的文件大小与质量…

机器学习笔记:RNN值Teacher Forcing

1 基本介绍 Teacher forcing是一种在训练循环神经网络&#xff08;RNN&#xff09;时使用的技术&#xff0c;尤其是在序列生成任务中&#xff0c;如机器翻译、文本生成或语音合成。这种方法的目的是更有效地训练网络预测下一个输出&#xff0c;给定一系列先前的观察结果。 1.…

专访 SPACE ID:通往 Web3 无许可域名服务协议之路

Web3 行业发展风起云涌&#xff0c;对于初创项目而言&#xff0c;如何寻找适合自己的赛道是首要问题。当前伴随用户交互和跨平台操作需求日渐兴起&#xff0c;如何更迅速地使用一站式域名实现便捷验证成为大众的心头期盼。 这一背景下&#xff0c;SPACE ID 于众星林立的 Web3 …

MFC 窗体插入图片

1.制作BMP图像1.bmp 放到res文件夹下&#xff0c;资源视图界面导入res文件夹下的1.bmp 2.添加控件 控件类型修改为Bitmap 图像&#xff0c;选择IDB_BITMAP1 3.效果