MindSpore-FCOS模型权重迁移推理对齐实录

news2024/9/24 9:23:28

准备工作

环境:
wsl2 Ubuntu 20.04
mindspore 2.0.0
python 3.8
pytorch 2.0.1 cpu

基于已有的mindspore FCOS项目和FCOS官方pytorch权重来做迁移,

  • FCOS官方pytorch实现
    FCOS_imprv_R_50_FPN_1x权重
  • MindSpore FCOS项目链接
    该代码是mindspore1.6实现的,用新版本运行会有很多warning,warning的接口要更改为新的。
    而且没提供训练好的权重,所以用官方的pytorch权重进行迁移,但其中发现MindSpore相比官方有许多地方不同。

权重转换

迁移其实就是在做权重的键值映射对齐,这其中有一些规律可寻,但不多,更多需要自己的分析比对,建立映射字典。

可参考的经验:

  • https://gitee.com/lirongxi4/pt2ms_convert
    一个迁移脚本,通用性一般
  • https://mindspore.cn/docs/zh-CN/r2.0/migration_guide/overview.html
    MindSpore官方的迁移指南

根据上述迁移经验,打印两种框架的权重的名称及shape进行比对,总结名称转换方式如下(pytorch的名称改为mindspore的):

import copy, torch
import mindspore as ms

def fcos_pth2ckpt():
    m = ms.load_checkpoint('test.ckpt')  # mindspore FCOS保存的随机权重
    t = torch.load('./weights/FCOS_imprv_R_50_FPN_1x.pth', map_location=torch.device('cpu'))  # pytorch FCOS权重
    match_pt_kv = {}  # 匹配到的pt权重的name及value的字典
    match_pt_kv_mslist = []  # 匹配到的pt权重的name及value的字典, mindspore加载权重需求的格式
    not_match_pt_kv = {}  # 未匹配到的pt权重的name及value
    matched_ms_k = []  # 被匹配到的ms权重名称
    
    '''一般性的转换规则'''
    pt2ms = {'module': 'fcos_body',  # backbone部分
             'stem.': '',
             '.body': '',
             '.rpn': '',
             'downsample': 'down_sample_layer',

             'backbone.fpn': 'fpn',  # FPN部分
             'fpn_inner4': 'prj_5',
             'fpn_layer4': 'conv_5',

             'fpn_inner3': 'prj_4',
             'fpn_layer3': 'conv_4',

             'fpn_inner2': 'prj_3',
             'fpn_layer2': 'conv_3',

             'top_blocks.p': 'conv_out',

             'bbox_tower': 'reg_conv',  # head部分
             'cls_tower': 'cls_conv',
             'bbox_pred': 'reg_pred',

             'scales': 'scale_exp',
             'centerness': 'cnt_logits',

             "running_mean": "moving_mean",  # BN部分
             "running_var": "moving_variance",

             }

    '''BN层的特殊转换规则'''
    pt2ms_bn = {
        "weight": "gamma",
        "bias": "beta",
    }


    for i in t['model'].keys():
        pt_name = copy.deepcopy(i)
        pt_value = copy.deepcopy(t['model'][i])
        
        '''通用的处理'''
        for k, v in pt2ms.items():
            if k in pt_name:
                pt_name = pt_name.replace(k, v)
        '''BN层处理'''
        if 'bn' in pt_name:
            for k, v in pt2ms_bn.items():
                if k in pt_name:
                    pt_name = pt_name.replace(k, v)
        '''下采样层特别处理'''
        if 'down' in pt_name:
            if 'bias' in pt_name:
                pt_name = pt_name.replace('bias', 'beta')
            if 'down_sample_layer.1.weight' in pt_name:
                pt_name = pt_name.replace('weight', 'gamma')

        '''head部分的特殊处理'''
        if 'cls_conv' in pt_name or 'reg_conv' in pt_name:
            if '1' in pt_name or '4' in pt_name or '7' in pt_name or '10' in pt_name:
                pt_name = pt_name.replace('weight', 'gamma')
                pt_name = pt_name.replace('bias', 'beta')

        '''改名成功,匹配到ms中的权重了,记录'''
        if pt_name in m.keys():
            assert pt_value.shape == m[pt_name].shape
            match_pt_kv[pt_name] = pt_value
            match_pt_kv_mslist.append({'name': pt_name, 'data': ms.Tensor(pt_value.numpy(), m[pt_name].dtype)})
            matched_ms_k.append(pt_name)
        else:
            not_match_pt_kv[i + '   ' + pt_name] = pt_value

    '''打印未匹配的pt权重名称'''
    print('\n\n------------------未匹配的pt权重名称--------------------')
    for j in not_match_pt_kv.keys():
        print(j, np.array(not_match_pt_kv[j].shape))

    '''打印未被匹配到的ms权重名称'''
    print('\n\n------------------未被匹配到的ms权重名称--------------------')
    for j in m.keys():
        if j not in matched_ms_k:
            print(j, np.array(m[j].shape))
    print('end')
    return match_pt_kv_mslist

