[论文阅读] Knowledge Fusion of Large Language Models

news2025/2/11 9:54:26

Knowledge Fusion of Large Language Models (FuseLLM)


Methodology

整体Pipeline如下图所示
在这里插入图片描述

不同的动物代表不同的LLM。左边第一,第二分别是Ensemble以及Weight Merging方法。最右侧为本文提出的FuseLLM。

  • Ensemble: 融合多个models的预测结果,比如求加权平均等。
  • Weight Merging:在权重/参数层面融合,但通常仅限于相同架构的模型。
  • FuseLLM 主要思想为:融合多个LLMs(可以是不同架构的)的probabilistic matrices,得到Fused Matrix后,喂给Target Model,起到知识蒸馏的作用

这里面会涉及到一个关键:

  • 不同LLM,使用的Tokenizer可能不同,设置也可能不一样(如 model_max_length ),分词结果可能不一样(比如对同一个句子分词,tokens总数不同),使用的Vocabulary也可能不一样,因此生成的probabilistic matrix在维度上可能有所不同,如何解决对齐问题?这个实际上就是 token alignment 问题,本文中着重描述了解决方案。

Definition of Problem

假设我们有一个语料库 C \mathcal{C} C K K K个source LLMs, 对于文本 t ∈ C t \in \mathcal{C} tC,经过 K K K个LLM处理,可以得到对应的概率分布矩阵 probabilistic distribution matrix { P t θ j } j = 1 K \{\mathbf{P}^{\theta_j}_t\}^K_{j=1} {Ptθj}j=1K,其中 θ j \theta_j θj表示第 j j j个LLM的参数。我们要做的就是将这 K K K概率分布矩阵融合,然后送入Target LLM中辅助训练:
P t = F u s i o n ( P t θ 1 , P t θ 2 , … , P t θ K ) , \begin{align} \mathbf{P}_t=\mathbb{F}\mathrm{usion}(\mathbf{P}_t^{\theta_1},\mathbf{P}_t^{\theta_2},\ldots,\mathbf{P}_t^{\theta_K}), \end{align} Pt=Fusion(Ptθ1,Ptθ2,,PtθK),
P t \mathbf{P}_t Pt即得到的融合概率分布矩阵(Fused Representation Matrix)。

为了将 P t \mathbf{P}_t Pt迁移至target model中,我们假设 Q t \mathbf{Q}_t Qt为其输出的representation matrix,则Knowledge Fusion的训练目标为:
L F u s i o n = − E t ∼ C [ D ( Q t , P t ) ] . \begin{align} \mathcal{L}_{\mathrm{Fusion}}=-\mathbb{E}_{t\sim\mathcal{C}}\left[\mathbb{D}(\mathbf{Q}_t,\mathbf{P}_t)\right]. \end{align} LFusion=EtC[D(Qt,Pt)].
其中 D ( ⋅ , ⋅ ) \mathbb{D}(\cdot, \cdot) D(,)表示差异性函数,具体实现可以是KL散度。
整体的模型损失如下:
L = λ L C L M + ( 1 − λ ) L F u s i o n . \begin{align}\mathcal{L}=\lambda\mathcal{L}_{\mathrm{CLM}}+(1-\lambda)\mathcal{L}_{\mathrm{Fusion}}.\end{align} L=λLCLM+(1λ)LFusion.
其中 L C L M \mathcal{L}_{\mathrm{CLM}} LCLM表示最原始的ground-truth之间的损失, λ \lambda λ为系数。

实现细节

Token Alignment

我们假设有两个LLM,使用不同的tokenizer。对同一段文本分词,得到的token序列不同,长度也不同:
在这里插入图片描述
如上图,用DeepSeek和TinyLlama各自的分词器分词,得到的结果完全不一样。最终预测的概率分布矩阵也不一样。

Token-Level Alignment

为了解决这个问题,FuseLLM采用基于最小编辑距离Minimal Edit Distance(MinED)的动态规划策略,在token-level实现对齐,以下图为例:
在这里插入图片描述
具体实现的源代码other.py如下:


