LoRA 和 DoRA 代码笔记

news2024/9/20 8:12:57

Improving LoRA: Implementing Weight-Decomposed Low-Rank Adaptation (DoRA) from Scratch

LoRA

在这里插入图片描述
LoRA初始化时,A使用正态分布,B使用0.

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x


class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

在训练的时候,使用LinearWithLoRA 替换Linear层,并且freeze Linear的参数,只训练lora的参数。

model_lora = copy.deepcopy(model_pretrained)

model_lora.layers[0] = LinearWithLoRA(model_lora.layers[0], rank=4, alpha=8)

freeze_linear_layers(model_lora)

#然后就可以正常训练了

freeze Linear的参数

def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad = False
        else:
            # Recursively freeze linear layers in children modules
            freeze_linear_layers(child)

DoRA (Weight-Decomposed Low-Rank Adaptation)

权重weight矩阵W,可以分为 模向量m(magnitude vector)和方向矩阵V(directional matrix)。
把LoRA加入到方向矩阵V中,然后再和m计算出新的权重W。
使用方法和LoRA相同。
在这里插入图片描述

在这里插入图片描述

#训练时
class LinearWithDoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        self.m = nn.Parameter(torch.ones(1, linear.out_features))

    def forward(self, x):
        linear_output = self.linear(x)
        lora_output = self.lora(x)
        lora_output_norm = lora_output / (lora_output.norm(p=2, dim=1, keepdim=True) + 1e-9)
        dora_modification = self.m * lora_output_norm
        return linear_output + dora_modification

#合并时
# Code inspired by https://github.com/catid/dora/blob/main/dora.py
class LinearWithDoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        
        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        numerator = self.linear.weight + self.lora.alpha*lora.T
        denominator = numerator.norm(p=2, dim=0, keepdim=True)
        directional_component = numerator / denominator
        new_weight = self.m * directional_component
        return F.linear(x, new_weight, self.linear.bias)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2093149.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第L1周:机器学习-数据预处理

第L1周:机器学习-数据预处理 🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 学习要点: **** 学习如何处理缺损数据尝试进行Label编码使用train_test_split进行数据划分学习特征标准化…

EXO项目StandardNode;max_generate_tokens;buffered_token_output;is_finished;

目录 StandardNode max_generate_tokens buffered_token_output 构造函数参数 类属性 总结 is_finished max_generate_tokens self.buffered_token_output StandardNode _process_tensor result是一个np.ndarray ,result.size == 1是什么意思 StandardNode max_g…

【Python机器学习】NLP词频背后的含义——反馈及改进

之前学习的LSA方法都没有考虑文档之间的相似度信息,创建的主题对一组通用规则来说是最优的。在这些特征(主题)提取模型的无监督学习中,没有任何关于主题向量之间应该多么接近的数据。我们也不允许任何关于主题向量在哪里结束或者它…

力扣刷题(复习版1)

文章目录 题目:最大重复子字符串题解 题目: 面试题 16.07. 最大数值题解 题目: 最大字符串配对数目题解 题目: 字符串中第二大的数字题解 题目: 统计最大组的数目题解 题目: 删除每行中的最大值题解 题目&a…

记录|Chart控件使用

目录 前言一、Series集合1.1 什么是Series1.2 IsValueShownAsLabel1.3 Points 二、ChartArea2.1 轴- Axes2.1.1 Title2.1.2 刻度下的Maximum、Minimum2.1.3 间隔- Interval2.1.4 网格刻度线 2.2 游标- CursorX/CursorYMajorGrid属性中的Enabled 更新时间 前言 参考视频&#xf…

七、库存管理——盘点业务

第一节 库存管理盘点 1、盘点的目的 答:是指通过定期或临时对库存商品的实际数量进行盘查、清点,然后掌握货物的流动情况(入库、出库、调拨、在库等),最后对仓库现有物品的实际数量与保管账上记录的数量相核对&…

Spring Boot集成Stripe快速入门demo

1.什么是Stripe? 一体化全球支付平台,开启收入增长引擎,针对不同规模业务打造的支付解决方案,满足从初创公司到跨国企业的多维度需求,助力全球范围内线上线下付款。 转化更多客户: 通过内置的优化功能、100 多种支付…

QT QGraphicsView实现预览图片显示缩略图功能

QT QGraphicsView实现预览图片显示缩略图功能QT creator Qt5.15.2 头文件&#xff1a; #ifndef TGRAPHICSVIEW_H #define TGRAPHICSVIEW_H#include <QGraphicsView> #include <QMainWindow> #include <QObject> #include <QWidget>class TGraphicsVie…

vue页面自适应 动态 postcss postcss-pxtorem

vue页面自适应 动态 postcss postcss-pxtorem postcss-pxtorem实现页面自适应1、安装postcss-pxtorem2、根目录创建postcss.config.js&#xff0c;并配置以下内容3、创建rem.js&#xff0c;动态设置root px4、在main.js中引入rem.js5、在main.js中创建全局处理函数px2rem6、对…