输出:

------------------未匹配的pt权重名称--------------------

------------------未被匹配到的ms权重名称--------------------
fcos_body.backbone.end_point.weight [1001 2048]
fcos_body.backbone.end_point.bias [1001]

这俩权重不参与模型forward,是冗余的。
match_pt_kv_mslist就是转换后的mindspore权重,加载后测试发现输出有很大出入,第一个原因是mindspore1.10的ops.sort算子有bug,已提交[issue]https://gitee.com/mindspore/mindspore/issues/I7EHKI),后续版本修复了,所以我升级到2.0.0版本了,其他原因就是网络实现未对齐,接下来主要讲这部分。

区别一:输入处理未对齐

MindSpore FCOS项目链接 输入处理方式就与FCOS官方pytorch实现不一样

  • offical pytorch FCOS:BGR 255 ,使用(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])进行归一化
  • MindSpore FCOS:RGB, 使用(mean=[0.40789654, 0.44719302, 0.47026115], std=[0.28863828, 0.27408164, 0.27809835])进行归一化

其他的裁剪,图像padding对推理结果影响不会很大。

归一化对齐为官方实现后仍发现图片值仍有不同(B通道的最大值不一样),可能Normalize的底层实现有区别?没有深究,后续直接用torch的Normalize结果张量输入到mindspore中以实现模型输入对齐。

输入对齐后的测试:使用coco2017验证集第一张图像(val/000000000139.jpg),resize到(800,1216)大小,两个框架的模型分别输入进去,输出有差别,

进行排查,发现模型第一个卷积的padding没对齐。

区别二:第一个7x7卷积padding方式未对齐

pytorch:

torch默认pad模式
在这里插入图片描述
卷积结果:
在这里插入图片描述

mindspore:

same模式下的卷积跟torch的pad模式下肯定不一样,且两种框架的same也不一样:算子区别
在这里插入图片描述
结果自然不一样:
在这里插入图片描述
原实现:

nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, pad_mode='same', weight_init=weight)

改为:

nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad', weight_init=weight)

在这里插入图片描述
结果这就对了。

其实发现设置mindsporefocs实现的resnet中的self.res_base=True就会调用正确的7x7卷积。

第一个卷积对了,但后面BN层就不对了,官方的BN层是一种frozenBN,没有使用eps,去除了eps按公式手动计算,但还是有误差,不知为何…

此外,mindspore实现的fcos的卷积pad_mode全选的same,这个肯定与官方的对不齐,pytorch官方的全使用的zeros模式,对应的mindspore应该是pad模式吧

FCOS对齐先放在这儿,后续再处理,已经有了一定的经验,先去做TOOD的迁移。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/672725.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【加强版】SAX解析XML返回对应格式的Map对象(解决元素递归嵌套)

SAX解析XML返回对应格式的Map对象_辛丑年正月十五的博客-CSDN博客 前言 上篇文章实现了xml元素节点的解析并返回了对应格式的Map对象,但是遗留了一个问题,就是当xml中的元素存在递归嵌套时就解析不了,因为qname属性会重复,导致后…

DDD软件架构领域驱动设计

目录 1. DDD概述1.1 软件开发的困境1.2 DDD的来源及简介1.2.1 DDD设计方法 1.3 DDD解决了什么问题1.3.1 沟通问题1.3.2 代码质量问题 1.4 模型和建模1.4.1 什么是模型 1.5 统一语言(UBIQUITOUS LANGUAGE)1.6 什么是DDD 2. 传统开发模式2.1 基础知识回顾2…

Debian12.0.0更换系统语言中文到英文

6月10号,Debian12.0.0更新,想尝尝鲜,在虚拟机里安装好,想将中文改为英文,因为Terminal下输入命令,中文切换麻烦。 一、步骤如下 #1、查看当前语言环境 env | grep LANG #2、en表示语言,US表示…

欧科云链在GEF论坛发起圆桌:监管科技与Web3合规发展图景与展望

6月15日,欧科云链在格林威治经济论坛发起了一场题为“监管科技与Web3合规发展图景与展望”的圆桌会议,此次会议由中国香港贸易发展局副执行董事PatrickLau博士主持。Stratford Finance首席执行官Angelina Kwan,BC科技集团有限公司董事会副主席…

[Web前端] Servlet及应用

文章目录 前言1、简介1.1、Servlet 架构1.1.1、Servlet 任务1.1.2、Servlet 包 1.2、Servlet 环境设置1.2.1、设置 Web 应用服务器:Tomcat 1.3、Servlet 生命周期1.3.1、init() 方法1.3.2、service() 方法1.3.3、doGet() 方法1.3.4、doPost() 方法1.3.5、destroy() …

采集发布到WordPress 特色图片(缩略图)无法显示

采集的数据发布到wordpress系统网站,文章内容是正常的,但是在列表页的缩略图(特色图片)却是显示失败。 这种情况有多种问题都可以造成的,可按照以下步骤逐一排查: 目录 1. 发布映射值是否正确 2. 与主题…