def dtw(series_1, series_2, norm_func=np.linalg.norm):
    """Use dynamic time wrapping to align to tokenizers, modified from:
    https://github.com/talcs/simpledtw/blob/master/simpledtw.py"""

    """

    Parameters
    ----------
    series_1: List[str]
        blending_input_tokens
    series_2: List[str]
        base_input_tokens
    norm_func: function
        edit distance evaluation between 2 tokens


    Return Values
    ----------
    matches: List[Tuple]
        matched pairs between a base token and a blending token
    matrix[-1, -1]: int 
        the total cost for mapping the two series of tokens
    mappings_series_1: List[List]
        mapping from blending tokens to base tokens
        eg: [0], [1, 2], [3, 4, 5], [6], ...
    mappings_series_2: List[List]
        mapping from base tokens to blending tokens
    matrix: List[int]
        the dtw matrix

    """

    matrix = np.zeros((len(series_1) + 1, len(series_2) + 1))
    matrix[0, :] = np.inf
    matrix[:, 0] = np.inf
    matrix[0, 0] = 0
    for i, vec1 in enumerate(series_1):
        for j, vec2 in enumerate(series_2):
            cost = norm_func(vec1, vec2)
            matrix[i + 1, j + 1] = cost + min(
                matrix[i, j + 1], matrix[i + 1, j], matrix[i, j]
            )
    matrix = matrix[1:, 1:]
    i = matrix.shape[0] - 1
    j = matrix.shape[1] - 1
    matches = []
    mappings_series_1 = [list() for v in range(matrix.shape[0])]
    mappings_series_2 = [list() for v in range(matrix.shape[1])]
    while i > 0 or j > 0:
        matches.append((i, j))
        mappings_series_1[i].append(j)
        mappings_series_2[j].append(i)
        option_diag = matrix[i - 1, j - 1] if i > 0 and j > 0 else np.inf
        option_up = matrix[i - 1, j] if i > 0 else np.inf
        option_left = matrix[i, j - 1] if j > 0 else np.inf
        move = np.argmin([option_diag, option_up, option_left])
        if move == 0:
            i -= 1
            j -= 1
        elif move == 1:
            i -= 1
        else:
            j -= 1
    matches.append((0, 0))
    mappings_series_1[0].append(0)
    mappings_series_2[0].append(0)
    matches.reverse()
    for mp in mappings_series_1:
        mp.reverse()
    for mp in mappings_series_2:
        mp.reverse()

    return matches, matrix[-1, -1], mappings_series_1, mappings_series_2, matrix


Logit-Level Alignment

利用该对齐结果,将不同LLMs得到的representation matrix对齐。关键代码other.py如下:


def transform_step_logits(
    base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
    blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,
    base_model_vocab: Dict[str, int],
    base_model_input_ids: List[int],
    blending_model_input_ids: List[int],
    blending_model_per_step_logits: List[List[float]],
    blending_model_per_step_indices: List[List[int]],
    vocab_align_type: str = "hard",
    blending_to_base_mapping: Dict[str, str] = None,
):
    """Align blending model per step logits & indices with base model."""


    """

    Parameters
    ----------
    base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase
    blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase
    base_model_vocab: Dict[str, int]
        mapping token to id using vocabulary of base model
    base_model_input_ids: List[int]
        ids of base_model_input_tokens
    blending_model_input_ids: List[int]
        ids of blending_model_input_tokens
    blending_model_per_step_logits: List[List[float]]
        logits for each token in blending_model_input_tokens 
    blending_model_per_step_indices: List[List[int]]
        indices corresponding to logits for each token in blending_model_input_tokens 
    vocab_align_type: str = "hard"
    blending_to_base_mapping: Dict[str, str] = None
        mapping each blending token to its corresponding base token 


    Return Values
    ----------
    aligned_blending_model_per_step_logits: List[List[float]]
        aligned logits for each token in base_model_input_tokens for the FuseLLM training
    aligned_blending_model_per_step_indices: List[List[int]]
        aligned indices corresponding aligned logits for each token in base_model_input_tokens for the FuseLLM training. 
        Use the base model vocabulary to look up the token.
    """



    base_model_tokens = base_model_tokenizer.convert_ids_to_tokens(base_model_input_ids)
    blending_model_tokens = blending_model_tokenizer.convert_ids_to_tokens(
        blending_model_input_ids
    )
    base_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[
        base_model_tokenizer.__class__
    ]
    blending_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[
        blending_model_tokenizer.__class__
    ]


    def dist_fn(a, b):
        """Calculate editdistance between two tokens, a is from blending model, b is from base model."""
        aa = a.replace(blending_model_special_token, "")
        bb = b.replace(base_model_special_token, "")
        dist = editdistance.eval(aa, bb)
        return dist

    _, _, _, base_to_blending, _ = dtw(
        blending_model_tokens, base_model_tokens, norm_func=dist_fn
    )
    aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = (
        [],
        [],
    )
    for i, blending_idx in enumerate(base_to_blending):
        aligned_blending_model_per_step_logit = []
        aligned_blending_model_per_step_index = []
        if len(blending_idx) == 1:  # one base token map to one blending token
            j = blending_idx[0]
            base_token = base_model_tokens[i]
            blending_token = blending_model_tokens[j].replace(
                blending_model_special_token, base_model_special_token
            )
            if (
                (
                    blending_model_tokenizer.__class__
                    == transformers.GPTNeoXTokenizerFast
                    or blending_model_tokenizer.__class__
                    == transformers.GPT2TokenizerFast
                )
                and i == 0
                and base_token.startswith(base_model_special_token)
                and not blending_token.startswith(base_model_special_token)
            ):
                blending_token = (
                    base_model_special_token + blending_token
                )  # special case for mpt
            if vocab_align_type == "hard":
                if (
                    base_token == blending_token
                ):  # find the aligned mapping, use the corresponding logits
                    # the logits and indices at this step
                    for blending_logit, blending_index in zip(
                        blending_model_per_step_logits[j],
                        blending_model_per_step_indices[j],
                    ):
                        # the token corresponds to the logit and indices
                        blending_t = blending_model_tokenizer.convert_ids_to_tokens(
                            [blending_index]
                        )[0].replace(
                            blending_model_special_token, base_model_special_token
                        )
                        if blending_t in base_model_vocab:
                            aligned_index = base_model_vocab[
                                blending_t
                            ]  # the index of the token in base model vocab
                            if (
                                aligned_index
                                not in aligned_blending_model_per_step_index
                            ):
                                aligned_blending_model_per_step_index.append(
                                    aligned_index
                                )
                                aligned_blending_model_per_step_logit.append(
                                    blending_logit
                                )
                else:  # find error aligned mapping, use the one-hot logits
                    aligned_blending_model_per_step_index.append(
                        base_model_vocab[base_token]
                    )
                    aligned_blending_model_per_step_logit.append(1.0)
            elif vocab_align_type == "soft":
                if (base_token == blending_token) or (
                    blending_token in blending_to_base_mapping
                    and base_token == blending_to_base_mapping[blending_token]
                ):  # find the aligned mapping, use the corresponding logits
                    # the logits and indices at this step
                    for blending_logit, blending_index in zip(
                        blending_model_per_step_logits[j],
                        blending_model_per_step_indices[j],
                    ):
                        # the token corresponds to the logit and indices
                        blending_t = blending_model_tokenizer.convert_ids_to_tokens(
                            [blending_index]
                        )[0].replace(
                            blending_model_special_token, base_model_special_token
                        )
                        blending_t = blending_to_base_mapping[blending_t]
                        if blending_t in base_model_vocab:
                            aligned_index = base_model_vocab[
                                blending_t
                            ]  # the index of the token in base model vocab
                            if (
                                aligned_index
                                not in aligned_blending_model_per_step_index
                            ):
                                aligned_blending_model_per_step_index.append(
                                    aligned_index
                                )
                                aligned_blending_model_per_step_logit.append(
                                    blending_logit
                                )
                        else:
                            logger.warning(
                                f"blending_t: {blending_t} not in base_model_vocab!"
                            )
                else:  # find error aligned mapping, use the one-hot logits
                    aligned_blending_model_per_step_index.append(
                        base_model_vocab[base_token]
                    )
                    aligned_blending_model_per_step_logit.append(1.0)
            else:
                logger.warning(
                    f"The vocab_align_type: '{vocab_align_type}' is not support!"
                )
                raise NotImplementedError
        else:  # one base token map to multiple blending token, in this case only fit base token. use the one-hot logits
            base_token = base_model_tokens[i]
            aligned_blending_model_per_step_index.append(base_model_vocab[base_token])
            aligned_blending_model_per_step_logit.append(1.0)
        aligned_blending_model_per_step_indices.append(
            aligned_blending_model_per_step_index
        )
        aligned_blending_model_per_step_logits.append(
            aligned_blending_model_per_step_logit
        )
    return (
        aligned_blending_model_per_step_logits,
        aligned_blending_model_per_step_indices,
    )

Fusion Strategies:

得到对齐的representation matrix以后,由于不同的LLM具有不同的性能,可以使用概率分布矩阵与ground-truth之间的交叉熵损失(CE loss)评估LLM的优劣,再根据此判断选择哪些LLM参与知识融合。CE loss越低,证明模型效果更好。具体而言,作者提出了两种Fusion Strategy:

  1. MinCE: 仅选择CE loss最小的representation matrix用于知识融合。
  2. AvgCE: 基于各个模型的CE loss,采用多个representation matrices的加权平均,用于知识融合。

