Stable Diffusion的微调方法原理总结

news2025/2/26 19:47:06

目录

1、Textural Inversion(简易)

2、DreamBooth(完整)

3、LoRA(灵巧)

4、ControlNet(彻底)

5、其他


1、Textural Inversion(简易)

        不改变网络结构,仅改变CLIP中token embedding的字典。在字典中新增一个伪词的embedding,fine-tune这个embedding的值。其他所有可调参数都冻结。

优点:训练量极小,需要的素材就是一张图。完全不改变神经网络中的任何参数。

缺点:效果一般。

TI的简洁激发了很多研究者的灵感,基于TI思路的研究出现了很多。

2、DreamBooth(完整)

        具体做法是,加入一个新词(sks)代表subject,embedding初始值继承原类型的词的embedding。调整了模型中全部可调参数,彻底的让模型学会subject。损失函数加入了监督功能,去监控漂移现象,防止灾难性遗忘“学会新的忘了旧的”。

在LoRA出现前,训练DreamBooth是潮流,但代价较大。

3、LoRA(灵巧)

        LoRA的网络是一种additional network,LoRA训练不改变基础模型的任何参数,只对附加网络内部参数进行调整。在生成图像时,附加网络输出与原网络输出融合,从而改变生成效果。

        由于LoRA是将矩阵压缩到低秩后训练,所以LoRA网络的参数量很小(千分之一),训练速度快。实验发现,低维矩阵对高维矩阵的替代损失不大。所以即便训练的矩阵小,训练效果仍然很好,已成为一种customization image generation范式。LoRA后来在结构上改进出不同的版本,例如LoHA,LyCORIS等。

LoRA详解:https://zhuanlan.zhihu.com/p/632159261

Self-Attention的LoRA微调代码:GitHub - owenliang/pytorch-diffusion: pytorch复现stable diffusion

代码分析:

用于替换的线性层 (Wq, Wk, Wv矩阵):

class CrossAttention(nn.Module):
    def __init__(self,channel,qsize,vsize,fsize,cls_emb_size):
        super().__init__()
        # Wq, Wk, Wv 矩阵使用LoRA微调降低参数量, W + WA * WB
        self.w_q=nn.Linear(channel,qsize)
        self.w_k=nn.Linear(cls_emb_size,qsize)
        self.w_v=nn.Linear(cls_emb_size,vsize)
        self.softmax=nn.Softmax(dim=-1)
        self.z_linear=nn.Linear(vsize,channel)
        self.norm1=nn.LayerNorm(channel)
        # feed-forward结构
        self.feedforward=nn.Sequential(
            nn.Linear(channel,fsize),
            nn.ReLU(),
            nn.Linear(fsize,channel)
        )
        self.norm2=nn.LayerNorm(channel)

找到模型中所有的Wq, Wk, Wv线性层并将其替换为Lora:

if __name__=='__main__':   # 加入LoRA微调的训练过程
    # 预训练模型
    model=torch.load('model.pt')

    # 向nn.Linear层注入Lora
    for name,layer in model.named_modules():
        name_cols=name.split('.')
        # 过滤出cross attention使用的linear权重
        filter_names=['w_q','w_k','w_v']
        if any(n in name_cols for n in filter_names) and isinstance(layer,nn.Linear):   # module名字中存在w_q, w_k, w_v且属于线性层
            # print(name)   # enc_convs.0.crossattn.w_q,enc_convs.0.crossattn.w_k,enc_convs.0.crossattn.w_v,……
            inject_lora(model,name,layer)

Lora具体实现与替换过程:

# Lora实现,封装linear,替换到父module里
class LoraLayer(nn.Module):
    def __init__(self,raw_linear,in_features,out_features,r,alpha):
        super().__init__()
        self.r=r   # 秩数
        self.alpha=alpha   # LoRA分支的权重比例系数
        self.lora_a=nn.Parameter(torch.empty((in_features,r)))   # 可训练参数
        self.lora_b=nn.Parameter(torch.zeros((r,out_features)))
    
        nn.init.kaiming_uniform_(self.lora_a,a=math.sqrt(5))   # WA 矩阵参数需要进行初始化

        self.raw_linear=raw_linear   # 原始模型权重 W
    
    def forward(self,x):    # x:(batch_size,in_features)
        raw_output=self.raw_linear(x)   
        lora_output=x@((self.lora_a@self.lora_b)*self.alpha/self.r)    # LoRA分支:x * (WA * WB * α/r)
        return raw_output+lora_output   # W + LoRA

def inject_lora(model,name,layer):
    name_cols=name.split('.')   # [enc_convs, 0, crossattn, w_q]

    # 逐层下探到linear归属的module
    children=name_cols[:-1]   # [enc_convs, 0, crossattn]
    cur_layer=model 
    for child in children:
        cur_layer=getattr(cur_layer,child)   # 逐层深入得到w_q, w_k, w_v层的属性
    
    #print(layer==getattr(cur_layer,name_cols[-1]))
    lora_layer=LoraLayer(layer,layer.in_features,layer.out_features,LORA_R,LORA_ALPHA)
    setattr(cur_layer,name_cols[-1],lora_layer)   # 把 crossattn 的 w_q/w_k/w_v层 的属性替换为LoraLayer

模型训练过程:冻结非Lora分支的所有参数

    # lora权重的加载
    try:
        restore_lora_state=torch.load('lora.pt')   # 加载训练好的Lora权重(lora_a, lora_b矩阵),enc_convs.0.crossattn.w_q.lora_a等
        model.load_state_dict(restore_lora_state,strict=False)
    except:
        pass 

    model=model.to(DEVICE)

    # 冻结非Lora参数
    for name,param in model.named_parameters():
        if name.split('.')[-1] not in ['lora_a','lora_b']:  # 非LoRA部分不计算梯度
            param.requires_grad=False
        else:
            param.requires_grad=True

模型推理过程:将Lora分支参数合并到原始模型参数中(相加)

if __name__=='__main__':
    # 加载模型
    model=torch.load('model.pt')

    USE_LORA=True

    if USE_LORA:   # 使用LoRA推理
        # 把Linear层替换为Lora
        for name,layer in model.named_modules():
            name_cols=name.split('.')
            # 过滤出cross attention使用的linear权重
            filter_names=['w_q','w_k','w_v']
            if any(n in name_cols for n in filter_names) and isinstance(layer,nn.Linear):
                inject_lora(model,name,layer)

        # lora权重的加载
        try:
            restore_lora_state=torch.load('lora.pt')
            model.load_state_dict(restore_lora_state,strict=False)
        except:
            pass 

        model=model.to(DEVICE)

        # lora权重合并到主模型(把LoRA权重加到原始模型权重中)
        for name,layer in model.named_modules():
            name_cols=name.split('.')

            if isinstance(layer,LoraLayer):   # 找到模型中所有的 LoraLayer 层
                children=name_cols[:-1]
                cur_layer=model 
                for child in children:
                    cur_layer=getattr(cur_layer,child)    # cur_layer = cross attention对象(包含修改过的wq, wk, wv)
                lora_weight=(layer.lora_a@layer.lora_b)*layer.alpha/layer.r   # 计算得到lora分支权重
                before_weight=layer.raw_linear.weight.clone()   # 原始模型权重W
                layer.raw_linear.weight=nn.Parameter(layer.raw_linear.weight.add(lora_weight.T)).to(DEVICE)    # 把Lora参数加到base model的linear weight上
                setattr(cur_layer,name_cols[-1],layer.raw_linear)   # 使用新的合并分支替换原来的两分支Lora结构

4、ControlNet(彻底)

        将神经网络快的不同权重,分别复制到“锁定”副本(locked copy)和“可训练”副本(trainable copy)中。按制定规则集成原图特征并生成新的内容,不会导致生成图和原图看起来毫无关系。

