AIGC笔记--基于PEFT库使用LoRA

news2024/9/23 3:29:46

1--相关讲解

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 在 Stable Diffusion 中的三种应用:原理讲解与代码示例

PEFT-LoRA

2--基本原理

        固定原始层,通过添加和训练两个低秩矩阵,达到微调模型的效果;

3--简单代码

import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, LoraModel
from peft.utils import get_peft_model_state_dict

# 创建模型
class Simple_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(64, 128)
        self.linear2 = nn.Linear(128, 256)
    def forward(self, x: torch.Tensor):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

if __name__ == "__main__":
    # 初始化原始模型
    origin_model = Simple_Model()

    # 配置lora config
    model_lora_config = LoraConfig(
        r = 32, 
        lora_alpha = 32, # scaling = lora_alpha / r 一般来说,lora_alpha的参数初始化为与r相同,即scale=1
        init_lora_weights = "gaussian", # 参数初始化方式
        target_modules = ["linear1", "linear2"], # 对应层添加lora层
        lora_dropout = 0.1
    )

    # Test data
    input_data = torch.rand(2, 64)
    origin_output = origin_model(input_data)

    # 原始模型的权重参数
    origin_state_dict = origin_model.state_dict() 

    # 两种方式生成对应的lora模型,调用后会更改原始的模型
    new_model1 = get_peft_model(origin_model, model_lora_config)
    new_model2 = LoraModel(origin_model, model_lora_config, "default")

    output1 = new_model1(input_data)
    output2 = new_model2(input_data)
    # 初始化时,lora_B矩阵会初始化为全0,因此最初 y = WX + (alpha/r) * BA * X == WX
    # origin_output == output1 == output2

    # 获取lora权重参数,两者在key_name上会有区别
    new_model1_lora_state_dict = get_peft_model_state_dict(new_model1)
    new_model2_lora_state_dict = get_peft_model_state_dict(new_model2)

    # origin_state_dict['linear1.weight'].shape -> [output_dim, input_dim]
    # new_model1_lora_state_dict['base_model.model.linear1.lora_A.weight'].shape -> [r, input_dim]
    # new_model1_lora_state_dict['base_model.model.linear1.lora_B.weight'].shape -> [output_dim, r]
    print("All Done!")

4--权重保存和合并

核心公式是:new_weights = origin_weights + alpha* (BA)

    # 借助diffuser的save_lora_weights保存模型权重
    from diffusers import StableDiffusionPipeline
    save_path = "./"
    global_step = 0
    StableDiffusionPipeline.save_lora_weights(
            save_directory = save_path,
            unet_lora_layers = new_model1_lora_state_dict,
            safe_serialization = True,
            weight_name = f"checkpoint-{global_step}.safetensors",
        )

    # 加载lora模型权重(参考Stable Diffusion),其实可以重写一个简单的版本
    from safetensors import safe_open
    alpha = 1. # 参数融合因子
    lora_path = "./" + f"checkpoint-{global_step}.safetensors"
    state_dict = {}
    with safe_open(lora_path, framework="pt", device="cpu") as f:
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)

    all_lora_weights = []
    for idx,key in enumerate(state_dict):
        # only process lora down key
        if "lora_B." in key: continue

        up_key    = key.replace(".lora_A.", ".lora_B.") # 通过lora_A直接获取lora_B的键名
        model_key = key.replace("unet.", "").replace("lora_A.", "").replace("lora_B.", "")
        layer_infos = model_key.split(".")[:-1]

        curr_layer = new_model1

        while len(layer_infos) > 0:
            temp_name = layer_infos.pop(0)
            curr_layer = curr_layer.__getattr__(temp_name)

        weight_down = state_dict[key].to(curr_layer.weight.data.device)
        weight_up   = state_dict[up_key].to(curr_layer.weight.data.device)
        # 将lora参数合并到原模型参数中 -> new_W = origin_W + alpha*(BA)
        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
        all_lora_weights.append([model_key, torch.mm(weight_up, weight_down).t()])
        print('Load Lora Done')

5--完整代码

PEFT_LoRA

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1706795.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

web自动化-下拉框操作/键鼠操作/文件上传

在我们做UI自动化测试的时候,会有一些元素需要特殊操作,比如下拉框操作/键鼠操作/文件上传。 下拉框操作 在我们很多页面里有下拉框的选择,这种元素怎么定位呢?下拉框分为两种类型:我们分别针对这两种元素进行定位和…

答应我!养猫就一定要入手的七款好物!养猫真的会开心

养猫是一件让人愉悦的事情,猫咪的陪伴能让我们感到温暖和满足。然而,想要让猫咪健康快乐地成长,除了关心它们的饮食和健康,还需要为它们准备一些必要的生活用品。今天,我将为大家推荐几个养猫必备的好物,让…

黑马头条day6总结

1、wemedian错误 一开始没加EnableFeignClients(basePackages "com.heima.apis")导致获取ischeduleClient错误,找不到bean。 我看教程的代码中没有,【ComponentScan({"com.heima.apis","com.heima.wemedia"})】&#x…

11款必备IP地址管理软件,你都用过吗?

1、LightMesh IPAM 产品描述:LightMesh IPAM 是一款功能强大的工具,可简化和自动化互联网协议网络的管理。它提供可扩展性、子网规划器、即时云发现、IP 和网络管理以及 IP 规划和可视化,以帮助您优化效率、可见性和安全性。 特征&#xff1…

强化学习——学习笔记

一、什么是强化学习? 强化学习 (Reinforcement Learning, RL) 是一种通过与环境交互来学习决策策略的机器学习方法。它的核心思想是让智能体 (Agent) 在执行动作 (Action)、观察环境 (Environment) 反馈的状态 (State) 和奖励 (Reward) 的过程中,学习到…

