8G 显存玩转书生大模型 Demo

news2024/9/8 7:13:14

创建可用环境

# 创建环境
conda create -n demo python=3.10 -y
# 激活环境
conda activate demo
# 安装 torch
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia -y
# 安装其他依赖
pip install transformers==4.38
pip install sentencepiece==0.1.99
pip install einops==0.8.0
pip install protobuf==5.27.2
pip install accelerate==0.33.0
pip install streamlit==1.37.0

如果没有( InternLM2-Chat-1.8B模型)下载 InternLM2-Chat-1.8B模型

创建download_hf.py文件用于下载模型

import os

# 设置环境变量
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

# 下载模型
os.system('huggingface-cli download --resume-download InternLM2-Chat-1.8B --local-dir /root/model/InternLM2-Chat-1.8B')

激活对应的环境并运行下载模型脚本

conda activate demo
python download_hf.py

运行InternLM2-Chat-1.8B模型

创建一个文件touch cli_demo.py写入如下代码

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

#此处填写自己的模型下载在什么地方就写什么地方
model_name_or_path = "/root/model/internlm2-chat-1_8b"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, device_map='cuda:0')
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='cuda:0')
model = model.eval()

system_prompt = """You are an AI assistant whose name is InternLM (书生·浦语).
- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.
"""

messages = [(system_prompt, '')]

print("=============Welcome to InternLM chatbot, type 'exit' to exit.=============")

while True:
    input_text = input("\nUser  >>> ")
    input_text = input_text.replace(' ', '')
    if input_text == "exit":
        break

    length = 0
    for response, _ in model.stream_chat(tokenizer, input_text, messages):
        if response is not None:
            print(response[length:], flush=True, end="")
            length = len(response)

运行cli_demo.py

在上面创建的环境中运行

conda activate demo

python cli_demo.py

直接用脚本执行的效果如图

用浏览器来进行对话之Streamlit Web Demo 部署 InternLM2-Chat-1.8B 模型

创建streamlit_demo.py来运行InternLM2-Chat-1.8B 模型

# isort: skip_file

import copy

import warnings

from dataclasses import asdict, dataclass

from typing import Callable, List, Optional

import streamlit as st

import torch

from torch import nn

from transformers.generation.utils import (LogitsProcessorList,

                                           StoppingCriteriaList)

from transformers.utils import logging

from transformers import AutoTokenizer, AutoModelForCausalLM  # isort: skip

logger = logging.get_logger(__name__)


 

@dataclass

class GenerationConfig:

    # this config is used for chat to provide more diversity

    max_length: int = 32768

    top_p: float = 0.8

    temperature: float = 0.8

    do_sample: bool = True

    repetition_penalty: float = 1.005


 

@torch.inference_mode()

