IA3源码分析

news2025/1/15 23:35:09

IA3源码分析

PEFT 开源包中的模型代码实现

class IA3Model(BaseTuner):
    """
    Example:
        ```py
        >>> from transformers import AutoModelForSeq2SeqLM, ia3Config
        >>> from peft import IA3Model, IA3Config

        >>> config = IA3Config(
        ...     peft_type="IA3",
        ...     task_type="SEQ_2_SEQ_LM",
        ...     target_modules=["k", "v", "w0"],
        ...     feedforward_modules=["w0"],
        ... )

        >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
        >>> ia3_model = IA3Model(config, model)
        ```
    """

    def __init__(self, model, config, adapter_name):
        super().__init__(model, config, adapter_name)

    @staticmethod
    def _create_new_module(ia3_config, adapter_name, target, **kwargs):
        bias = hasattr(target, "bias") and target.bias is not None
        loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
        is_feedforward = kwargs.pop("is_feedforward", False)

        if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
            eightbit_kwargs = kwargs.copy()
            eightbit_kwargs.update(
                {
                    "has_fp16_weights": target.state.has_fp16_weights,
                    "memory_efficient_backward": target.state.memory_efficient_backward,
                    "threshold": target.state.threshold,
                    "index": target.index,
                }
            )
            new_module = Linear8bitLt(
                adapter_name,
                target.in_features,
                target.out_features,
                is_feedforward,
                bias=bias,
                **eightbit_kwargs,
            )
        else:
            #  Create a new Linear module with (IA)^3 parameters for torch.nn.Linear
            # or Conv1D modules
            if isinstance(target, torch.nn.Linear):
                in_features, out_features = target.in_features, target.out_features
                if kwargs["fan_in_fan_out"]:
                    warnings.warn(
                        "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                        "Setting fan_in_fan_out to False."
                    )
                    kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False
            elif isinstance(target, Conv1D):
                in_features, out_features = (
                    target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
                )
                if not kwargs["fan_in_fan_out"]:
                    warnings.warn(
                        "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                        "Setting fan_in_fan_out to True."
                    )
                    kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True
            else:
                raise ValueError(
                    f"Target module {target} is not supported. "
                    f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
                )
            new_module = Linear(
                adapter_name, in_features, out_features, is_feedforward=is_feedforward, bias=bias, **kwargs
            )
        return new_module

    @staticmethod
    def _check_target_module_exists(ia3_config, key):
        if isinstance(ia3_config.target_modules, str):
            target_module_found = re.fullmatch(ia3_config.target_modules, key)
        else:
            target_module_found = any(_is_valid_match(key, target_key) for target_key in ia3_config.target_modules)
        return target_module_found

    def _mark_only_adapters_as_trainable(self) -> None:
        for n, p in self.model.named_parameters():
            if "ia3_" not in n:
                p.requires_grad = False

    def _create_and_replace(
        self,
        ia3_config,
        adapter_name,
        target,
        target_name,
        parent,
        **optionnal_kwargs,
    ):
        loaded_in_8bit = optionnal_kwargs["loaded_in_8bit"]
        current_key = optionnal_kwargs["current_key"]

        # check if target module is in feedforward_modules
        if isinstance(ia3_config.feedforward_modules, str):
            is_feedforward = re.fullmatch(ia3_config.feedforward_modules, current_key)
        else:
            is_feedforward = any(current_key.endswith(target_key) for target_key in ia3_config.feedforward_modules)

        kwargs = {
            "fan_in_fan_out": ia3_config.fan_in_fan_out,
            "init_ia3_weights": ia3_config.init_ia3_weights,
            "loaded_in_8bit": loaded_in_8bit,
            "is_feedforward": is_feedforward,
        }

        if isinstance(target, IA3Layer):
            target.update_layer(
                adapter_name,
                ia3_config.init_ia3_weights,
            )
        else:
            new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs)
            self._replace_module(parent, target_name, new_module, target)

    @staticmethod
    def _replace_module(parent, child_name, new_module, child):
        setattr(parent, child_name, new_module)
        new_module.weight = child.weight
        if child.bias is not None:
            new_module.bias = child.bias
        if getattr(child, "state", None) is not None:
            new_module.state = child.state
            new_module.to(child.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():
            if "ia3_" in name:
                module.to(child.weight.device)

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)

    def get_peft_config_as_dict(self, inference: bool = False):
        config_dict = {}
        for key, value in self.peft_config.items():
            config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
            if inference:
                config["inference_mode"] = True
        config_dict[key] = config
        return config

    def _set_adapter_layers(self, enabled=True):
        for module in self.model.modules():
            if isinstance(module, IA3Layer):
                module.disable_adapters = False if enabled else True
            elif isinstance(module, ModulesToSaveWrapper):
                module.disable_adapters = False if enabled else True

    def enable_adapter_layers(self):
        self._set_adapter_layers(enabled=True)

    def disable_adapter_layers(self):
        self._set_adapter_layers(enabled=False)

    def set_adapter(self, adapter_name):
        for module in self.model.modules():
            if isinstance(module, IA3Layer):
                if module.merged:
                    warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
                    module.unmerge()
                module.active_adapter = adapter_name

    def _prepare_adapter_config(self, peft_config, model_config):
        if peft_config.target_modules is None:
            if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING:
                raise ValueError("Please specify `target_modules` in `peft_config`")
            peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
        if peft_config.feedforward_modules is None:
            if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING:
                raise ValueError("Please specify `feedforward_modules` in `peft_config`")
            peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[
                model_config["model_type"]
            ]
        return peft_config

    def merge_and_unload(self):
        r"""
        This method merges the (IA)^3 layers into the base model. This is needed if someone wants to use the base model
        as a standalone model.
        """
        if getattr(self.config, "model_type", None) == "gpt2":
            raise ValueError("GPT2 models are not supported for merging ia3 layers")

        if getattr(self.model, "is_loaded_in_8bit", False):
            raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode")

        key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
        for key in key_list:
            try:
                parent, target, target_name = _get_submodules(self.model, key)
            except AttributeError:
                continue
            if isinstance(target, IA3Layer):
                bias = target.bias is not None
                new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
                target.merge()
                self._replace_module(parent, target_name, new_module, target)

            # save any additional trainable modules part of `modules_to_save`
            if isinstance(target, ModulesToSaveWrapper):
                setattr(parent, target_name, target.modules_to_save[target.active_adapter])

        return self.model