5、其他

  • Custom Diffusion基本建立在DreamBooth的基础上,通过消融实验证明了即使只训练交叉注意力层中的部分矩阵,也有非常好的fine-tune效果,不需要像DreamBooth那样全部参数调整。这种思路也引领了后续的一系列研究,但DreamBooth仍然是当时的范式。
  • 与ControlNet同期有一种方法叫做T-2-l adapter,微调的参数更少,效果较CN差些,比CN发布晚了一点,被ControlNet的光芒遮挡了。
  • LORA的典型修改方案是LyCORIS,这个以二次元人物命名的方法把LoRA的思想应用在卷积层做改进,并且结合了一些其他算法进行了参数调整。
  • 微调方法只是打包起来的tricks。模型建模研究是建构的过程,而不是发现的过程,有很大的自由度,不要被已有做法的说法限制自己的想象。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2064431.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ciallo~(∠・ω・ )⌒☆第二十五篇 Redis

Redis 是一个高性能的键值存储数据库,它能够在内存中快速读写数据,并且支持持久化到磁盘。它被广泛应用于缓存、队列、实时分析等场景。 一、启动redis服务器 要打开redis服务器,需要在终端中输入redis-server命令。确保已经安装了redis&…

【Java】/* 链式队列 和 循环队列 - 底层实现 */

