基于ChatYuan-large-v2 语言模型 Fine-tuning 微调训练 广告生成 任务

news2024/12/23 0:04:09

一、ChatYuan-large-v2

ChatYuan-large-v2是一个开源的支持中英双语的功能型对话语言大模型,与其他 LLM 不同的是模型十分轻量化,并且在轻量化的同时效果相对还不错,仅仅通过0.7B参数量就可以实现10B模型的基础效果,正是其如此的轻量级,使其可以在普通显卡、 CPU、甚至手机上进行推理,而且 INT4 量化后的最低只需 400M

v2 版本相对于以前的 v1 版本,是使用了相同的技术方案,但在指令微调、人类反馈强化学习、思维链等方面进行了优化。

在本专栏前面文章介绍了 ChatYuan-large-v2langchain 相结合的使用,地址如下:

LangChain 本地化方案 - 使用 ChatYuan-large-v2 作为 LLM 大语言模型

本篇文章以 ChatYuan-large-v2 模型为基础 Fine-tuning 广告生成 任务。

二、数据集处理

数据集这里使用 ChatGLM 官方在 Fine-tuning 中使用到的 广告生成 数据集。

下载地址如下:

https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view

数据已 JSON 的形式存放,分为了 traindev 两种类型:

在这里插入图片描述
数据格式如下所示:

{
    "content":"类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤",
    "summary":"宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"
}
{
    "content":"类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤",
    "summary":"宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"
}
{
    "content":"类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖",
    "summary":"圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆,深邃富有诗意。收腰包臀,修饰女性身体曲线,结合别出心裁的鱼尾裙摆设计,勾勒出自然流畅的身体轮廓,展现了婀娜多姿的迷人姿态。"
}
{
    "content":"类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款式#不规则",
    "summary":"宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不规则剪裁设计,彰显出时尚前卫的形态。被剪裁过的样式呈现出布条状自然地垂坠下来,别具有一番设计感。线条分明的字母样式有着花式的外观,棱角分明加上具有少女元气的枣红色十分有年轻活力感。粉红色的衣身把肌肤衬托得很白嫩又健康。"
}
{
    "content":"类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙",
    "summary":"踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄托在随风摇曳的雪纺连衣裙上,吐露出<UNK>微妙而又浪漫的清新之意。宽松的a字版型除了能够带来足够的空间,也能以上窄下宽的方式强化立体层次,携带出自然优雅的曼妙体验。"
}
{
    "content":"类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短袖*衣款式#拼接",
    "summary":"想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔软纯棉面料,让您紧跟时尚潮流。再配合上潮流的蓝色拼接设计,使您的风格更加出众。就算单从选料上来说,这款polo衫的颜色沉稳经典,是这个季度十分受大众喜爱的风格了,而且兼具舒适感和时尚感。"
}

其中任务的方式为根据输入(content)生成一段广告词(summary)。

train.json 共有 114599 条记录,这里为了演示效果取前 50000 条数据进行训练、5000 条数据进行验证:

import os

# 将训练集进行提取
def doHandle(json_path, train_size, val_size, out_json_path):
    train_count = 0
    val_count = 0

    train_f = open(os.path.join(out_json_path, "train.json"), "a", encoding='utf-8')
    val_f = open(os.path.join(out_json_path, "val.json"), "a", encoding='utf-8')

    with open(json_path, "r", encoding='utf-8') as f:
        for line in f:
            if train_count < train_size:
                train_f.writelines(line)
                train_count = train_count + 1
            elif val_count < val_size:
                val_f.writelines(line)
                val_count = val_count + 1
            else:
                break

    print("数据处理完毕!")
    train_f.close()
    val_f.close()


if __name__ == '__main__':
    json_path = "./data/AdvertiseGen/train.json"
    out_json_path = "./data/"
    train_size = 50000
    val_size = 5000
    doHandle(json_path, train_size, val_size, out_json_path)

处理之后可以看到两个生成的文件:

在这里插入图片描述

下面基于上面的数据格式构建 Dataset

from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
import json


