【Python】解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias

news2024/12/30 2:54:00

目录

1.问题描述

2.问题原因

3.问题解决

3.1思路1——忽视最后一层权重

额外说明:假如载入权重不写strict=False, 直接是model.load_state_dict(pre_weights, strict=False),会报错找不到key?

解决办法是:加上strict=False,这个语句就是指忽略掉模型和参数文件中不匹配的参数

3.2思路2——更改最后一层参数

额外说明:假如原有的model默认类别数 和 载入权重类别数不一致,代码如何更改?


1.问题描述

训练一个CNN时,比如ResNet, 借助迁移学习的方式使用预训练好的权重,在导入权重后报错:

RuntimeError: Error(s) in loading state_dict for ResNet:
        size mismatch for linear.weight: copying a param with shape torch.Size([100, 2048]) from checkpoint, the shape in current model is torch.Size([10, 2048]).
        size mismatch for linear.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([10]).

RuntimeError: Error(s) in loading state_dict for ResNet:
        size mismatch for linear.weight: copying a param with shape torch.Size([100, 2048]) from checkpoint, the shape in current model is torch.Size([10, 2048]).
        size mismatch for linear.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([10]).

类似的也可以有:

RuntimeError: Error(s) in loading state_dict for ResNet:
        size mismatch for fc.weight: copying a param with shape torch.Size([100, 2048]) from checkpoint, the shape in current model is torch.Size([10, 2048]).
        size mismatch for fc.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([10]).

导入权重的核心代码为:

    model = model_dict[opt.model](num_classes=10)
    model_path = "./save/models/ResNet50_vanilla/ckpt_epoch_240.pth"
    pre_weights = torch.load(model_path)['model']
    model.load_state_dict(pre_weights, strict=False)

这里的pre_weighets后面还加的有['model']是因为在保存文件的时候出了保存权重,还保存有epoch, acc等等。

2.问题原因

根本原因在于预训练权重的某一层参数和模型需要的参数对应不上,这里比如就是model.linear层,其实就是相当于全连接层fc, 可以直接model文件中去查看,最后一层的命名。比如进入定义好的ResNet文件中查看最后一层名字为linear。

 具体的描述:预训练权重中最后一层的输出类别为100, 而现在我们的目标类别是10,所以导致linear层的参数对应不上,进而报错。(更常见的是在imagenet数据上训练分类类别为1000, 目标类别为10,也会是相同的错误)

3.问题解决

3.1思路1——忽视最后一层权重

查阅相关解决办法后,可以使用pop()函数弹出最后一层的参数,这样相当于导入的时候,只有前面网络层参数,就不会报最后一层参数不匹配的问题。所以把权重文件弹出pre_weights.pop('linear.weight')
pre_weights.pop('linear.bias')

核心代码:

    model = model_dict[opt.model](num_classes=10)
    model_path = "./save/models/ResNet50_vanilla/ckpt_epoch_240.pth"
    pre_weights = torch.load(model_path)['model']
    pre_weights.pop('linear.weight')
    pre_weights.pop('linear.bias')
    model.load_state_dict(pre_weights, strict=False)

额外说明:假如载入权重不写strict=False, 直接是model.load_state_dict(pre_weights, strict=False),会报错找不到key?

RuntimeError: Error(s) in loading state_dict for ResNet:
        Missing key(s) in state_dict: "linear.weight", "linear.bias".

解决办法是:加上strict=False,这个语句就是指忽略掉模型和参数文件中不匹配的参数

3.2思路2——更改最后一层参数

因为这里仅仅是最后一层参数不匹配,所以可以获取导入权重的最后一层,然后更改最终的分类类别数目

核心代码:

    model.load_state_dict(pre_weights, strict=False)
    in_channel = model.linear.in_features
    model.linear = nn.Linear(in_channel, n_cls)

代码意思就是:

1.先按照正常的载入模型,如果原来的model文件默认的类别数目和载入权重默认的类别数目一致的话,那么就直接使用上述核心代码就行。

2.获取最后一层的输入特征维度,这里的model.linear.in_features, 是因为在model文件中自定义最后一层为self.linear,要根据实际名称更改