整体的算法流程如下:
在这里插入图片描述

  • 注:这里Eq.5实际是本文中上述的Eq.3

一些思考

本文的思路是将多个LLMs输出的概率分布矩阵视为知识,将知识融合后,送入target LLM进行训练,以达到融合多种模型知识,提升目标模型性能的目的。但在实际的实现当中我们会发现,logit-level的alignment,要么是直接采用blending_model_per_step_logits/indices,要么直接用ground-truth one-hot作为融合后的知识,而没有充分评估logit-level中,blending/base_model_per_step_logits之间的差异性。为此,Probabilistic Token Alignment for Large Language Model Fusion提出采用Probabilistic Token Alignment方法,在logit-level实现alignment。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2295417.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

elementui:el-table支持搜索、切换分页多选功能,以及数据回显

1、el-table相关代码&#xff0c;需注意:row-key"(row) > { return row.id }" 以及 :reserve-selection"true" <div class"boxList"><div class"search-form"><!-- 搜索表单 --><el-form :inline"true&q…

(ICLR=2025)生成的表征对齐:训练扩散Transformer比你想象的更简单

生成的表征对齐&#xff1a;训练扩散Transformer比你想象的更简单 paper是KAIST发表在ICLR 2025的工作 paper title:REPRESENTATION ALIGNMENT FOR GENERATION: TRAINING DIFFUSION TRANSFORMERS IS EASIER THAN YOU THINK Code&#xff1a;链接 ABSTRACT 最近的研究表明&…

白嫖RTX 4090?Stable Diffusion:如何给线稿人物快速上色?

大家都知道&#xff0c;在设计的初期&#xff0c;我们通常会先绘制草图&#xff0c;然后再进行上色处理&#xff0c;最终才开始进行最终的设计工作。在这个上色的过程中&#xff0c;配色是至关重要的一环。这不仅方便了内部同事的评审&#xff0c;也让产品方和客户可以直观地了…

Bash (Bourne-Again Shell)、Zsh (Z Shell)

文章目录 1. 历史背景2. 主要区别3. 功能对比自动补全插件和主题路径扩展提示符定制 4. 性能5. 使用场景6. 如何切换 Shell7. 总结 以下是 Bash 和 Zsh 之间的主要区别&#xff0c;列成表格方便对比&#xff1a; 特性BashZsh默认Shell大多数Linux发行版默认ShellmacOS默认She…

pikachu[皮卡丘] 靶场全级别通关教程答案 以及 学习方法 如何通过渗透测试靶场挑战「pikachu」来精通Web渗透技巧? 一篇文章搞完这些问题

目录 Pikachu靶场 部署 暴力破解漏洞 学习地址: 靶场练习: 基于表单的暴力破解 验证码绕过(on server) 验证码绕过(on Client) token防爆破? XSS跨站脚本攻击 学习地址: 靶场练习&#xff1a; 反射型xss(get) 反射性xss(post) 存储型xss DOM型xss xss盲打 x…

汽车零部件工厂如何借助安灯呼叫按钮盒提升生产响应速度

在现代汽车零部件工厂的生产环境中&#xff0c;高效的信息传递和快速的响应速度是确保生产顺畅运行的关键。然而&#xff0c;传统的口头呼喊或现场沟通方式往往存在信息传递慢、现场嘈杂、责任人难以及时找到等问题&#xff0c;尤其在设备故障或缺料时&#xff0c;这些问题会导…

Idea 2024.3 使用CodeGPT插件整合Deepseek

哈喽&#xff0c;大家好&#xff0c;我是浮云&#xff0c;最近国产大模型Deepseek异常火爆&#xff0c;作为程序员我也试着玩了一下&#xff0c;首先作为简单的使用&#xff0c;大家进入官网&#xff0c;点击开始对话即可进行简单的聊天使用&#xff0c;点击获取手机app即可安装…

「vue3-element-admin」告别 vite-plugin-svg-icons!用 @unocss/preset-icons 加载本地 SVG 图标

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall ︱vue3-element-admin︱youlai-boot︱vue-uniapp-template &#x1f33a; 仓库主页&#xff1a; GitCode︱ Gitee ︱ Github &#x1f496; 欢迎点赞 &#x1f44d; 收藏 ⭐评论 …

docker /var/lib/docker/overlay2目录把磁盘空间占满问题

1、查看服务器磁盘空间 df -h果然100%了,docker系统文件把磁盘空间占满了。 2、进入overlay2目录&#xff0c;查找那个容器工作目录占用最高 cd /var/lib/docker/overlay2du -h --max-depth1详见下图 好家伙占用110G&#xff01;复制目录名称2c3c48ccac533c5d4a366d45a19bb9…