class SummaryDataSet(Dataset):

    def __init__(self, json_path: str, tokenizer, max_length=300):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.content_data = []
        self.summary_data = []
        with open(json_path, "r", encoding='utf-8') as f:
            for line in f:
                if not line or line == "":
                    continue
                json_line = json.loads(line)
                content = json_line["content"]
                summary = json_line["summary"]
                self.content_data.append(content)
                self.summary_data.append(summary)
        print("data load , size:", len(self.content_data))

    def __len__(self):
        return len(self.content_data)

    def __getitem__(self, index):
        source_text = str(self.content_data[index])
        target_text = str(self.summary_data[index])

        source = self.tokenizer.batch_encode_plus(
            [source_text],
            max_length=self.max_length,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        target = self.tokenizer.batch_encode_plus(
            [target_text],
            max_length=self.max_length,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "source_ids": source_ids.to(dtype=torch.long),
            "source_mask": source_mask.to(dtype=torch.long),
            "target_ids": target_ids.to(dtype=torch.long),
            "target_ids_y": target_ids.to(dtype=torch.long),
        }

三、模型训练

下载 ChatYuan-large-v2 模型:

https://huggingface.co/ClueAI/ChatYuan-large-v2/tree/main

在这里插入图片描述
在这里插入图片描述

下面基于 ChatYuan-large-v2 进行训练:

import pandas as pd
import torch
from torch.utils.data import DataLoader
import os, time
from transformers import T5Tokenizer, T5ForConditionalGeneration
from gen_dataset import SummaryDataSet


def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    time1 = time.time()
    for _, data in enumerate(loader, 0):
        y = data["target_ids"].to(device, dtype=torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data["source_ids"].to(device, dtype=torch.long)
        mask = data["source_mask"].to(device, dtype=torch.long)

        outputs = model(
            input_ids=ids,
            attention_mask=mask,
            decoder_input_ids=y_ids,
            labels=lm_labels,
        )
        loss = outputs[0]
        # 每100步打印日志
        if _ % 100 == 0 and _ != 0:
            time2 = time.time()
            print(_, "epoch:" + str(epoch) + "-loss:" + str(loss) + ";each step's time spent:" + str(
                float(time2 - time1) / float(_ + 0.0001)))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def validate(epoch, tokenizer, model, device, loader, max_length):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            y = data['target_ids'].to(device, dtype=torch.long)
            ids = data['source_ids'].to(device, dtype=torch.long)
            mask = data['source_mask'].to(device, dtype=torch.long)

            generated_ids = model.generate(
                input_ids=ids,
                attention_mask=mask,
                max_length=max_length,
                num_beams=2,
                repetition_penalty=2.5,
                length_penalty=1.0,
                early_stopping=True
            )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
                     generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
            if _ % 1000 == 0:
                print(f'Completed {_}')

            predictions.extend(preds)
            actuals.extend(target)
    return predictions, actuals


def T5Trainer(train_json_path, val_json_path, model_dir, batch_size, epochs, output_dir, max_length=300):

    tokenizer = T5Tokenizer.from_pretrained(model_dir)
    model = T5ForConditionalGeneration.from_pretrained(model_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    train_params = {
        "batch_size": batch_size,
        "shuffle": True,
        "num_workers": 0,
    }
    training_set = SummaryDataSet(train_json_path, tokenizer, max_length=max_length)
    training_loader = DataLoader(training_set, **train_params)

    val_params = {
        "batch_size": batch_size,
        "shuffle": False,
        "num_workers": 0,
    }

    val_set = SummaryDataSet(val_json_path, tokenizer, max_length=max_length)
    val_loader = DataLoader(val_set, **val_params)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        train(epoch, tokenizer, model, device, training_loader, optimizer)
        print("保存模型")
        model.save_pretrained(output_dir)
        tokenizer.save_pretrained(output_dir)
        # 验证
        with torch.no_grad():
            predictions, actuals = validate(epoch, tokenizer, model, device, val_loader, max_length)
            # 验证结果存储
            final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
            final_df.to_csv(os.path.join(output_dir, "predictions.csv"))


if __name__ == '__main__':
    train_json_path = "./data/train.json"
    val_json_path = "./data/val.json"
    # 下载模型目录位置
    model_dir = "chatyuan_large_v2"
    batch_size = 5
    epochs = 1
    max_length = 300
    output_dir = "./model"
	
	# 开始训练
    T5Trainer(
        train_json_path,
        val_json_path,
        model_dir,
        batch_size,
        epochs,
        output_dir,
        max_length
    )

运行后可以看到如下日志打印,训练大概占用 33G 的显存,如果显存不够可以调低些 batch_size 的大小:

在这里插入图片描述

等待训练结束后:

在这里插入图片描述

可以在 model 下看到保存的模型:

在这里插入图片描述

这里可以先看下 predictions.csv 验证集的效果:

在这里插入图片描述

可以看到模型生成的结果有点不太好,这里仅对前 50000 条进行了训练,并且就训练了一个 epoch ,后面可以增加数据集大小和增加 epoch 应该能达到更好的效果,下面通过调用模型测试一下生成的文本效果。

四、模型测试

# -*- coding: utf-8 -*-
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

# 这里是模型下载的位置
model_dir = './model'

tokenizer = T5Tokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)

while True:
    text = input("请输入内容: \n ")
    if not text or text == "":
        continue
    if text == "q":
        break

    encoded_input = tokenizer(text, padding="max_length", truncation=True, max_length=300)
    input_ids = torch.tensor([encoded_input['input_ids']])
    attention_mask = torch.tensor([encoded_input['attention_mask']])

    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=300,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True
    )

    reds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
            generated_ids]

    print(reds)