【Python 基础篇】Python 字符串以及字符串常用函数

文章目录 导言一、字符串基础二、字符串操作1、字符串拼接2、字符串格式化3、字符串常用函数len()lower()upper()strip()split()join()replace()find()count() 三、条件控制与字符串总结 导言 字符串是计算机编程中常用的数据类型之一。在 Python 中,字符串是由字符…

切换SVN登录的账号

更换SVN的账号 1、找到已登录的用户信息2、删除已登录的用户信息3、获取重输用户信息弹窗4、使用新的用户信息登录 1、找到已登录的用户信息 (1)在任何文件夹里面右键,找到TortoiseSVN,然后选择里面的Settings (2&am…

【Python 基础篇】Python 条件与循环控制

文章目录 导言一、条件语句1、if-elif-else 结构2、嵌套条件语句3、单行 if 语句 二、循环语句1、while 循环while 循环的高级用法 2、for 循环for 循环的高级用法 示例一:输出 1 到 10 的偶数示例二:获取 100 以内的质数结论 导言 Python 是一种简单而…

【Leetcode -2236.判断根节点是否等于子节点之和 -2331.计算布尔二叉树的值】

Leetcode Leetcode -2236.判断根节点是否等于子节点之和Leetcode -2331.计算布尔二叉树的值 Leetcode -2236.判断根节点是否等于子节点之和 题目:给你一个 二叉树 的根结点 root,该二叉树由恰好 3 个结点组成:根结点、左子结点和右子结点。 …

shell脚本自动化部署tomcat

前言 在一个月黑风高的晚上,在公司把程序打包好后,发给现场,结果又被告知不能登录命令行界面部署程序(tomcat部署),只能提供一个shell脚本实现自动化部署,于是拿出我0.5年的开发经验&#xff0…

Spring MVC获取参数和自定义参数类型转换器及编码过滤器

目录 一、使用Servlet原生对象获取参数 1.1 控制器方法 1.2 测试结果 二、自定义参数类型转换器 2.1 编写类型转换器类 2.2 注册类型转换器对象 2.3 测试结果 三、编码过滤器 3.1 JSP表单 3.2 控制器方法 3.3 配置过滤器 3.4 测试结果 往期专栏&文章相关导读…

MySQL数据库学习笔记二

数据库存储引擎 数据库存储引擎是数据库底层软件组织,数据库管理系统(DBMS)通过数据引擎,对数据进行创建、查询、修改和删除的操作。不同的存储引擎提供不同的存储机制、索引技巧、锁定水平等功能,使用不同的存储引擎…

Kubios HRV心率变异性分析软件

Kubios HRV是由东芬兰大学研究团队开发的一款心率变异性分析软件,目前在全球128个国家被1200所大学的科研人员使用。 PC端的Kubios HRV主要分免费版(Standard)和收费版(Premium)两个版本。 免费版仅支持RR间期时间序…

Git添加与提交文件与查看

目录 一、Git添加 二、Git提交文件 三、查看Git仓库的提交历史和当前状态 一、Git添加 1、在终端或命令提示符中,导航到你的Git项目所在的目录,使用 cd 命令切换目录。 2、在目标目录中,运行以下命令来初始化一个新的Git仓库,…

【跟小嘉学 Rust 编程】一、Rust 编程基础

系列文章目录 【跟小嘉学 Rust 编程】一、Rust 编程基础 文章目录 系列文章目录前言一、Rust是什么?二、Rust 开发环境搭建2.1、下载地址2.2、Windows 环境安装 可以参考2.3、Mac 环境安装2.3.1、安装步骤2.3.2、执行完上述命令之后,有如下提示 2.4、安…

深度学习----第J1周:ResNet50算法实战

深度学习----第J1周:ResNet50算法实战 🍨 本文为🔗365天深度学习训练营 中的学习记录博客** 参考文章:Pytorch实战 | 第P5周:运动鞋识别**🍖 原作者:K同学啊|接辅导、项目定制 文章目录 深度学习…

Elasticsearch 分词器

前奏 es的chinese、english、standard等分词器对中文分词十分不友好,几乎都是逐字分词,对英文分词比较友好。 在kibana的dev tools中测试分词: POST /_analyze {"analyzer": "standard","text": "你太…

chatgpt赋能python:Python文件导出方法详解

Python文件导出方法详解 Python是一种高级编程语言,广泛应用于各种数据科学、人工智能、Web开发等领域。在Python开发中,我们需要将处理好的数据与结果输出为合适的格式,文件导出是常见的输出方式之一。在本文中,我们将详细介绍P…

【C++篇】C++的输入和输出

友情链接:C/C系列系统学习目录 知识总结顺序参考C Primer Plus(第六版)和谭浩强老师的C程序设计(第五版)等,内容以书中为标准,同时参考其它各类书籍以及优质文章,以至减少知识点上的…