C++音视频开发面试题集锦

老规矩,先上面试题目: 1、iOS 中系统 API 提供了哪些视频编码的方式?2、VideoToolbox 视频帧解码失败以后应该如何重试?3、如何使用 PSNR 对视频转码质量进行评估?4、什么是 VAO,什么是 VBO,它…

【图书推荐】《机器学习实战(视频教学版)》

本书用处 快速入门Python机器学习基础算法。 最后3个综合实战项目(包括新闻内容分类实战、泰坦尼克号获救预测实战、中药数据分析项目实战)可以作为研究可以的素材。 内容简介 本书基于Python语言详细讲解机器学习算法及其应用,用于读者快…

Java 五种内部类演示及底层原理详解

内部类 什么是内部类 在A类的内部定义B类,B类就被称为内部类 发动机类单独存在没有意义 发动机为独立个体 可以在外部其他类里创建内部类的对象去调用方法 类的五大成员 属性 方法 构造方法 代码块 内部类 内部类的访问特点 内部类可以直接访问外部类的成员&a…

Java处理CSV文件示例

Java处理CSV文件示例 1. 导入依赖 <dependency><groupId>org.apache.commons</groupId><artifactId>commons-csv</artifactId><version>1.10.0</version></dependency>文件示例 下面是示例文件文件数据 vscode和idea都有解析…

第二证券股市资讯:连续3天20%涨停!A股这一赛道,牛股批量出现!

今日&#xff0c;A股小幅轰动调整&#xff0c;上证指数下试3100点支撑。 两市成交7453亿元&#xff0c;创近4个月来新低&#xff0c;超4000只个股下跌。盘面上&#xff0c;电力、芯片、煤炭、石油等板块涨幅居前&#xff0c;铜缆高速衔接、房地产、工程机械、网络游戏等板块跌幅…

word-主控文档、文档拆分及标书编写技巧建议

一、主控文档 视图-大纲视图-显示文档-插入子文档 子文档一旦更新&#xff0c;主文档也会更新。更新主文档&#xff0c;子文档也会更新 需要注意&#xff0c;不可修改子文档名字 二、上交文件 显示文档-折叠子文档-只显示一级-取消链接-关闭大纲视图-保存 三、文档拆分 根…

Transformer 从attention到grouped query attention (GQA)

Attention原理和理解 attention原理参考&#xff1a; Attention Is All You Need The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time. Transformer图解 - 李理的博客 Attention首先对输入x张量乘以WQ, WK, WV得到query,…

本地开发正常 线上CI/CD构建项目过程报错文件未能正确引用

问题快照 原因分析&#xff1a; 一般遇到这样的错误就是 文件路径或者文件名称未能正确匹配 或者文件不存在 会报这样的错误 以为很好解决 但这次 都排查 了 就是 没发现原因 不管怎么说还是要感谢 GPT的能力(分析问题的能力) 先上图 当我看到 第四步的时候 我立马 去仓库里查…

没开玩笑!高速信号不能参考电源网络这条规则,其实很难做到

高速先生成员--黄刚 看到这篇文章的题目&#xff0c;我相信大家心里都呈现出了这么一个场景&#xff1a;高速信号线在L20层&#xff0c;我只要把L19和L21层都铺上完整的地平面&#xff0c;这不就满足了高速信号线不能参考电源平面这条规则了吗&#xff1f;这难道很难做到吗&…

Windows 使用技巧

Windows 使用技巧 ①局域网内共享文件 ②CTRL Y 和 CTRL Z ①局域网内共享文件 第一步&#xff1a; 选择要共享的文件&#xff08;分享方操作&#xff09; 第二步&#xff1a; 右键打开属性&#xff0c;选择共享&#xff08;分享方操作&#xff09; 第三步&#xff1a; …

Spring使用的设计模式

Spring 框架是一个广泛使用的 Java 框架&#xff0c;它内部使用了多种设计模式来简化开发过程、提高代码的可维护性和扩展性。 以下是一些在 Spring 框架中常见的设计模式&#xff0c;以及用代码示例来解释它们&#xff1a; 一、工厂模式&#xff08;Factory Pattern&#xff…

C#开发上位机应用:基础与实践

C#是一种流行的面向对象编程语言&#xff0c;常用于Windows应用程序的开发。上位机应用是一种用于监控和控制设备或系统的应用程序&#xff0c;通常与下位机&#xff08;如传感器、执行器等&#xff09;进行通信。在本文中&#xff0c;我们将介绍C#开发上位机应用的基础知识和实…

Vue3 之 动态组件和KeepAlive组件

一、动态组件 1、简介 ​ 在某些业务场景下&#xff0c;页面的某模块具有多个组件但在同一时间只显示一个&#xff0c;需要在多个组件之间进行频繁的切换&#xff0c;如&#xff1a;tab切换等场景。除了可以使用v-if、v-show根据不同条件显示不同组件之外&#xff0c;还可以通…

Element-Plus中表格及分页功能

导入Element-Plus 具体步骤如下&#xff1a;&#xff08;内容参照官网&#xff1a;安装 | Element Plus&#xff09; # 选择一个你喜欢的包管理器# NPM $ npm install element-plus --save# Yarn $ yarn add element-plus# pnpm $ pnpm install element-plus 在main.js文件的…

【论文阅读笔记】The Google File System

1 简介 Google File System (GFS) 是一个可扩展的分布式文件系统&#xff0c;专为快速增长的Google数据处理需求而设计。这篇论文发表于2003年&#xff0c;此前已在Google内部大规模应用。 GFS不仅追求性能、可伸缩性、可靠性和可用性等传统分布式文件系统的设计目标&#xf…