效果测试:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/841347.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于YOLOv7开发构建MSTAR雷达影像目标检测系统

MSTAR&#xff08;Moving and Stationary Target Acquisition and Recognition&#xff09;数据集是一个基于合成孔径雷达&#xff08;Synthetic Aperture Radar&#xff0c;SAR&#xff09;图像的目标检测和识别数据集。它是针对目标检测、机器学习和模式识别算法的研究和评估…

Visual Studio 2019 详细安装教程(图文版)

前言 Visual Studio 2019 安装包的下载教程、安装教程 教程 博主博客链接&#xff1a;https://blog.csdn.net/m0_74014525 关注博主&#xff0c;后期持续更新系列文章 ********文章附有百度网盘安装包链接********* 系列文章 第一篇&#xff1a;Visual Studio 2019 详细安装教…

MobiSys 2023 | 多用户心跳监测的双重成形声学感知

注1:本文系“无线感知论文速递”系列之一,致力于简洁清晰完整地介绍、解读无线感知领域最新的顶会/顶刊论文(包括但不限于 Nature/Science及其子刊; MobiCom, Sigcom, MobiSys, NSDI, SenSys, Ubicomp; JSAC, 雷达学报 等)。本次介绍的论文是:<<MobiSys’23,Multi-User A…

特殊符号的制作 台风 示例 使用第三方工具 Photoshop 地理信息系统空间分析实验教程 第三版

特殊符号的制作 首先这是一个含有字符的&#xff0c;使用arcgis自带的符号编辑器制作比较困难。所以我们准备采用Adobe Photoshop 来进行制作符号&#xff0c;然后直接导入符号的图片文件作为符号 我们打开ps&#xff0c;根据上面的图片的像素长宽比&#xff0c;设定合适的高度…

高中生python零基础怎么学,python高中生自学行吗

这篇文章主要介绍了高中学历学python好找工作吗&#xff0c;具有一定借鉴价值&#xff0c;需要的朋友可以参考下。希望大家阅读完这篇文章后大有收获&#xff0c;下面让小编带着大家一起了解一下。 学习python的第九天 根据我们前面这几天的学习&#xff0c;我们掌握了Python的…

nginx环境部署

目录 一、yum安装 二、源码安装 三、测试结果 一、yum安装 1、先查找本地yum源上有没有nginx包 yum list | grep nginx 2、rpm安装 rpm -Uvh http://nginx.org/packages/centos/7/x86_64/RPMS/nginx-1.14.2-1.el7_4.ngx.x86_64.rpm 3、查看安装是否成功 rpm -pa | grep…

Pytorch迁移学习使用MobileNet v3网络模型进行猫狗预测二分类

目录 1. MobileNet 1.1 MobileNet v1 1.1.1 深度可分离卷积 1.1.2 宽度和分辨率调整 1.2 MobileNet v2 1.2.1 倒残差模块 1.3 MobileNet v3 1.3.1 MobieNet V3 Block 1.3.2 MobileNet V3-Large网络结构 1.3.3 MobileNet V3预测猫狗二分类问题 送书活动 1. MobileNet …

【力扣每日一题】2023.8.6 两两交换链表中的节点

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们一个链表&#xff0c;让我们两两交换相邻节点的值&#xff0c;并且不能通过修改节点内部的值来达到这一目的&#xff0c;如果可…

NERFS 与现实捕捉 - 弥合现实世界与数字世界之间的差距

NERF介绍 近年来&#xff0c;计算机视觉和图形领域取得了显着的进步&#xff0c;催生了革命性的技术&#xff0c;改变了各个行业。 NERFS&#xff08;神经辐射场&#xff09;和现实捕捉是两项备受关注的重要技术。 NERFS 和现实捕捉都是以数字形式捕捉和重建现实世界的强大工具…