3.更新最后一层的输出特征维度,这里也要使用model.linear

额外说明:假如原有的model默认类别数 和 载入权重类别数不一致,代码如何更改?

举例子:比如model文件中默认分类类别数是10,如下图所示

但是载入权重文件的的分类类别数是100,如下图所示(这个权重文件训练的数据集就是100个分类类别)

此时,我想要在一个数据集只有7个类别的数据集上进行迁移学习,载入权重的话,就应该这样写:

    # model
    model = model_dict[opt.model](num_classes=100)
    # print(model)
    model_path = "./save/models/ResNet50_vanilla/ckpt_epoch_240.pth"
    pre_weights = torch.load(model_path)['model']
    # pre_weights.pop('linear.weight')
    # pre_weights.pop('linear.bias')
    # model.load_state_dict(pre_weights, strict=False)
    # # # 更改最后的全连接层
    model.load_state_dict(pre_weights, strict=False)
    in_channel = model.linear.in_features
    model.linear = nn.Linear(in_channel, n_cls)

核心的想法是:实例化模型的时候,需要更改模型的分类类别数100 和 权重文件的类别100数保持一致,也就是如下

model = model_dict[opt.model](num_classes=100)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/55632.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ajax学习:解决跨域_JSONP

JSONP:非官方的跨域解决方案,纯粹依靠程序员的聪明才智,只支持get请求 JSONP是怎么工作的:再页面中有一些标签天生具有跨域能力,就像是link,img,iframe,script JSONP就是利用script标签的跨域能力来发送请求 如下所示&#xff1a…

[附源码]计算机毕业设计毕业生就业管理系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

【Docker】Docker常用命令(包含Dockerfile指令)

目录一.Docker常用命令1.帮助命令2.镜像命令3.容器命令4.其他常用命令5.更多更详细命令二.Dockerfile常用指令1.常用指令2.CMD 和 ENTRYPOINT 的区别一.Docker常用命令 1.帮助命令 # docker version //查看docker版本号 # docker info //查看docker的系统信息…

外贸员的日常工作分享

外贸人米贸搜的日常工作流程为你整理如下。希望能帮到你: 外贸业务员的职责 一、业务人员在向国外买家询价前,要了解客户的基本信息,包括是否是终端客户,年采购量,消费区域,产品的用途,规格,质…

(附源码)SSM座位管理系统 毕业设计 250858

基于SSM的座位管理系统 摘 要 21世纪时信息化的时代,几乎任何一个行业都离不开计算机,将计算机运用于学校的各种信息管理也是十分常见的。过去使用手工的管理方式对高校教室座位进行管理,造成了管理繁琐、难以维护等问题,如今使用…

python的opencv使用总结

作为最容易上手之一的语言,python拥有着大量的第三方库,这些第三方库的存在使得很多人可以专注于业务逻辑、数学逻辑而忽略繁琐的代码操作,python的opencv第三方库就是其中之一。 一、第三方库的安装和简单使用 安装 简单的pip安装就可以了…

【雷达波位编排】基于matlab相控阵雷达的波位编排仿真【含Matlab源码 2251期】

⛄一、相控阵雷达最优波位编排策略仿真算法 1 波位编排的最优化 相控阵雷达的扫描空域一般在修正球坐标系下进行指定,它的坐标原点为雷达站,雷达阵面法线在水平面的投影作为方位角的零度,顺时针为正,逆时针为负,有效取值范围为[-π/2,π/2],以水平面作为俯仰角的零度,向上为正…

分享5款2023年不容错过的宝藏软件

今天带来五款宝藏软件,身为宝藏男孩和宝藏女孩的你们,不试一下吗? 1.EPUB阅读器——Starrea Starrea 是一款Windows平台的EPUB电子书阅读器,它虽然只支持一个平台,但是提供了很多额外的功能,其中包括 文…

MySQL学习笔记(十)crash-safe 和两段提交

