nlp培训重点-5

news2025/3/10 10:14:42

1. LoRA微调

loader:

# -*- coding: utf-8 -*-

import json
import re
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
"""
数据加载
"""


class DataGenerator:
    def __init__(self, data_path, config):
        self.config = config
        self.path = data_path
        self.index_to_label = {0: '家居', 1: '房产', 2: '股票', 3: '社会', 4: '文化',
                               5: '国际', 6: '教育', 7: '军事', 8: '彩票', 9: '旅游',
                               10: '体育', 11: '科技', 12: '汽车', 13: '健康',
                               14: '娱乐', 15: '财经', 16: '时尚', 17: '游戏'}
        self.label_to_index = dict((y, x) for x, y in self.index_to_label.items())
        self.config["class_num"] = len(self.index_to_label)
        if self.config["model_type"] == "bert":
            self.tokenizer = BertTokenizer.from_pretrained(config["pretrain_model_path"])
        self.vocab = load_vocab(config["vocab_path"])
        self.config["vocab_size"] = len(self.vocab)
        self.load()


    def load(self):
        self.data = []
        with open(self.path, encoding="utf8") as f:
            for line in f:
                line = json.loads(line)
                tag = line["tag"]
                label = self.label_to_index[tag]
                title = line["title"]
                if self.config["model_type"] == "bert":
                    input_id = self.tokenizer.encode(title, max_length=self.config["max_length"], pad_to_max_length=True)
                else:
                    input_id = self.encode_sentence(title)
                input_id = torch.LongTensor(input_id)
                label_index = torch.LongTensor([label])
                self.data.append([input_id, label_index])
        return

    def encode_sentence(self, text):
        input_id = []
        for char in text:
            input_id.append(self.vocab.get(char, self.vocab["[UNK]"]))
        input_id = self.padding(input_id)
        return input_id

    #补齐或截断输入的序列,使其可以在一个batch内运算
    def padding(self, input_id):
        input_id = input_id[:self.config["max_length"]]
        input_id += [0] * (self.config["max_length"] - len(input_id))
        return input_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

def load_vocab(vocab_path):
    token_dict = {}
    with open(vocab_path, encoding="utf8") as f:
        for index, line in enumerate(f):
            token = line.strip()
            token_dict[token] = index + 1  #0留给padding位置,所以从1开始
    return token_dict


#用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):
    dg = DataGenerator(data_path, config)
    dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)
    return dl

if __name__ == "__main__":
    from config import Config
    dg = DataGenerator("valid_tag_news.json", Config)
    print(dg[1])

model:

import torch.nn as nn
from config import Config
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from torch.optim import Adam, SGD

TorchModel = AutoModelForSequenceClassification.from_pretrained(Config["pretrain_model_path"])


def choose_optimizer(config, model):
    optimizer = config["optimizer"]
    learning_rate = config["learning_rate"]
    if optimizer == "adam":
        return Adam(model.parameters(), lr=learning_rate)
    elif optimizer == "sgd":
        return SGD(model.parameters(), lr=learning_rate)

evaluate:

# -*- coding: utf-8 -*-
import torch
from loader import load_data

"""
模型效果测试
"""

class Evaluator:
    def __init__(self, config, model, logger):
        self.config = config
        self.model = model
        self.logger = logger
        self.valid_data = load_data(config["valid_data_path"], config, shuffle=False)
        self.stats_dict = {"correct":0, "wrong":0}  #用于存储测试结果

    def eval(self, epoch):
        self.logger.info("开始测试第%d轮模型效果:" % epoch)
        self.model.eval()
        self.stats_dict = {"correct": 0, "wrong": 0}  # 清空上一轮结果
        for index, batch_data in enumerate(self.valid_data):
            if torch.cuda.is_available():
                batch_data = [d.cuda() for d in batch_data]
            input_ids, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            with torch.no_grad():
                pred_results = self.model(input_ids)[0]
            self.write_stats(labels, pred_results)
        acc = self.show_stats()
        return acc

    def write_stats(self, labels, pred_results):
        # assert len(labels) == len(pred_results)
        for true_label, pred_label in zip(labels, pred_results):
            pred_label = torch.argmax(pred_label)
            # print(true_label, pred_label)
            if int(true_label) == int(pred_label):
                self.stats_dict["correct"] += 1
            else:
                self.stats_dict["wrong"] += 1
        return

    def show_stats(self):
        correct = self.stats_dict["correct"]
        wrong = self.stats_dict["wrong"]
        self.logger.info("预测集合条目总量:%d" % (correct +wrong))
        self.logger.info("预测正确条目:%d,预测错误条目:%d" % (correct, wrong))
        self.logger.info("预测准确率:%f" % (correct / (correct + wrong)))
        self.logger.info("--------------------")
        return correct / (correct + wrong)

 main:

# -*- coding: utf-8 -*-

import torch
import os
import random
import os
import numpy as np
import torch.nn as nn
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
from peft import get_peft_model, LoraConfig, \
    PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfig 


#[DEBUG, INFO, WARNING, ERROR, CRITICAL]
logging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

"""
模型训练主程序
"""


seed = Config["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)



def main(config):
    #创建保存模型的目录
    if not os.path.isdir(config["model_path"]):
        os.mkdir(config["model_path"])
    #加载训练数据
    train_data = load_data(config["train_data_path"], config)
    #加载模型
    model = TorchModel

    #大模型微调策略
    tuning_tactics = config["tuning_tactics"]
    if tuning_tactics == "lora_tuning":
        peft_config = LoraConfig(
            r=8,
            lora_alpha=32,
            lora_dropout=0.1,
            target_modules=["query", "key", "value"]
        )
    elif tuning_tactics == "p_tuning":
        peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
    elif tuning_tactics == "prompt_tuning":
        peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
    elif tuning_tactics == "prefix_tuning":
        peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
    
    
    model = get_peft_model(model, peft_config)
    # print(model.state_dict().keys())

    if tuning_tactics == "lora_tuning":
        # lora配置会冻结原始模型中的所有层的权重,不允许其反传梯度
        # 但是事实上我们希望最后一个线性层照常训练,只是bert部分被冻结,所以需要手动设置
        for param in model.get_submodule("model").get_submodule("classifier").parameters():
            param.requires_grad = True

    # 标识是否使用gpu
    cuda_flag = torch.cuda.is_available()
    if cuda_flag:
        logger.info("gpu可以使用,迁移模型至gpu")
        model = model.cuda()

    #加载优化器
    optimizer = choose_optimizer(config, model)
    #加载效果测试类
    evaluator = Evaluator(config, model, logger)
    #训练
    for epoch in range(config["epoch"]):
        epoch += 1
        model.train()
        logger.info("epoch %d begin" % epoch)
        train_loss = []
        for index, batch_data in enumerate(train_data):
            if cuda_flag:
                batch_data = [d.cuda() for d in batch_data]

            optimizer.zero_grad()
            input_ids, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况
            output = model(input_ids)[0]
            loss = nn.CrossEntropyLoss()(output, labels.view(-1))
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            if index % int(len(train_data) / 2) == 0:
                logger.info("batch loss %f" % loss)
        logger.info("epoch average loss: %f" % np.mean(train_loss))
        acc = evaluator.eval(epoch)
    model_path = os.path.join(config["model_path"], "%s.pth" % tuning_tactics)
    save_tunable_parameters(model, model_path)  #保存模型权重
    return acc

def save_tunable_parameters(model, path):
    saved_params = {
        k: v.to("cpu")
        for k, v in model.named_parameters()
        if v.requires_grad
    }
    torch.save(saved_params, path)


if __name__ == "__main__":
    main(Config)

pred:

import torch
import logging
from model import TorchModel
from peft import get_peft_model, LoraConfig, PromptTuningConfig, PrefixTuningConfig, PromptEncoderConfig

from evaluate import Evaluator
from config import Config


logging.basicConfig(level=logging.INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

#大模型微调策略
tuning_tactics = Config["tuning_tactics"]

print("正在使用 %s"%tuning_tactics)

if tuning_tactics == "lora_tuning":
    peft_config = LoraConfig(
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        target_modules=["query", "key", "value"]
    )
elif tuning_tactics == "p_tuning":
    peft_config = PromptEncoderConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prompt_tuning":
    peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
elif tuning_tactics == "prefix_tuning":
    peft_config = PrefixTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)

#重建模型
model = TorchModel
# print(model.state_dict().keys())
# print("====================")

model = get_peft_model(model, peft_config)
# print(model.state_dict().keys())
# print("====================")

state_dict = model.state_dict()

#将微调部分权重加载
if tuning_tactics == "lora_tuning":
    loaded_weight = torch.load('output/lora_tuning.pth')
elif tuning_tactics == "p_tuning":
    loaded_weight = torch.load('output/p_tuning.pth')
elif tuning_tactics == "prompt_tuning":
    loaded_weight = torch.load('output/prompt_tuning.pth')
elif tuning_tactics == "prefix_tuning":
    loaded_weight = torch.load('output/prefix_tuning.pth')

print(loaded_weight.keys())
state_dict.update(loaded_weight)

#权重更新后重新加载到模型
model.load_state_dict(state_dict)

#进行一次测试
model = model.cuda()
evaluator = Evaluator(Config, model, logger)
evaluator.eval(0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2312636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电子学会—2024年月6青少年软件编程(图形化)四级等级考试真题——水仙花数

水仙花数 如果一个三位数等于它各个数位上的数字的立方和,那么这个数就是水仙花数,例如:153 111 555 333,153就是一个水仙花数。 1.准备工作 (1)保留默认角色小猫; (2)白色背景。 2.功能实现 (1)使用循环遍历所有三位数,把所…

JetBrains学生申请

目录 JetBrains学生免费授权申请 IDEA安装与使用 第一个JAVA代码 1.利用txt文件和cmd命令运行 2.使用IDEA新建项目 JetBrains学生免费授权申请 本教程采用学生校园邮箱申请,所以要先去自己的学校申请校园邮箱。 进入JetBrains官网 点击立即申请,然…

langchain系列(终)- LangGraph 多智能体详解

目录 一、导读 二、概念原理 1、智能体 2、多智能体 3、智能体弊端 4、多智能体优点 5、多智能体架构 6、交接(Handoffs) 7、架构说明 (1)网络 (2)监督者 (3)监督者&…

侯捷 C++ 课程学习笔记:深入理解智能指针

文章目录 每日一句正能量一、引言二、智能指针的核心概念(一)std::unique_ptr(二)std::shared_ptr(三)std::weak_ptr 三、学习心得四、实际应用案例五、总结 每日一句正能量 如果说幸福是一个悖论&#xff…

访问不了 https://raw.githubusercontent.com 怎么办?

修改 Hosts 文件(推荐)​ 原理:通过手动指定域名对应的 IP 地址,绕过 DNS 污染。 步骤: 1、访问 IPAddress.com,搜索 raw.githubusercontent.com,获取当前最新的 IPv4 地址(例如 1…

大模型工程师学习日记(十五):Hugging Face 模型微调训练(基于 BERT 的中文评价情感分析)

1. datasets 库核心方法 1.1. 列出数据集 使用 d atasets 库,你可以轻松列出所有 Hugging Face 平台上的数据集: from datasets import list_datasets# 列出所有数据集 all_datasets list_datasets()print(all_datasets)1.2. 加载数据集 你可以通过 l…

子数组、子串系列(典型算法思想)—— OJ例题算法解析思路

一、53. 最大子数组和 - 力扣&#xff08;LeetCode&#xff09; 算法代码&#xff1a; class Solution { public:int maxSubArray(vector<int>& nums) {// 1. 创建 dp 表// dp[i] 表示以第 i 个元素结尾的子数组的最大和int n nums.size();vector<int> dp(n…

Windows编程----进程的当前目录

进程的当前目录 Windows Api中有大量的函数在调用的时候&#xff0c;需要传递路径。比如创建文件&#xff0c;创建目录&#xff0c;删除目录&#xff0c;删除文件等等。拿创建文件的CreateFile函数做比喻&#xff0c;如果我们要创建的文件路径不是全路径&#xff0c;那么wind…

AVL树的介绍及实现

文章目录 &#xff08;一&#xff09;AVL的概念&#xff08;二&#xff09;AVL树的实现1.AVL树的结构2.AVL树的插入3.AVL树的查找 &#xff08;三&#xff09;检查一棵树是否是AVL树 &#xff08;一&#xff09;AVL的概念 AVL树是一棵高度平衡的二叉搜索树&#xff0c;通过控制…

快速生成viso流程图图片形式

我们在写详细设计文档的过程中总会不可避免的涉及到时序图或者流程图的绘制&#xff0c;viso这个软件大部分技术人员都会使用&#xff0c;但是想要画的好看&#xff0c;画的科学还是比较难的&#xff0c;现在我总结一套比较好的方法可以生成好看科学的viso图(图片格式)。主要思…

【极光 Orbit•STC8A-8H】03. 小刀初试:点亮你的LED灯

【极光 Orbit•STC8H】03. 小刀初试&#xff1a;点亮你的 LED 灯 七律 点灯初探 单片方寸藏乾坤&#xff0c;LED明灭见真章。 端口配置定方向&#xff0c;寄存器值细推敲。 高低电平随心控&#xff0c;循环闪烁展锋芒。 嵌入式门初开启&#xff0c;从此代码手中扬。 摘要 …

OSPF报文分析

OSPF报文分析 组播地址 224.0.0.0&#xff5e;224.0.0.255为预留的组播地址&#xff08;永久组地址&#xff09;&#xff0c;地址224.0.0.0保留不做分配&#xff0c;其它地址供路由协议使用&#xff1b; 224.0.1.0&#xff5e;238.255.255.255为用户可用的组播地址&#xff08;…

MySql性能(9)- mysql的order by的工作原理

全字段排序rowid排序全字段排序和rowid排序 3.1 联合索引优化 3.2 覆盖索引优化优先队列算法优化建议 5.1 修改系统参数 5.2 优化sql 1. 全字段排序 CREATE TABLE t ( id int(11) NOT NULL,city varchar(16) NOT NULL, name varchar(16) NOT NULL, age int(11) NOT NULL,addr v…

爬虫案例七Python协程爬取视频

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、Python协程爬取视频 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 爬虫案例七协程爬取视频 提示&#xff1a;以下是本篇文章正文…

智慧城市智慧社区项目建设方案

一、项目背景 在全球化进程加速的今天&#xff0c;城市化问题日益凸显&#xff0c;传统的城市管理模式已难以满足现代社会对高效、智能化管理的需求。智慧城市和智慧社区的概念应运而生&#xff0c;其核心目标是通过信息技术手段&#xff0c;提升城市资源的利用效率&#xff0…

RabbitMQ高级特性--消息确认机制

目录 一、消息确认 1.消息确认机制 2.手动确认方法 二、代码示例 1. AcknowledgeMode.NONE 1.1 配置文件 1.2 生产者 1.3 消费者 1.4 运行程序 2.AcknowledgeMode.AUTO 3.AcknowledgeMode.MANUAL 一、消息确认 1.消息确认机制 生产者发送消息之后&#xff0c;到达消…

Java EE 进阶:Spring IoCDI

IOC的简单介绍 什么是Spring&#xff1f;Spring是一个开源的框架&#xff0c;让我们的开发更加的简单&#xff0c;我们可以用一句更加具体的话来概括Spring&#xff0c;就是Spring是一个包含众多工具方法的IOC容器。 简单介绍一下IOC&#xff0c;我们之前说过通过ReqestContr…

Java数据结构第二十期:解构排序算法的艺术与科学(二)

专栏&#xff1a;Java数据结构秘籍 个人主页&#xff1a;手握风云 目录 一、常见排序算法的实现 1.1. 直接选择排序 1.2. 堆排序 1.3. 冒泡排序 1.4. 快速排序 一、常见排序算法的实现 1.1. 直接选择排序 每⼀次从待排序的数据元素中选出最小的⼀个元素&#xff0c;存放在…

【算法day5】最长回文子串——马拉车算法

最长回文子串 给你一个字符串 s&#xff0c;找到 s 中最长的 回文 子串。 https://leetcode.cn/problems/longest-palindromic-substring/description/ 算法思路&#xff1a; class Solution { public:string longestPalindrome(string s) {int s_len s.size();string tmp …

《如何排查Linux系统平均负载过高》

【系统平均负载导读】何为系统平均负载&#xff1f;假设一台云服务主机&#xff0c;突然之间响应用户请求的时间变长了&#xff0c;那么这个时候应该如何去排查&#xff1f;带着这个问题&#xff0c;我们对“平均负载”展开深入的探讨和研究。 何为Linux系统的平均负载&#xf…