GPT建模与预测实战

news2025/1/13 13:55:36

代码链接见文末

效果图:

1.数据样本生成方法

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl

其中train.pkl是处理后的文件

因此,我们首先需要执行preprocess.py进行预处理操作,配置参数:

--data_path data/novel --save_path data/train.pkl --win_size 200 --step 200

其中--vocab_file是语料表,一般不用修改,--log_path是日志路径

预处理流程如下:

  • 首先,初始化tokenizer
  • 读取作文数据集目录下的所有文件,预处理后,对于每条数据,使用滑动窗口对其进行截断
  • 最后,序列化训练数据 

代码如下:

# 初始化tokenizer
    tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")#pip install jieba
    eod_id = tokenizer.convert_tokens_to_ids("<eod>")   # 文档结束符
    sep_id = tokenizer.sep_token_id

    # 读取作文数据集目录下的所有文件
    train_list = []
    logger.info("start tokenizing data")
    for file in tqdm(os.listdir(args.data_path)):
        file = os.path.join(args.data_path, file)
        with open(file, "r", encoding="utf8")as reader:
            lines = reader.readlines()
            title = lines[1][3:].strip()    # 取出标题
            lines = lines[7:]   # 取出正文内容
            article = ""
            for line in lines:
                if line.strip() != "":  # 去除换行
                    article += line
            title_ids = tokenizer.encode(title, add_special_tokens=False)
            article_ids = tokenizer.encode(article, add_special_tokens=False)
            token_ids = title_ids + [sep_id] + article_ids + [eod_id]
            # train_list.append(token_ids)

            # 对于每条数据,使用滑动窗口对其进行截断
            win_size = args.win_size
            step = args.step
            start_index = 0
            end_index = win_size
            data = token_ids[start_index:end_index]
            train_list.append(data)
            start_index += step
            end_index += step
            while end_index+50 < len(token_ids):  # 剩下的数据长度,大于或等于50,才加入训练数据集
                data = token_ids[start_index:end_index]
                train_list.append(data)
                start_index += step
                end_index += step

    # 序列化训练数据
    with open(args.save_path, "wb") as f:
        pickle.dump(train_list, f)

2.模型训练过程

 (1) 数据与标签

        在训练过程中,我们需要根据前面的内容预测后面的内容,因此,对于每一个词的标签需要向后错一位。最终预测的是每一个位置的下一个词的token_id的概率。

(2)训练过程

        对于每一轮epoch,我们需要统计该batch的预测token的正确数与总数,并计算损失,更新梯度。

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,
                epoch, args):
    model.train()
    device = args.device
    ignore_index = args.ignore_index
    epoch_start_time = datetime.now()

    total_loss = 0  # 记录下整个epoch的loss的总和
    epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量
    epoch_total_num = 0  # 每个epoch中,预测的word的总数量

    for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
        # 捕获cuda out of memory exception
        try:
            input_ids = input_ids.to(device)
            labels = labels.to(device)
            outputs = model.forward(input_ids, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            loss = loss.mean()

            # 统计该batch的预测token的正确数与总数
            batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)
            # 统计该epoch的预测token的正确数与总数
            epoch_correct_num += batch_correct_num
            epoch_total_num += batch_total_num
            # 计算该batch的accuracy
            batch_acc = batch_correct_num / batch_total_num

            total_loss += loss.item()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            # 进行一定step的梯度累计之后,更新参数
            if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
                # 更新参数
                optimizer.step()
                # 更新学习率
                scheduler.step()
                # 清空梯度信息
                optimizer.zero_grad()

            if (batch_idx + 1) % args.log_step == 0:
                logger.info(
                    "batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(
                        batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))

            del input_ids, outputs

        except RuntimeError as exception:
            if "out of memory" in str(exception):
                logger.info("WARNING: ran out of memory")
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                logger.info(str(exception))
                raise exception

    # 记录当前epoch的平均loss与accuracy
    epoch_mean_loss = total_loss / len(train_dataloader)
    epoch_mean_acc = epoch_correct_num / epoch_total_num
    logger.info(
        "epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))

    # save model
    logger.info('saving model for epoch {}'.format(epoch + 1))
    model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(model_path)
    logger.info('epoch {} finished'.format(epoch + 1))
    epoch_finish_time = datetime.now()
    logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))

    return epoch_mean_loss

(3)部署与网页预测展示

        app.py既是模型预测文件,又能够在网页中展示,这需要我们下载一个依赖库:

pip install streamlit

        

生成下一个词流程,每次只根据当前位置的前context_len个token进行生成:

  • 第一步,先将输入文本截断成训练的token大小,训练时我们采用的200,截断为后200个词
  • 第二步,预测的下一个token的概率,采用温度采样和topk/topp采样

最终,我们不断的以自回归的方式不断生成预测结果

这里指定模型目录 

进入项目路径

执行streamlit run app.py 

 生成效果:

 数据与代码链接:https://pan.baidu.com/s/1XmurJn3k_VI5OR3JsFJgTQ?pwd=x3ci 
提取码:x3ci 

 

         

      

 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1590564.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一个开源跨平台嵌入式USB设备协议:TinyUSB

概述 TinyUSB 是一个用于嵌入式系统的开源跨平台 USB 主机/设备堆栈&#xff0c;设计为内存安全&#xff0c;无需动态分配&#xff0c;线程安全&#xff0c;所有中断事件都被推迟&#xff0c;然后在非 ISR 任务函数中处理。查看在线文档以获取更多详细信息。 源码链接&#xff…

【U8+】打开固定资产卡片提示:运行时错误‘91’,未设置对象变量或with block变量。

【问题描述】 用友U8软件&#xff0c;固定资产模中打开某张卡片后&#xff0c; 提示&#xff1a;运行时错误‘91’&#xff0c;未设置对象变量或with block变量。 Ps&#xff1a;但不是所有卡片打开的时候都会提示&#xff0c;有的正常。 【解决方法】 跟踪数据库后&#xff…

基于单片机的智能模拟路灯控制系统

摘 要: 随着电力资源的紧缺,以及光污染和雾霾天气的影响,更智能化的路灯设计对人们的日常生活意义重大。本文的智能路灯控制系统是基于单片机的控制器,通过介绍该系统相应的硬件设计和软件设计,实现定时开关和依具体情况是否需要来开关路灯和进行亮度调节,并且具有自检功能…

Ubuntu 20.04 设置开启 root 远程登录连接

Ubuntu默认不设置 root 帐户和密码 Ubuntu默认不设置 root 帐户和密码 Ubuntu默认不设置 root 帐户和密码 如有需要&#xff0c;可在设置中开启允许 root 用户登录。具体操作步骤如下&#xff1a; 操作步骤 1、首先使用普通用户登录 2、设置root密码 macw:~$ sudo passwd …

Redis入门到通关之Redis通用命令

文章目录 Redis数据结构介绍Redis 通用命令命令演示KEYSDELEXISTSEXPIRE RedisTemplate 中的通用命令 Redis数据结构介绍 Redis 是一个key-value的数据库&#xff0c;key一般是String类型&#xff0c;不过value的类型多种多样&#xff1a; 贴心小建议&#xff1a;命令不要死…

Python编写一个抽奖小程序,新手入门案例,简单易上手!

“ 本篇文章将以简明易懂的方式引导小白通过Python编写一个简单的抽奖小程序&#xff0c;无需太多的编程经验。通过本文&#xff0c;将学习如何使用Python内置的随机模块实现随机抽奖&#xff0c;以及如何利用列表等基本数据结构来管理和操作参与抽奖的人员名单。无论你是Pytho…

时间序列分析 #ARMA模型的识别与参数估计 #R语言

掌握ARMA模型的识别和参数估计。 原始数据在文末&#xff01;&#xff01;&#xff01; 练习1、 根据某1915-2004年澳大利亚每年与枪支有关的凶杀案死亡率&#xff08;每10万人&#xff09;数据&#xff08;题目1数据.txt&#xff09;&#xff0c;求&#xff1a; 第1小题&…

计算机视觉——DiffYOLO 改进YOLO与扩散模型的抗噪声目标检测

概述 物体检测技术在图像处理和计算机视觉中发挥着重要作用。其中&#xff0c;YOLO 系列等型号因其高性能和高效率而备受关注。然而&#xff0c;在现实生活中&#xff0c;并非所有数据都是高质量的。在低质量数据集中&#xff0c;更难准确检测物体。为了解决这个问题&#xff…

Python实现PDF页面的删除与添加

在处理PDF文档的过程中&#xff0c;我们时常会需要对PDF文档中的页面进行编辑操作的情况&#xff0c;如插入和删除页面。通过添加和删除PDF页面&#xff0c;我们可以增加内容或对不需要的内容进行删除&#xff0c;使文档内容更符合需求。而通过Python实现PDF文档中的插入和删除…