def generate_interactive(

    model,

    tokenizer,

    prompt,

    generation_config: Optional[GenerationConfig] = None,

    logits_processor: Optional[LogitsProcessorList] = None,

    stopping_criteria: Optional[StoppingCriteriaList] = None,

    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],

                                                List[int]]] = None,

    additional_eos_token_id: Optional[int] = None,

    **kwargs,

):

    inputs = tokenizer([prompt], padding=True, return_tensors='pt')

    input_length = len(inputs['input_ids'][0])

    for k, v in inputs.items():

        inputs[k] = v.cuda()

    input_ids = inputs['input_ids']

    _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

    if generation_config is None:

        generation_config = model.generation_config

    generation_config = copy.deepcopy(generation_config)

    model_kwargs = generation_config.update(**kwargs)

    bos_token_id, eos_token_id = (  # noqa: F841  # pylint: disable=W0612

        generation_config.bos_token_id,

        generation_config.eos_token_id,

    )

    if isinstance(eos_token_id, int):

        eos_token_id = [eos_token_id]

    if additional_eos_token_id is not None:

        eos_token_id.append(additional_eos_token_id)

    has_default_max_length = kwargs.get(

        'max_length') is None and generation_config.max_length is not None

    if has_default_max_length and generation_config.max_new_tokens is None:

        warnings.warn(

            f"Using 'max_length''s default \

                ({repr(generation_config.max_length)}) \

                to control the generation length. "

            'This behaviour is deprecated and will be removed from the \

                config in v5 of Transformers -- we'

            ' recommend using `max_new_tokens` to control the maximum \

                length of the generation.',

            UserWarning,

        )

    elif generation_config.max_new_tokens is not None:

        generation_config.max_length = generation_config.max_new_tokens + \

            input_ids_seq_length

        if not has_default_max_length:

            logger.warn(  # pylint: disable=W4902

                f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "

                f"and 'max_length'(={generation_config.max_length}) seem to "

                "have been set. 'max_new_tokens' will take precedence. "

                'Please refer to the documentation for more information. '

                '(https://huggingface.co/docs/transformers/main/'

                'en/main_classes/text_generation)',

                UserWarning,

            )

    if input_ids_seq_length >= generation_config.max_length:

        input_ids_string = 'input_ids'

        logger.warning(

            f'Input length of {input_ids_string} is {input_ids_seq_length}, '

            f"but 'max_length' is set to {generation_config.max_length}. "

            'This can lead to unexpected behavior. You should consider'

            " increasing 'max_new_tokens'.")

    # 2. Set generation parameters if not already defined

    logits_processor = logits_processor if logits_processor is not None \

        else LogitsProcessorList()

    stopping_criteria = stopping_criteria if stopping_criteria is not None \

        else StoppingCriteriaList()

    logits_processor = model._get_logits_processor(

        generation_config=generation_config,

        input_ids_seq_length=input_ids_seq_length,

        encoder_input_ids=input_ids,

        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,

        logits_processor=logits_processor,

    )

    stopping_criteria = model._get_stopping_criteria(

        generation_config=generation_config,

        stopping_criteria=stopping_criteria)

    logits_warper = model._get_logits_warper(generation_config)

    unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

    scores = None

    while True:

        model_inputs = model.prepare_inputs_for_generation(

            input_ids, **model_kwargs)

        # forward pass to get next token

        outputs = model(

            **model_inputs,

            return_dict=True,

            output_attentions=False,

            output_hidden_states=False,

        )

        next_token_logits = outputs.logits[:, -1, :]

        # pre-process distribution

        next_token_scores = logits_processor(input_ids, next_token_logits)

        next_token_scores = logits_warper(input_ids, next_token_scores)

        # sample

        probs = nn.functional.softmax(next_token_scores, dim=-1)

        if generation_config.do_sample:

            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        else:

            next_tokens = torch.argmax(probs, dim=-1)

        # update generated ids, model inputs, and length for next step

        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

        model_kwargs = model._update_model_kwargs_for_generation(

            outputs, model_kwargs, is_encoder_decoder=False)

        unfinished_sequences = unfinished_sequences.mul(

            (min(next_tokens != i for i in eos_token_id)).long())

        output_token_ids = input_ids[0].cpu().tolist()

        output_token_ids = output_token_ids[input_length:]

        for each_eos_token_id in eos_token_id:

            if output_token_ids[-1] == each_eos_token_id:

                output_token_ids = output_token_ids[:-1]

        response = tokenizer.decode(output_token_ids)

        yield response

        # stop when each sentence is finished

        # or if we exceed the maximum length

        if unfinished_sequences.max() == 0 or stopping_criteria(

                input_ids, scores):

            break


 

def on_btn_click():

    del st.session_state.messages


 

@st.cache_resource

def load_model():

    model = (AutoModelForCausalLM.from_pretrained(

        '/share/new_models/Shanghai_AI_Laboratory/internlm2-chat-1_8b',

        trust_remote_code=True).to(torch.bfloat16).cuda())

    tokenizer = AutoTokenizer.from_pretrained(

        #下载的模型在哪里就填写哪里

        '/internlm2-chat-1_8b',

        trust_remote_code=True)

    return model, tokenizer


 

def prepare_generation_config():

    with st.sidebar:

        max_length = st.slider('Max Length',

                               min_value=8,

                               max_value=32768,

                               value=32768)

        top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)

        temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)

        st.button('Clear Chat History', on_click=on_btn_click)

    generation_config = GenerationConfig(max_length=max_length,

                                         top_p=top_p,

                                         temperature=temperature)

    return generation_config


 

user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'

robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'

cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\

    <|im_start|>assistant\n'


 

def combine_history(prompt):

    messages = st.session_state.messages

    meta_instruction = ('You are InternLM (书生·浦语), a helpful, honest, '

                    'and harmless AI assistant developed by Shanghai '

                    'AI Laboratory (上海人工智能实验室).')

    total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'

    for message in messages:

        cur_content = message['content']

        if message['role'] == 'user':

            cur_prompt = user_prompt.format(user=cur_content)

        elif message['role'] == 'robot':

            cur_prompt = robot_prompt.format(robot=cur_content)

        else:

            raise RuntimeError

        total_prompt += cur_prompt

    total_prompt = total_prompt + cur_query_prompt.format(user=prompt)

    return total_prompt


 