【王树森】Vision Transformer (ViT) 用于图片分类(个人向笔记)

图片分类任务 给定一张图片&#xff0c;现在要求神经网络能够输出它对这个图片的分类结果。下图表示神经网络有40%的信心认定这个图片是狗 ResNet&#xff08;CNN&#xff09;曾经是是图像分类的最好模型在有足够大数据做预训练的情况下&#xff0c;ViT要强于ResNetViT 就是Tr…

S7-200编程软件STEP 7打开时界面乱码显示Translation Required

遇到的问题 如题&#xff0c;两个月没有打开过S7-200编程软件&#xff08;软件版本是V4.0 STEP 7 MicroWIN SP9&#xff0c;电脑系统是Windows 11&#xff09;&#xff0c;这一次打开就发现它的那个界面乱码了&#xff0c;原来时中文汉化的地方全都变成了Translation Required…

笔记整理—内核!启动!—uboot部分(1)

常规启动时&#xff0c;各镜像都在SD卡中的各种分区中&#xff0c;内核放在kernel分区&#xff0c;从SD卡到DDR的连接处&#xff08;内核不需要进行重定位&#xff0c;直接从链接处启动&#xff09;。uboot从sd卡分区读使用movi命令。 使用fastboot指令可以查看分区情况&#x…

通过Dot1q终结子接口实现VLAN间互访

如图1所示&#xff0c;SwitchA为支持配置子接口的三层交换机&#xff0c;SwitchB为二层交换机&#xff0c;SwitchA通过一个三层以太网接口与SwitchB互连。用户主机被划分到两个VLAN&#xff1a;VLAN2和VLAN3。由于业务需要&#xff0c;不同VLAN的用户要求互通。 图1 通过Dot1q…

AI革命:清华大学揭秘大模型工具学习的未来

&#x1f31f; 未来已来&#xff1a;大模型工具学习开启智能新时代 &#x1f31f; 清华大学THUNLP最新研究&#xff0c;探索AI工具使用的无限可能 文末有报告免费下载&#xff0c;需要的朋友自行下跳。 亲爱的读者朋友们&#xff0c;人工智能的浪潮已经不可阻挡地涌入我们的…

LabVIEW VI并行执行设置

要在多个程序中运行同一个VI&#xff08;Virtual Instrument&#xff09;&#xff0c;通常需要确保VI的重入性&#xff08;Reentrancy&#xff09;设置正确。在LabVIEW中&#xff0c;可以使用“重入性”&#xff08;Reentrancy&#xff09;选项来允许同一个VI同时在多个地方调用…

RAG噪声的设计及其对大模型问答的作用分析

有趣的大模型中RAG噪声的作用分析 大模型&#xff08;LLMs&#xff09;在多个任务上表现出色&#xff0c;但存在依赖过时知识、幻觉等问题。RAG作为一种提高LLM性能的方法&#xff0c;通过在推理过程中引入外部信息来缓解这些限制。 Figure 1 展示了一个来自 NoiserBench 的示…

Docker技术

一、Docker简介 1.什么是docker Docker是管理容器的引擎&#xff0c;为应用打包、部署平台&#xff0c;而非单纯的虚拟化技术。 它具有以下几个重要特点和优势&#xff1a; 1. 轻量级虚拟化 &#xff1a;Docker 容器相较于传统的虚拟机更加轻量和高效&#xff0c;能够快速启…

【高中数学/极值/判别式法】已知实数a和b,b在(0,1)区间,a-b=1,则1/(a-1)+1/(5-4b)的最小值是?

【问题】 已知实数a,b&#xff0c;b在(0,1)区间&#xff0c;a-b1,则1/(a-1)1/(5-4b)的最小值是&#xff1f; 【来源】 《解题卡壳怎么办 高中数学解题智慧点剖析》P34 余继光 苏德矿合著 浙江大学出版社出版 【破题点】 将a-1用b取代&#xff0c;发现结果是二次式相除&…

24-8-31-读书笔记(十六)-《契诃夫文集》(十一)([俄] 契诃夫 [译] 汝龙 )

文章目录 《契诃夫文集》&#xff08;十一&#xff09;&#xff08;[俄] 契诃夫 [译] 汝龙 &#xff09;目录阅读笔记记录总结 《契诃夫文集》&#xff08;十一&#xff09;&#xff08;[俄] 契诃夫 [译] 汝龙 &#xff09; 8月最后一天了&#xff0c;心里很多的感慨&#xff0…

Bluetooth: gatt server example 解读

在 core spec 中有 Example ATT Server contents,这里对此进行解读; Assigned_Numbers.pdf 需要提前准备,可以从 SIG 下载; Step-1 从这个服务看,server handle 是1, 但是第一个 characteristic clare handle是 4,所以不能预设handle 是按顺序连续的; Step-2 Servic…