MindSpore-TOOD模型权重迁移推理对齐实录

news2025/1/2 4:12:50

准备工作

环境:
wsl2 Ubuntu 20.04
mindspore 2.0.0
python 3.8
pytorch 2.0.1 cpu

基于自己编写的mindspore TOOD项目和MMDetection实现的pytorch权重来做迁移,

  • TOOD论文pytorch mmdetection实现
    tood_r50_fpn_1x_coco权重
    论文中的代码也是用mmdetection实现的
  • TOOD mmdetection实现
    观察上面两个实现的配置文件,区别只是分类损失用的不同,我们先对照TOOD mmdetection实现。
  • MindSpore TOOD项目链接
    该代码基于FCOS mindspore实现的,对网络命名进行了优化,更靠近官方的pytorch风格

基于MindSpore实现TOOD forward 结构

先搭模型,结构就是resnet50+fpn+toodhead。除了模型结构,还要注意head以及fpn部分的权值初始化要与mmdetection中的实现对齐,这个在后续训练时会有影响

  • 两种框架下pad的区别需要注意,区别见MindSpore官方的迁移指南 ,我尽量使用显式表达,防止出错
  • resent50 backbone在训练时加载预训练权重进行初始化
  • mmdetection中FPN部分的初始化为xavier初始化,我在mindspore中采用更好的kaiming初始化
  • head部分卷积和一般性的偏置使用normal初始化以及zeros初始化
  • head部分的分类分支偏置采用的prob初始化
  • 其他部分(BN,GN)的初始化两个框架相同

权重转换

迁移其实就是在做权重的键值映射对齐,有了FCOS的迁移经验,且对网络模型部分做了命名优化,做这个会快很多。

可参考的经验:

  • FCOS权重迁移经验
  • https://gitee.com/lirongxi4/pt2ms_convert
    一个迁移脚本,通用性一般
  • MindSpore官方的迁移指南

打印两种框架的权重的名称及shape进行比对,
利用文本对比网站进行对比:
在这里插入图片描述
根据shape可以看到顺序完全对齐了,注意scale在pt中是一个浮点数,而在ms中是一个1x1的tensor。FPN实现的运算顺序也在代码中专门调试过,只需完成名称转换即可。

虽然可以根据顺序直接转换,但为了稳定性,还是用字典映射的方法,总结的名称转换方式如下(pytorch的名称改为mindspore的):