def main():

    # torch.cuda.empty_cache()

    print('load model begin.')

    model, tokenizer = load_model()

    print('load model end.')

    st.title('InternLM2-Chat-1.8B')

    generation_config = prepare_generation_config()

    # Initialize chat history

    if 'messages' not in st.session_state:

        st.session_state.messages = []

    # Display chat messages from history on app rerun

    for message in st.session_state.messages:

        with st.chat_message(message['role'], avatar=message.get('avatar')):

            st.markdown(message['content'])

    # Accept user input

    if prompt := st.chat_input('What is up?'):

        # Display user message in chat message container

        with st.chat_message('user'):

            st.markdown(prompt)

        real_prompt = combine_history(prompt)

        # Add user message to chat history

        st.session_state.messages.append({

            'role': 'user',

            'content': prompt,

        })

        with st.chat_message('robot'):

            message_placeholder = st.empty()

            for cur_response in generate_interactive(

                    model=model,

                    tokenizer=tokenizer,

                    prompt=real_prompt,

                    additional_eos_token_id=92542,

                    **asdict(generation_config),

            ):

                # Display robot response in chat message container

                message_placeholder.markdown(cur_response + '▌')

            message_placeholder.markdown(cur_response)

        # Add robot response to chat history

        st.session_state.messages.append({

            'role': 'robot',

            'content': cur_response,  # pylint: disable=undefined-loop-variable

        })

        torch.cuda.empty_cache()


 

if __name__ == '__main__':

    main()

运行streamlit来部署,在对应的虚拟环境中

streamlit run streamlit_demo.py --server.address 127.0.0.1 --server.port 6006

浏览器运行效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Moving Object Segmentation: All You Need Is SAM(and Flow) 论文详解

系列文章目录 文章目录 系列文章目录前言摘要1 引言2 相关工作3 SAM Preliminaries4 帧级分割Ⅰ&#xff1a;以流作为输入5 帧级分割Ⅱ&#xff1a;以流为提示6 序列级掩膜关联7 实验7.1 数据集7.2 评价指标7 .3 实施细节7.4 消融实验7.5 定量结果7 .定性可视化 8 结论致谢附录…

01 - 计算机组成原理与体系结构

文章目录 一&#xff0c;计算机系统硬件基本组成硬件软件 二&#xff0c;CPU的功能与组成功能组成运算器控制器 三&#xff0c;数据表示计算机的基本单位进制转换原码&#xff0c;反码&#xff0c;补码&#xff0c;移码数值表示范围浮点数表示 四&#xff0c;寻址五&#xff0c…

【Unity模型】古代亚洲建筑

在Unity Asset Store上&#xff0c;一款名为"Ancient Asian Buildings Pack"&#xff08;古代亚洲建筑包&#xff09;的3D模型资源包&#xff0c;为广大开发者和设计师提供了一个将古代亚洲建筑风格融入Unity项目的机会。本文将详细介绍这款资源包的特点、使用方式以…

如何选择合适的自动化测试工具!

选择合适的自动化测试工具是一个涉及多方面因素的决策过程。以下是一些关键步骤和考虑因素&#xff0c;帮助您做出明智的选择&#xff1a; 一、明确测试需求和目标 测试范围&#xff1a;确定需要自动化的测试类型&#xff08;如单元测试、集成测试、UI测试等&#xff09;和测试…

AI视频实战教程:DiffIR2VR-Zero-模糊视频8K高清修复技术

〔探索AI的无限可能&#xff0c;微信关注“AIGCmagic”公众号&#xff0c;让AIGC科技点亮生活〕 本文作者&#xff1a;AIGCmagic社区 猫先生 一、简 介 DiffIR2VR-Zero&#xff1a;一种创新的零样本视频恢复技术&#xff0c;该技术利用预训练的图像恢复模型&#xff0c;解决…

C++高性能通信:图形简述高性能中间件Iceoryx

文章目录 1. 概述2. 支持一个发布者多个订阅者2.2 Iceoryx为何不支持多个发布者发布到同一个主题 3. Iceoryx的架构和数据传输示意图3.1 发布者与订阅者的通信机制3.2 零拷贝共享内存通信机制 4. 使用事件驱动机制4.1 WaitSet机制4.2 Listener机制 5. 已知限制6. 参考 1. 概述 …

sci-hub下载不了的文献去哪里获取全文

我们在查找外文文献时经常会用到sci-hub&#xff0c;但sci-hub也有没有收录的文献&#xff0c;遇到这种情况我们可以用另一个途径来获取该文献。 例如这篇Wiley数据库中的文献&#xff1a;Unveiling Gating Behavior in Piezoionic Effect: toward Neuromimetic Tactile Sensin…