一、链式队列 1. 使用双向链表实现队列,可以采用尾入,头出 也可以采用 头入、尾出 (LinkedList采用尾入、头出) 2. 下面代码实现的是尾入、头出: package bageight;/*** Created with IntelliJ IDEA.* Description:* User: tangyuxiu* Date: …

mOTA v2.0

mOTA v2.0 一、简介 本开源工程是一款专为 32 位 MCU 开发的 OTA 组件,组件包含了 bootloader 、固件打包器 (Firmware_Packager) 、固件发送器 三部分,并提供了基于多款 MCU (STM32F1 / STM32F407 / STM32F411 / STM32L4) 和 YModem-1K 协议的案例。基…

【文献及模型、制图分享】2000—2020年中国青饲料播种面积及供需驱动因素的时空格局

文献介绍 高产、优质的青饲料对于国家畜牧业发展和食物供给至关重要。然而,当前对于青饲料播种面积时空变化格局及其阶段性特征、区域差异以及影响因素等尚未清楚。 本文基于省级面板数据分析了2000—2020年青饲料种植的时空格局变化,结合MODIS-NPP产品…

Nginx 405 not allowed

问题原因:nginx不允许静态文件被post请求 解决:添加error_page 405 200 $request_uri;

白酒与家庭:团圆时刻的需备佳品

在中国传统文化中,家庭是社会的基石,是每个人心灵的港湾。而团圆,则是家庭生活中较美好的时刻。在这样一个特殊的日子里,白酒,尤其是豪迈白酒(HOMANLISM),成为了团圆时刻的需备佳品。…

了解JS数组元素及属性

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1、定义数组并输出2、查询数组的长度3、访问数组的第一个元素4、访问数组中第一个元素的xxx属性5、从数组元素中提取ID并存储到搜索参数对象 提示:以下是…

C++设计模式1:单例模式(懒汉模式和饿汉模式,以及多线程问题处理)

饿汉单例模式 程序还没有主动获取实例对象&#xff0c;该对象就产生了&#xff0c;也就是程序刚开始运行&#xff0c;这个对象就已经初始化了。 class Singleton { public:~Singleton(){std::cout << "~Singleton()" << std::endl;}static Singleton* …

KUKA KR C2 中文操作指南 详情见目录

KUKA KR C2 中文操作指南 详情见目录

Selenium + Python 自动化测试22(PO+数据驱动)

我们的目标是&#xff1a;按照这一套资料学习下来&#xff0c;大家可以独立完成自动化测试的任务。 上一篇我们讨论了PO模式和unittest框架结合起来使用。 本篇文章我们综合一下之前学习的内容&#xff0c;如先将PO模式、数据驱动思想和我们生成HTML报告融合起来&#xff0c;综…

​2024年AI新蓝海:三门生意如何借AI之力,开启变现新篇章

【导语】在这个日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;已不再是遥不可及的未来科技&#xff0c;而是正逐步渗透到我们生活的方方面面&#xff0c;成为推动产业升级的重要力量。你是否还在为传统行业的未来而忧虑&#xff1f;别担心&#xff0c;AI正以其…

Pandas DataFrame 数据转换处理和多条件查询

工作中需要处理一个比较大的数据&#xff0c;且当中需要分析的日期类型字段为字符串型&#xff0c;需要进行转换&#xff0c;获得一个新的字段用于时间统计。我们应用 datetime.datetime.strptime 函数进行转换。 数据读取与时间列补充代码如下&#xff1a; import pandas as…

原来ChatGPT是这么评价《黑神话:悟空》的啊?

《黑神话&#xff1a;悟空》一经上线便迅速吸引了全球的目光&#xff0c;成为了今日微博热搜榜上的焦点话题。作为中国首款现象级的中国3A大作&#xff0c;它的发布无疑引发了广泛的关注与讨论。 《黑神话&#xff1a;悟空》&#xff0c;这款3A国产游戏大作&#xff0c;由国内游…

根据状态的不同,显示不同的背景颜色

文章目录 前言HTML模板部分JavaScript部分注意&#xff1a;主要差异影响如何处理示例 总结 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 实现效果&#xff1a; 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 根据给定的状态…

文件操作2(函数的专栏)

1、文件的打开和关闭 1.1文件指针 在缓冲文件系统中&#xff0c;关键的概念是“文件类型指针”&#xff0c;简称“文件指针”取名为FILE。 例如&#xff0c; VS2013编译环境提供的 stdio. h头文件中有以下的文件类型申明&#xff1a; struct _ iobuf { char *_ ptr; int _…

【YOLO5 项目实战】(6)YOLO5+StrongSORT 目标追踪

欢迎关注『youcans动手学模型』系列 本专栏内容和资源同步到 GitHub/youcans 【YOLO5 项目实战】&#xff08;1&#xff09;YOLO5 环境配置与检测 【YOLO5 项目实战】&#xff08;2&#xff09;使用自己的数据集训练目标检测模型 【YOLO5 项目实战】&#xff08;6&#xff09;Y…

数据库机器上停service360safe

发现有个数据库的负载较高&#xff0c;发现有360safe&#xff0c;就准备停了该服务再观察 [rootdb1 ~]# ps -ef |grep 360 root 970 1 0 15:12 ? 00:00:10 /opt/360safe/360entclient root 976 970 5 15:12 ? 00:18:42 /opt/360…

Linux之RabbitMQ集群部署

RabbitMQ 消息中间件 1、消息中间件 消息(message)&#xff1a; 指在服务之间传送的数据。可以是简单的文本消息&#xff0c;也可以是包含复杂的嵌入对象的消息 消息队列(message queue): 指用来存放消息的队列&#xff0c;一般采用先进先出的队列方式&#xff0c;即最先进入的…

关于springboot的异常处理以及源码分析(一)

一、什么是异常处理 1、文档定义 首先我们先来看springboot官方对于异常处理的定义。springboot异常处理 在文档的描述中&#xff0c;我们首先可以看到的一个介绍如下&#xff1a; By default, Spring Boot provides an /error mapping that handles all errors in a sensib…

优思学院|如何在30分钟内评审一家供应商?SQE必需知道的11点

在供应商评审中&#xff0c;特别是时间有限的情况下&#xff0c;SQE&#xff08;供应商质量工程师&#xff09;需要通过高效的观察和分析来快速评估供应商的能力。在《哈佛商业评论》中&#xff0c;R. Eugene Goodson 的一篇“Read a Plant—Fast”文章正好提供了一个极为实用的…