与LORA类似,修改了可训练层神经网络


通过 print_trainable_parameters 方法可以查看到 IA3 可训练参数的数量(仅为172,032)以及占比(仅为0.0307%)。

trainable params: 172,032 || all params: 559,386,624 || trainable%: 0.0307536849504646

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1154594.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实力验证 | 求臻医学满分通过CAP及NCCL组织的国内外三项室间质评

近日,求臻医学以满分的优异成绩通过了由美国病理学家协会(College of American Pathologists,CAP)组织的NGS−A 2023:Next−Generation Sequencing (NGS) – Germline、NEO-B 2023 Neoplastic Cellularity能力验证项目…

User-Agent防爬虫与应对策略

引题 最近在看爬虫,也准备学习一下防爬的策略,世上莫大之事就是,我可以爬别人网站,别人不许爬我网站。 正文 什么是User-Agent User-Agent是一个HTTP请求头的一部分,它向Web服务器提供关于客户端(通常是…

使用 RBAC 鉴权实战

使用 RBAC 鉴权实战 官方文档 创建名称 deployment-clusterrole 的 ClusterRole,该⻆⾊具备创建 Deployment、Statefulset、Daemonset 的权限,在命名空间 rbac-test 中创建名称为 cicd-token 的 ServiceAccount,绑定 ClusterRole 到 Service…

Linux学习第28天:Platform设备驱动开发(二): 专注与分散

Linux版本号4.1.15 芯片I.MX6ULL 大叔学Linux 品人间百味 思文短情长 三、硬件原理图分析 四、驱动开发 1、platform设备与驱动程序开发 53 /* 54 * 设备资源信息,也就是 LED0 所使用的所有寄存器 55 */ 56 static str…

揭秘!产品经理提升效率的秘密武器:10款AI生成PPT工具

AI的爆炸式增长表现令人惊艳,现有的各类AI工具正在重塑各行各业,不同程度地提高人们的工作效率,并有望创造新的职业机会。但是,面对市面上数量众多的AI工具,且每周都会蹦出新的产品,即便是以好奇心富称的产…

巴黎奥运会将基于阿里云实现云上转播

10月31日,2023杭州云栖大会,奥林匹克广播服务公司与奥林匹克频道服务公司首席技术官索蒂里斯萨拉穆里斯(Sotiris SALAMOURIS)表示,过去5年阿里云作为奥运会转播的基础设施,让奥运故事触达了更多全球观众。 …

c++实现策略模式

前言 看了一会儿大话设计模式,我感觉平常的话,策略模式还挺常用的,记录一下。个人理解策略模式,就是抽象一个算法,然后你可以有很多不同的实现,这些实现去重写抽象算法的虚方法。然后在一个上下文类中有一…

IT服务管理中怎样选择ITSM软件?

对于什么是一个新ITSM工具最重要的选择标准,业界都有不同的看法。其中67%的服务台用户认为是产品的特性和功能, 65%认为是自助服务功能,53%的人认为是轻松配置和定制的能力,45%的人认为是获得高质量的支持,45%的人认为…

Java入门篇 之 逻辑控制

博主的文章希望对大家有所帮助 今日份励志文案:凌空虚度,难成千秋伟业;求真务实,方能善作善成 冲冲冲!!!!! 目录 一.if~else语句 1.1.if-else语句基本用法: 1.2.代码…

C语言字符串详解

字符串详解 定义 输入输出 思考一: 思考二: 思考三 字符串的转义字符 思考四 常见的字符串函数 strcpy 拷贝数组 strlen 输出字符串长度 strcat 连接俩个字符串 strcmp 比较俩个字符串的大小 strupr 把字符串里面的小写转换成大写形式 s…

[ZenTao]禅道邮件通知设置

代码增加通知设置节点 module/message/config.php

正则表达式的使用实例

正则表达式的使用实例 1- 表示2- 实例 1- 表示 1, [:digit:] 表示0-9全部十个数字 //等价于 0123456789, 而不等价于[0123456789] 2, [[:digit:]] 表示任意一个数字 \{m,n\} 表示其前面的字符出现最少m次,最多n次的情况 \{3,\} 其前面的字符出…

git命令清单

一、设置和配置 1.初始化一个新的仓库&#xff1a; git init2.克隆&#xff08;Clone&#xff09;一个远程仓库到本地&#xff1a; git clone <repository_url>3.配置用户信息&#xff1a; git config --global user.name "Your Name" git config --global…

SpringBoot / Vue 对SSE的基本使用

一、SSE是什么&#xff1f; SSE技术是基于单工通信模式&#xff0c;只是单纯的客户端向服务端发送请求&#xff0c;服务端不会主动发送给客户端。服务端采取的策略是抓住这个请求不放&#xff0c;等数据更新的时候才返回给客户端&#xff0c;当客户端接收到消息后&#xff0c;再…

深入内核buddy分配器(芯驰X9/杰发8015 buddy系统明明还有几十M到100多M内存,却分配4k内存失败)

如上图内核打印分配4K内存失败&#xff0c;但是normal 类型的buddy系统还有大量内存。居然分配失败。源码分析&#xff1a; 根据logfaddr2line定位到&#xff0c;调用栈为__alloc_pages_slowpath——》get_page_from_freelist——》zone_watermark_fast 可以看到buddy内存低于…

node使用fs模块(一)—— 写入文件的基本使用

文章目录 前言一、写入文件的使用&#xff08;fs.writeFile&#xff09;1.参数说明2.基本使用(1)新建app.js 文件(2)代码如下(3)执行命令(4&#xff09;效果 3.写入文件的同步和异步&#xff08;1&#xff09;默认异步&#xff08;2&#xff09; 同步方法&#xff08;writeFile…

【HeidiSql_01】python在heidisql当中创建新表的注意事项

python在heidisql当中创建新表的注意事项 假设你已经在python当中弄好了所有的结果&#xff0c;并且保存在df_all这个dataframe当中&#xff0c;然后要将其导入数据库当中并创建一张新的表进行保存。 # 构建数据库连接,将merged_df写回数据库 from sqlalchemy import create_e…

5000张照片怎么快速发给别人?分享三个简单的方法!

有的时候我们不得不一次性发送很多图片&#xff0c;一张一张发实在让人头疼&#xff0c;这个时候就需要借助一些图片压缩工具打包成文件压缩包发送。下面介绍了三种好用的方法&#xff0c;一起来看看吧&#xff5e; 方法一&#xff1a;使用微信助手 可以使用微信助手&#xff…

设计思想培养:装饰者模式下的RecyclerView添加头、尾

用一个设计模式培养高复用、低耦合思想 前言Android中的装饰者代码实现第一步&#xff1a;创建装饰器DecorateAdapter第二步&#xff1a;处理头部、中间内容、尾部的绑定关系第三步&#xff1a;装饰器的使用第四步&#xff1a;改进、直接封装一个View出来 总结 前言 一个高复用…

操作系统备考学习 day11 (4.1.1~4.1.9)

操作系统备考学习 day11 第四章 文件管理4.1文件系统基础4.1.1 文件的基本概念文件的属性文件的逻辑结构操作系统向上提供的功能文件如何存放在外存 4.1.2 文件的逻辑结构顺序文件索引文件索引顺序文件 4.1.3 文件目录文件控制块单级目录结构两级目录结构多级目录结构 又称树形…