Redis深入学习

目录 Redis是什么&#xff1f; Redis使用场景 Redis线程模型 Redis执行命令是单线程的为什么还这么快&#xff1f; Redis持久化 Redis 事务 Key 过期策略 Redis 和 mysql 如何保证数据一致&#xff1f; 缓存穿透 缓存击穿 缓存雪崩 Redis是什么&#xff1f; redis是一…

EasyExcel 导出合并层级单元格

EasyExcel 导出合并层级单元格 一、案例 案例一 1.相同订单号单元格进行合并 合并结果 案例二 1.相同订单号的单元格进行合并2.相同订单号的总数和总金额进行合并 合并结果 案例三 1.相同订单号的单元格进行合并2.相同订单号的商品分类进行合并3.相同订单号的总数和总金额…

青少年编程与数学 02-009 Django 5 Web 编程 01课题、概要

青少年编程与数学 02-009 Django 5 Web 编程 01课题、概要 一、Django 5Django 5 的主要特性包括&#xff1a; 二、MVT模式三、官方网站四、内置功能数据库 ORM&#xff08;对象关系映射&#xff09;用户认证和授权表单处理模板引擎URL 路由缓存框架国际化和本地化安全性功能管…

2.7学习

crypto buu-还原大师 仔细阅读题目&#xff0c;这里有一段字符串&#xff0c;但是其中有四个大写字母被替换成了‘&#xff1f;’&#xff0c;那么我们写脚本&#xff1a;首先将四个问号均换成26个大写字母并且组成不同的组合&#xff0c; 所以有四个循环让四个问号都遍历26个…

oracle ORA-27054报错处理

现象 在oracle执行expdp&#xff0c;rman备份&#xff0c;xtts的时候,由于没有足够的本地空间&#xff0c;只能使用到NFS的文件系统但有时候会出现如下报错 ORA-27054: NFS file system where the file is created or resides is not mounted with correct options根据提示信…

使用LLaMA Factory踩坑记录

前置条件&#xff1a;电脑显卡RTX 4080 问题&#xff1a;LLaMA-Factory在运行的时候&#xff0c;弹出未检测到CUDA的报错信息 结论&#xff1a;出现了以上的报错&#xff0c;主要可以归结于以下两个方面&#xff1a; 1、没有安装GPU版本的pytorch&#xff0c;下载的是CPU版本…

电路研究9.3——合宙Air780EP中的AT开发指南(含TCP 示例)

根据合宙的AT研发推荐&#xff0c; AT指令基本上也简单看完了&#xff0c;这里开始转到AT的开发了。 AT 命令采用标准串口进行数据收发&#xff0c;将以前复杂的设备通讯方式转换成简单的串口编程&#xff0c; 大大简化了产品的硬件设计和软件开发成本&#xff0c;这使得几乎所…

Reqable使用实践

一、背景 日常开发中&#xff0c;难免要抓取请求数据&#xff0c;查看接口数据&#xff0c;从而更好定位问题&#xff0c;基于这个原因&#xff0c;查找了一些抓包工具&#xff0c;例如&#xff1a; HttpCanary、 Steam 、Fiddler等&#xff0c;不是要钱&#xff0c;就是只对苹…

【蓝桥杯嵌入式】2_LED

全部代码网盘自取 链接&#xff1a;https://pan.baidu.com/s/1PX2NCQxnADxYBQx5CsOgPA?pwd3ii2 提取码&#xff1a;3ii2 1、电路图 74HC573是八位锁存器&#xff0c;当控制端LE脚为高电平时&#xff0c;芯片“导通”&#xff0c;LE为低电平时芯片“截止”即将输出状态“锁存”…

B树详解及其C语言实现

目录 一、B树的基本原理 二、B树操作过程图形化演示 三、B树的应用场景 四、C语言实现B树及示例 五、代码执行结果说明 六、应用实例&#xff1a;文件系统目录索引 七、总结 一、B树的基本原理 B树&#xff08;B-Tree&#xff09; 是一种自平衡的树数据结构&#xff0c;…

ARM64 Linux 内核学习指南:从基础到实践

前言 ARM64 作为当今主流的处理器架构&#xff0c;被广泛应用于移动设备、嵌入式系统和服务器领域。学习 ARM64 在 Linux 内核中的实现&#xff0c;不仅有助于深入理解操作系统底层机制&#xff0c;还能提升在内核开发、驱动编写、虚拟化等领域的专业能力。 本指南面向对 Lin…