crash-safe CrashSafe指MySQL服务器宕机重启后,能够保证: 所有已经提交的事务的数据仍然存在。所有没有提交的事务的数据自动回滚。 如果MySQL宕机了,重启后,就需要检查redolog 日志文件里面,系统会自动定位到上次c…

Jmeter插件duang duang duang 学会模拟各种场景

为什么要使用jmeter线程组插件呢? jmeter自带的线程组插件模拟的压测场景非常有限,当需要模拟复杂压测场景的时候, 推荐大家使用jmeter线程组插件。 如何下载jmeter线程组插件呢? 早期版本的jmeter可以针对我们需要的扩展功能&…

解决远程连接 docker中mysql 失败

在docker安装好mysql之后 端口也设置好,同时云服务器的端口3306也打开。但是使用navicat 连接不上。 其实是因为Navicat只支持旧版本的加密,需要更改mysql的加密规则 进入MySQL容器,登陆MySQL docker exec -it mysql /bin/bash 登陆mysql** mysql -u root -p 输入密…

炫龙T6-E7A2电脑如何U盘重装系统解决系统故障教学

炫龙T6-E7A2电脑如何U盘重装系统解决系统故障教学。对于一些比较严重的系统故障问题来说,普通用户很难去进行问题的解决。比如系统故障导致无法开机,普通用户无法自己解决问题,这个时候可以通过U盘重装系统的方法来解决,这个方法还…

深入URP之Shader篇2: 目录结构和Unlit Shader分析[上]

Unity和URP版本 我使用的Unity版本为2020.3.33f1,对应的URP和SRP Core版本为10.8.1。阅读URP源码建议把package从Library/PackageCache中拷贝到Packages目录,也就是自定义package的方式,然后推荐使用VS code打开工程,这样可以很方…

Golang【Web 入门】 08 集成 Gorilla Mux

阅读目录集成 Gorilla Mux为什么不选择 HttpRouter?安装 gorilla/mux使用 gorilla/mux迁移到 Gorilla Mux1. 新增 homeHandler2. 指定 Methods () 来区分请求方法3. 请求路径参数和正则匹配4. 命名路由与链接生成集成 Gorilla Mux 我们将选用 gorilla/mux 来作为 g…

CSS页面布局(超详解)

目录 1 CSS页面布局概述 1.1 概述 1.2 网页栏目划分 1.3 元素类型转化 1.3.1 块元素 1.3.2 行内元素 1.3.2 块元素和行内元素的转换 1.4 定位 1.4.1 静态定位 1.4.2 相对定位 1.4.3 绝对定位 1.4.4 固定定位 1.4.5 定位元索的层叠次序 1.5 浮动 1.5.1 概述 1.5…

JAVA中如何精确取到时间

文章目录0 写在前面1 使用方法2 举例3 写在最后0 写在前面 做业务的时候,总要统计数据,几月份到几月份的全部数据。这个时候就要找到起始月份的具体时间和终止月份的具体时间。 此时我们用原始的Date类去处理就比较麻烦,可以自己写一个工具类…

Web3中文|什么是以太坊虚拟机(EVM),它是如何工作的?

来源 | cointelegraph 编译 | DaliiNFTnews.com 以太坊已成为仅次于比特币的第二重要区块链。以太坊能发展得这么好,它的原生Solidity编程语言和以太坊虚拟机(EVM)发挥了重要的作用。 以太坊区块链凭借自身拥有的灵活性、大量可用的开发工…

MySQL高级SQL语句

一.准备 mysql -uroot -p123123create database train_ticket; #创建库use train_ticket; create table REGION(region varchar(10),site varchar(20)); create table FARE(site varchar(20),money int(10),date varchar(15)); #创建表desc REGION; desc FARE; #查看表结构ins…

[附源码]计算机毕业设计云南美食管理系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

浅析linux内核网络协议栈--linux bridge(二)

6. 网桥数据转发 6.1 网桥数据包入口 网桥是一种2层网络互连设备,而不是一种网络协议。它在协议结构上并没有占有一席之地,因此不能通过向协议栈注册协议的方式来申请网桥数据包的处理。相 反,网桥接口(如上述的eth1&#xff09…