SSL数字证书

SSL数字证书产品提供商主要来自于国外&#xff0c;尤其是美国&#xff0c;原理和使用操作系统一样&#xff0c;区别在于SSL数字证书目前无法替代性&#xff0c;要想达到兼容性99%的机构目前全球才3-4家&#xff0c;目前国内的主流网站主要使用的是国际证书&#xff0c;除了考虑…

深度学习在三维点云处理与三维重建中的应用探索

目录 点云数据处理 数据清洗 数据降噪和简化 数据配准 特征提取 数据增强 数据组织 性能考量 PointNet PointNet 算法问题 改进方法 三维重建 重建算法 架构模块 流程步骤 标记说明 优点和挑战 点云数据处理 数据清洗 去噪&#xff1a;点云数据通常包含噪声…

使用clickhouse-backup备份和恢复数据

作者&#xff1a;俊达 介绍 clickhouse-backup是altinity提供的一个clickhouse数据库备份和恢复的工具&#xff0c;开源项目地址&#xff1a;https://github.com/Altinity/clickhouse-backup 功能上能满足日常数据库备份恢复的需求&#xff1a; 支持单表/全库备份支持备份上…

AI电影创作,AI影视创作全套完整课程

课程下载&#xff1a;https://download.csdn.net/download/m0_66047725/89064240 更多资源下载&#xff1a;关注我。 课程内容&#xff1a; 【试听课】AI发展的现状及对影视行业未来的影响.mp4 0【AI影视创作】流程与基本逻辑_1.mp4 1【AI基础课程】ChatGPT 注册安装流程.…

使用DSP28335在CCS中生成正弦波

DSP芯片支持数学库&#xff0c;那如何通过DSP芯片生成一个正弦波呢&#xff1f;通过几天研究&#xff0c;现在将我的方法分享一下&#xff0c;如有错误&#xff0c;希望大家及时指出&#xff0c;共同进步。 sin函数的调用 首先看下一sin函数 的使用。 //头文件的定义 #includ…

VSCode中 task.json 和 launch.json 的作用和参数解释以及配置教程

前言 由于 VS Code 并不是一个传统意义上的 IDE&#xff0c;所以初学者可能在使用过程中会有很多的疑惑&#xff0c;其中比较常见的一个问题就是 tasks.json和 launch.json两个文件分别有什么作用以及如何配置 tasks.json VSCode 官网提供的 tasks.json 配置教程 使用不同的…

Linux 系统解压缩文件

Linux系统&#xff0c;可以使用unzip命令来解压zip文件 方法如下 1. 打开终端&#xff0c;在命令行中输入以下命令来安装unzip&#xff1a; sudo apt-get install unzip 1 2. 假设你想要将zip文件解压缩到名为"target_dir"的目录中&#xff0c;在终端中切换到目标路…

【线段树】【区间更新】2916. 子数组不同元素数目的平方和 II

算法可以发掘本质&#xff0c;如&#xff1a; 一&#xff0c;若干师傅和徒弟互有好感&#xff0c;有好感的师徒可以结对学习。师傅和徒弟都只能参加一个对子。如何让对子最多。 二&#xff0c;有无限多1X2和2X1的骨牌&#xff0c;某个棋盘若干格子坏了&#xff0c;如何在没有坏…

基于ssm微信小程序的医院挂号预约系统

采用技术 基于ssm微信小程序的医院挂号预约系统的设计与实现~ 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringMVCMyBatis 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 页面展示效果 用户管理 医院管理 医生管理 公告资讯管理 科室信息管…

目前深圳嵌入式单片机就业环境如何?

深圳作为中国的科技创新中心之一&#xff0c;嵌入式行业的就业环境相对较好。我这里有一套嵌入式入门教程&#xff0c;不仅包含了详细的视频讲解&#xff0c;项目实战。如果你渴望学习嵌入式&#xff0c;不妨点个关注&#xff0c;给个评论222&#xff0c;私信22&#xff0c;我在…

DDoS攻击类型与应对措施详解

攻击与防御简介 SYN Flood攻击 原理&#xff1a; SYN Flood攻击利用的是TCP协议的三次握手机制。在正常的TCP连接建立过程中&#xff0c;客户端发送一个SYN&#xff08;同步序列编号&#xff09;报文给服务器&#xff0c;服务器回应一个SYN-ACK&#xff08;同步和确认&#xf…