无涯教程-Perl - dbmclose函数

描述 此函数关闭哈希和DBM文件之间的绑定。将领带功能与合适的模块配合使用。 语法 以下是此函数的简单语法- dbmclose HASH返回值 如果失败,此函数返回0,如果成功,则返回1。 请注意,在大型DBM文件上使用键和值之类的功能时,它们可能会返回巨大的列表。您可能更喜欢使用e…

HBase-写流程

写流程顺序正如API编写顺序&#xff0c;首先创建HBase的重量级连接 &#xff08;1&#xff09;读取本地缓存中的Meta表信息&#xff1b;&#xff08;第一次启动客户端为空&#xff09; &#xff08;2&#xff09;向ZK发起读取Meta表所在位置的请求&#xff1b; &#xff08;…

chapter13:springboot与任务

Spring Boot与任务视频 1. 异步任务 使用注解 Async 开启一个异步线程任务&#xff0c; 需要在主启动类上添加注解EnableAsync开启异步配置&#xff1b; Service public class AsyncService {Asyncpublic void hello() {try {Thread.sleep(3000);} catch (InterruptedExcept…

为什么很多人都在吹ChatGPT改变世界?一文全面了解

ChatGPT早已让全世界的互联网炸锅&#xff0c;它已经从一个新颖的聊天机器人演变成一项推动下一个创新时代到来的技术。已经很久没有出现一款让大家充满兴趣、兴奋、恐惧且具有争议的科技产品了。 如果你现在才接触到它&#xff0c;你可能会想知道这到底是怎么回事。这里建议你…

linux gcc __attribute__

__attribute__ 1. 函数属性1.1 __attribute__((noreturn))1.2 __attribute__((format))1.3 __attribute__((const)) 2. 变量属性2.1. __attribute__((aligned))2.2. __attribute__((packed)) 3. 类型属性 __attribute__ 是 GCC 编译器提供的一种特殊语法&#xff0c;它可以用于…

测试该知道的二三事:浅谈响应式网页设计

1.起因 最近几天正巧在帮朋友的公司团队做质量保障体系的培训&#xff0c;在此期间与几个测试人员闲聊&#xff0c;正是其中的一件事让我对今天的话题提起了兴趣&#xff1a;朋友公司里的研发团队招了一个应届毕业生&#xff0c;做了半年之后接了某个web项目的其中一个拓展功能…

质检工具(FindBugs、CheckStyle、Junit、Jmeter、Apifox)

1、Findbugs IDEA软件中可以装该插件,2018版本以前主要搜索FindBugs-IDEA 、2018版本以后主要搜索 SpotBugs。 1.1、FindBugs-IDEA安装及使用流程: 1.2、SpotBugs安装及使用流程: 2、Checkstyle IDEA软件中可以装该插件,所有版本的插件一致:CheckStyle 2.1、安装流程…

【C# 基础精讲】为什么选择C# ?

C#&#xff08;C Sharp&#xff09;是由微软开发的一种通用、面向对象的编程语言。它最初于2000年发布&#xff0c;自那时以来逐渐成为开发者的首选之一。C#的设计目标是提供一种简单、现代、可靠且安全的编程语言&#xff0c;使开发者能够轻松构建各种类型的应用程序。 为什么…

Apache Doris 助力中国联通万亿日志数据分析提速 10 倍

本文导读&#xff1a; 在数据安全管理体系的背后&#xff0c;离不开对安全日志数据的存储与分析。以终端设备为例&#xff0c;中国联通每天会产生百亿级别的日志数据&#xff0c;对于保障网络安全、提高系统稳定性和可靠性具有至关重要的作用。目前&#xff0c;Apache Doris 在…

解决树莓派“由于没有公钥,无法验证下列签名“

目录 简介&#xff1a;在换完国内源后&#xff0c;树莓派尝试更新同步/etc/apt/sources.list和/etc/apt/sources.list.d中列出的软件源的软件包版本也就是&#xff08;apt-get update&#xff09;和更新已安装的所有或者指定软件包&#xff08;也即是apt-get upgrade&#xff0…

java+springboot摄影作品竞赛报名系统 微信小程序--论文

随着Internet的发展&#xff0c;人们的日常生活已经离不开网络。未来人们的生活与工作将变得越来越数字化&#xff0c;网络化和电子化。网上管理&#xff0c;它将是直接管理摄影竞赛小程序的最新形式。本小程序是以构建摄影竞赛为目标&#xff0c;使用java技术制作&#xff0c;…