def tood_pth2ckpt():
    ms_ckpt = ms.load_checkpoint('tood_ms.ckpt')  # mindspore FCOS保存的随机权重
    pth = torch.load("/mnt/f/pretrain_weight/tood_r50_fpn_1x_coco.pth", map_location=torch.device('cpu'))  # pytorch FCOS权重
    match_pt_kv = {}  # 匹配到的pt权重的name及value的字典
    match_pt_kv_mslist = []  # 匹配到的pt权重的name及value的字典, mindspore加载权重需求的格式
    not_match_pt_kv = {}  # 未匹配到的pt权重的name及value
    matched_ms_k = []  # 被匹配到的ms权重名称

    '''一般性的转换规则'''
    pt2ms = {'backbone': 'tood_body.backbone',  # backbone部分
             'neck': 'tood_body.fpn',
             'bbox_head': 'tood_body.head',
             'downsample': 'down_sample_layer',
             }

    '''conv层的转换规则, 一致,可忽略'''
    pt2ms_conv = {
        "weight": "weight",
        "bias": "bias",
    }

    '''downsample层的转换规则, 有卷积层和bn层, 分别为0,1命名,在torch中weight重复'''
    pt2ms_down = {
        "0.weight": "0.weight",
        "1.weight": "1.gamma",

        "1.bias": "1.beta",
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
    }

    '''BN层的转换规则'''
    pt2ms_bn = {
        "running_mean": "moving_mean",
        "running_var": "moving_variance",
        "weight": "gamma",
        "bias": "beta",
    }

    '''GN层的转换规则'''
    pt2ms_gn = {
        "weight": "gamma",
        "bias": "beta",
    }

    for i, v in pth['state_dict'].items():
        pt_name = copy.deepcopy(i)
        pt_value = copy.deepcopy(v)
        '''一般性的处理'''
        for k, v in pt2ms.items():
            if k in pt_name:
                pt_name = pt_name.replace(k, v)

        '''conv层的转换规则, 一致,可忽略'''

        '''FPN部分特别处理'''
        if 'fpn' in pt_name:
            pt_name = pt_name.replace('.conv', '')

        '''下采样层特别处理'''
        if 'down' in pt_name:
            for k, v in pt2ms_down.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''BN层处理'''
        if 'bn' in pt_name:
            for k, v in pt2ms_bn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''GN层处理'''
        if 'gn' in pt_name:
            for k, v in pt2ms_gn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)

        '''改名成功,匹配到ms中的权重了,记录'''
        if pt_name in ms_ckpt.keys():
            if 'scale' in pt_name:
                pt_value = torch.tensor([pt_value])
            assert pt_value.shape == ms_ckpt[pt_name].shape
            match_pt_kv[pt_name] = pt_value
            match_pt_kv_mslist.append({'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), ms_ckpt[pt_name].dtype)})
            matched_ms_k.append(pt_name)
        else:
            not_match_pt_kv[i + '   ' + pt_name] = pt_value

    '''打印未匹配的pt权重名称'''
    print('\n\n-----------------------------未匹配的pt权重名称----------------------------')
    print('----------原名称--------                        ----------转换后名称---------')
    for j, v in not_match_pt_kv.items():
        print(j, np.array(v.shape))

    '''打印未被匹配到的ms权重名称'''
    print('\n\n---------------------------未被匹配到的ms权重名称----------------------------')
    for j, v in ms_ckpt.items():
        if j not in matched_ms_k:
            print(j, np.array(v.shape))
    print('end')
    return match_pt_kv_mslist

输出:

-----------------------------未匹配的pt权重名称----------------------------
----------原名称--------                        ----------转换后名称---------
backbone.layer4.1.bn3.num_batches_tracked   tood_body.backbone.layer4.1.bn3.num_batches_tracked []
backbone.layer4.2.bn1.num_batches_tracked   tood_body.backbone.layer4.2.bn1.num_batches_tracked []
backbone.layer4.2.bn2.num_batches_tracked   tood_body.backbone.layer4.2.bn2.num_batches_tracked []
backbone.layer4.2.bn3.num_batches_tracked   tood_body.backbone.layer4.2.bn3.num_batches_tracked []
......

---------------------------未被匹配到的ms权重名称----------------------------
end

剩下一些bn层的num_batches_tracked状态,不需要管

接下来进行输出对齐,推理到需要padding的卷积时发现了一些问题,
mindspore中

nn.Conv2d(64, 64, kernel_size=3, stride=1,
                     padding=1, pad_mode='pad', has_bias=False)

不等价于pytorch的

nn.Conv2d(64, 64, kernel_size=3, stride=1,
                     padding=1)

查阅资料按道理应该等价的啊,结果不等价
发现是跟ms中这样等价的, 先pad,再valid卷积:

pad1 = ms.nn.Pad(((0,0),(0,0),(1,1),(1,1)))
conv2 = ms.nn.Conv2d(64, 64, kernel_size=3, stride=1,
                      pad_mode='valid')

不解。。。

未完待续。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/675559.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浅谈前后端交互的基本原理

本文受众人群: 前端/后端开发工程师;Web应用程序设计师;项目经理;产品经理等。 为什么要去了解? 了解前后端交互的基本原理对于从事与Web开发相关的角色的人群是非常重要的。这包括前端开发工程师、后端开发工程师、全…

【Java高级语法】(十三)注解:解码程序设计中的元数据利器,在小小的@符里挖呀挖呀挖~用小小的注解做强大的开发...

Java高级语法详解之注解 1️⃣ 概念2️⃣ 优势和缺点3️⃣ 使用3.1 元注解3.2 自定义注解3.3 常用内置注解 4️⃣ 应用场景5️⃣ 扩展:那些流行框架中的注解🌾 总结 1️⃣ 概念 Java 注解(Annotation) 是Java语言中一种元数据形式…

chatgpt赋能python:Python爬虫速度分析:如何加速你的爬虫?

Python爬虫速度分析:如何加速你的爬虫? Python作为一种优秀的胶水语言,被广泛应用于web开发、数据处理等众多领域。在众多应用场景中,Python爬虫无疑是其中之一。然而,在爬取海量数据时,爬虫的速度往往成为…

Arthas原理分析

在日常开发中,经常会使用到arthas排查线上问题,觉得arthas的功能非常强大,所以打算花了点时间了解一下其实现原理。并试着回答一下使用Arthas时存在的一些疑问。 Arthas主要基于是Instrumentation JavaAgent Attach API ASM 反射 OGNL等…

chatgpt赋能python:Python点的用法

Python点的用法 作为一名有着10年Python编程经验的工程师,我发现很多初学者对Python的点(.)用法存在疑惑。因此,在这篇文章中,我将详细介绍Python点的用法,并希望能够对这个问题有一个全面的认识。 什么是点 在Python中&#x…

Linux Xshell配置public key实现免密登录linux服务器

linux服务器安装成功后,登录linux服务器的工具有很多中,例如:Xshell、SecureCRT等等。而我所服务的用户使用xshell工具来对linux服务器进行运维。 当使用xshell登录linux服务器时,xshell提供了三种身份验证方式: 1.P…

实战:Maven构建工具实践-2023.6.21(测试成功)

实战:Maven构建工具实践-2023.6.21(测试成功) 目录 推荐文章 https://www.yuque.com/xyy-onlyone/aevhhf?# 《玩转Typora》 实验环境 gitlab/gitlab-ce:15.0.3-ce.0 jenkins/jenkins:2.346.3-2-lts-jdk11 apache-maven-3.9.2 openjdk 11.0.18实验软件 链接&…

对centOS的home目录进行扩容。

对centos的home目录进行扩容 1 首先要了解PV\VG\LV的含义1.1 基本概念1.2 基本命令行 2 实际操作2.1 盘符当前现状2.1实操 1 首先要了解PV\VG\LV的含义 1.1 基本概念 物理卷(Physical Volume,PV) 指磁盘分区或从逻辑上与磁盘分区具有同样功能…

SPSS统计教程:卡方检验

本文简要的介绍了卡方分布、卡方概率密度函数和卡方检验,并通过SPSS实现了一个卡方检验例子,不仅对结果进行了解释,而且还给出了卡方、自由度和渐近显著性的计算过程。本文用到的数据"2.2.sav"链接为: https://url39.ctfile.com/f/…

菲涅尔圆孔衍射matlab完整程序分享

根据惠更斯 - 菲涅耳原理,光的衍射是光束内部的次波之间的相干叠加,衍射光波场的光振动符合菲涅耳积分公式。但直接运用菲涅耳积分公式计算衍射光场是很困难的。对于夫琅和费衍射(远场衍射),在光源和接收屏距离衍射屏均为无穷远的…

实战:k8s证书续签-2023.6.19(测试成功)

实战:k8s证书续签-2023.6.19(测试成功) 目录 推荐文章 https://www.yuque.com/xyy-onlyone/aevhhf?# 《玩转Typora》 1、前言 k8s集群核心的证书有2套,还有1套非核心的(即使出问题也问题不大)。 ⚠️ 如果是kubeadm搭建的k8s集群,其有效期为…

chatgpt赋能python:Python烧录单片机:快速的开发工具

Python烧录单片机:快速的开发工具 简介 Python是一种高级的编程语言,被广泛应用于各种领域,包括机器学习、数据分析和物联网等领域。Python的易用性和简洁性已经成为其成功的关键因素之一。Python也能在烧录单片机时提供极大的方便性和灵活…

chatgpt赋能python:用Python自动爬取链接的内容——提升SEO效果的利器

用Python自动爬取链接的内容——提升SEO效果的利器 在当今数字化时代,SEO(搜索引擎优化)对于任何一个网站来说都至关重要。一种有用的SEO策略就是频繁地更新网站内容,吸引更多的访问者和搜索引擎爬虫。而最快捷的方法就是自动爬取…

chatgpt赋能python:Python爬虫解密:如何快速抓取网站数据

Python爬虫解密:如何快速抓取网站数据 在当今信息时代,人们越来越依赖互联网获取信息。不同的网站提供了大量数据,但是手动去抓取这些数据十分困难,效率也很低。Python爬虫技术是解决这一问题的有效工具之一,它可以帮…

使用npm安装pnpm包管理器

使用npm安装pnpm包管理器 一、安装 使用 npm install pnpm -g 命令安装pnpm npm install pnpm -g安装完成之后,使用pnpm -v命令查询是否成功安装,出现版本号即可 二、设置源 1.先查看源是否为淘宝的源 pnpm config get registry 2.设置源命令 pn…

TS:pip安装python库报ssl错误-2023.6.17(已解决)

2023.6.17-TS-pip安装python库报ssl错误(已解决) 目录 文章目录 2023.6.17-TS-pip安装python库报ssl错误(已解决)目录报错现象报错环境测试过程换其他源还是报错(失败)百度:替换为豆瓣源并加--trusted-host参数(成功) 参考文章关于我最后 报错…

一文理解多线程机制和多线程的优缺点

一文理解多线程机制 前言:多线程的优缺点。一、什么是多线程1.1、多线程的概念和基本原理1.2、多线程与单线程的区别 二、多线程的应用场景三、C 中的多线程3.1、C11 新增加的 thread 库3.2、C 线程同步机制(mutex、condition_variable) 四.、…

【Openvino01】Ubuntu安装inter的openvino2022.1以及遇到的各种错误解决

交代一下今天的文章背景: 于最近要使用inter的一款名为Intel Movidius™ Myriad™ X 的加速卡去实现对算法模型的加速推理能力,由于是就得第一步安装openvino,然后再使用卡去验证openvino是否安装ok,卡是否真的存在推理加速的能力…

python pytorch教程-带你从入门到实战(代码全部可运行)

python pytorch教程-带你从入门到实战(代码全部可运行) 其实这个教程以前博主写过一次,不过,这回再写一次,打算内容写的多一点,由浅入深,然后加入一些实践案例。 下面是我们的内容目录&#x…

2022(一等奖)D1073基于Himawari-8卫星遥感的黑龙江省地表水时空格局研究

作品介绍 1 项目简介 为探究黑龙江省地表水空间格局变化,本项目以黑龙江省为例,基于高时相Himawari-8号卫星数据,通过影像预处理、特征指数选择、自动阈值分类、集成学习和随机森林分类等步骤,融合IDL二次开发与GIS空间分析&…