Linux服务管理(四)Apache服务

Apache服务 1、基于IP的虚拟主机2、基于IP端口的虚拟主机3、基于域名的虚拟主机4、prefork模式5、worker模式6、event模式7、细说驱动工作模式和MPM&#xff08;多处理模块&#xff09;工作模式 新旧域名都保留&#xff0c;因为旧域名已有一定的知名度和流量&#xff0c;直接下…

Cocos Creator2D游戏开发(8)-飞机大战(6)-炸机

碰撞 飞机与飞机碰撞 子弹与飞机碰撞 ① 设置碰撞矩阵 设置碰撞矩阵,就是设置谁跟谁碰撞(添加Enemy,PlayerBullet,Player) ②设置刚体和碰撞体 两个预制体设置(Enemy和PlayerBullet) 注意点: 1. 都在预制体节点上,不在图片上; 2.碰撞体Collider2D中的Editing悬着好之后可以调整…

C#-读取测序数据的ABI文件并绘制svg格式峰图-施工中

本地环境&#xff1a;win10&#xff0c;visual studio 2022 community 目录 前言问题描述解决思路实现效果 前言 本文是在已有的代码基础上进行的开发&#xff0c;前期已经实现&#xff1a; ABI文件的解析峰图的简单绘制svg绘图 对于1&#xff0c;主要用到之前重写的struct包…

大模型面经之bert和gpt区别

BERT和GPT是自然语言处理&#xff08;NLP&#xff09;领域中的两种重要预训练语言模型&#xff0c;它们在多个方面存在显著的区别。以下是对BERT和GPT区别的详细分析。 一、模型基础与架构 BERT&#xff1a; 全称&#xff1a;Bidirectional Encoder Representations from Trans…

系统移植(九)Linux内核移植(未整理)

文章目录 一、概念二、在linux内核源码的arch/arm/configs目录下生成FSMP1A板子对应的默认配置文件三、将自己编写的驱动通过图形化界面的方式编译到内核的镜像文件uImage中&#xff08;一&#xff09;拷贝myled.c和myled.h文件到linux内核源码的drivers/char目录下&#xff08…

第15周 15.1 Zookeeper简介安装及基础使用

1. Zookeeper介绍 1.1 介绍 1.2 应用场景简介 1.3 zookeeper工作原理 1.4 zookeeper特点

Canva收购Leonardo.ai,增强生成式AI技术能力

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

MyBatis-Plus自动生成代码

目录 前言一. 什么是 MyBatis-Plus1. Mybatis-Plus 的特点2. Mybatis-Plus 结构二. MyBatis-Plus 自动生成步骤1. 数据库准备2. 环境准备(1) 创建一个空的 Spring Boot 工程(2) 导入pom依赖(3) 编辑application.yml文件(4) 在启动类加入 @MapperScan 注解3. 配置代码4. 运行三.…

Hutool SoapClient 调用使用@webservice 发布的webService接口,参数传递为空

一.发布webService接口 &#xff08;1&#xff09;接口声明 import javax.jws.WebService;WebService public interface Calculator {String add(String a, String b);int multi(int a, int b);}&#xff08;2&#xff09;实现方法 import com.maxnerva.cloud.webservice.ser…

初始mybatis

一、J D B C 编程和 ORM 模 型 1.JDBC回顾 加载驱动 &#xff1a;导入JDBC 连接数据库的 jar包&#xff0c;利用CLASS.forName 加载驱动&#xff1b; 获取连接 &#xff1a; 利用 DriverManager 获取 Connection&#xff0c;然后创建 Statement &#xff1b; 执行SQL语句 &…

算法刷题day20|回溯:39. 组合总和、40. 组合总和 II、131. 分割回文串

39. 组合总和 回溯 class Solution { private:vector<vector<int>> result;vector<int> path;void backtracking(vector<int>& candidates, int target, int sum, int startIndex) {if (sum > target) {return;}if (sum target) {result.push…

valideer,一个超强的 Python 库!

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com 大家好&#xff0c;今天为大家分享一个超强的 Python 库 - valideer。 Github地址&#xff1a;https://github.com/podio/valideer 在开发应用程序时&#xff0c;数据验证是一个至关重要的环节。它确保了输入数…

【SpringBoot】5 Swagger

官网 https://swagger.io/ 介绍 Swagger 是一套基于 OpenAPI 规范构建的开源工具&#xff0c;可以帮助开发者实现设计、构建、记录、使用 Rest API。 Swagger 是一款根据 Restful 风格生成的接口开发文档&#xff0c;并且支持做测试的一款中间软件。 Swagger主要包括三部分&…