YOLO学习记录之模型修改

news2024/11/16 17:29:58

我们在做实验时,不免需要对模型结构进行修改来检测自己的改进性能,对于一般模型而言,我们只需要简单的在代码中添加网络层即可,但对于一些预训练好的模型,我们则需要进行较为复杂的修改。以我们的YOLOV7模型为例,yolo_v7.pth为预训练模型,里面已经根据image_Net训练好了大量的权值,是具有通用性的,如果我们不选择该模型而选择自己重新训练的话,无疑会增大计算成本,同时也可能无法取到满意的效果。
今天主要是尝试为YOLO模型添加简单的网络层,为之后模型的修改完善打下基础。

基础知识

首先我们先来熟悉一下模型文件,.pt,.pth,.pkl的PyTorch模型文件。

它们并不存在格式上的区别,只是后缀名不同而已。在用torch.save()函数保存模型文件的时候,有些人喜欢用.pt后缀,有些人喜欢用.pth或 .pkl,用相同的 torch.save()语句保存出来的模型文件没有什么不同。在PyTorch官方的文档里,有用.pt的,也有用.pth的。

据某些文章的说法,一般惯例是使用 .pth,但是官方文档里貌似.pt居多,而且官方也不是很在意固定地用某一种。

简单测试

我们来简单测试一下模型文件的生成,保存和读取:
我们自定义一个模型,然后将其保存并读取:

import torch
from torch import nn
class Qu(nn.Module):
    def __init__(self):
        super(Qu, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return
def test_model():
    qu = Qu()
    torch.save(qu, "qu_method1.pth")
    model = torch.load("qu_method1.pth")
    print(model)
test_model()

在这里插入图片描述
然后我们读取一下yolov7_weights.pth文件:

def load_model():
    import torch
    weights = '../model_data/yolov7_weights.pth'
    net = torch.load(weights)
    print(type(net), len(net))
    for k, v in net.items():
        print(k, type(v), v.size())

load_model()

在这里插入图片描述
可以看到其读取出的文件为权重值,偏置项的配置,包含初始化值,这便是预训练权重,是在image_net上训练得到的,具有通用性,而我们便是在修改了自己的网络模型上,在预训练权值的基础上进行微调,从而得到符合我们数据集的权重,进而完成我们的实验。

微调(Fine Tune)

什么是模型微调呢,比如我们已知一个网络模型:

Y=Wx  这里我们没有设置偏置项  

我们想要找到W,使X=2时,Y=1,即W=0.5
那么我们就要对W进行初始化,其初始化值符合均值为0,方差为1的分布,假设我们开始初始值为0.1,当我们的X=2时,Y=0.2,此时Y的实际值与理想值误差为0.8,相差较大,0.8的误差值去反向传播更新W,假设此时更新为0.2,那么依旧有0.6的误差,可能经过十几次乃至几十次的反向传播,最后我们得到了理想的权重值。
而如果一开始时,有人告诉我们说我们的权重值在0.48附近,那么我们我们第一次的误差值便只有0.04了,那么我们肯能只需要几次反向传播便可以得到理想的结果,我么是在一个已有范围的基础上稍微调整,即称为微调。

这个告诉我么的初始权值范围便相当于一个预训练模型,而我么之后的训练便是微调的过程。

我们选择的预训练模型一般都是在image_net,VOC,COCO等这种大型数据集上训练得到的,具有公信力和通用性。而如果我么自己从头训练的话,若是数据集数量过少,而我们的权值参数数量很多,那么就可以存在过拟合线性,泛化性能不佳。

何时可用微调?

1.数据集很相似,个人数据集与预训练数据集很相似
2.数据集很相似,但数量太少,不能满足训练要求
3.计算资源匮乏,如果计算力差,那么使用预训练模型无疑是一个好的选择。
4.自己搭建的模型准确性太差

通过对我们拥有的较小数据集进行训练(即反向传播),对现有网络进行微调,这些网络是在像ImageNet这样的大型数据集上进行训练的,以达到快速训练模型的效果。假设我们的数据集与原始数据集(例如ImageNet)的上下文没有很大不同,预先训练的模型将已经学习了与我们自己的分类问题相关的特征。

我们也可以冻结网络中的层数来进行训练。

最后我们调用一下YOLO模型:

def yolo_model():
    import nets.yolo as yolo
    import torch
    YOLO=yolo.YoloBody()
    torch.save(YOLO, "YOLO.pth")
    model = torch.load("YOLO.pth")
    print(model)
yolo_model()

在这里插入图片描述

可以看到其网络输出通道数以及特征图大小与我们的模型图一致。

在这里插入图片描述

模型修改

终于到我们的重头戏了,首先我们先要定义一下我们的模型结构,博主定义了一个SE模块,这是一个通道注意力模型。

import torch
import torch.nn as nn
class SELayer(nn.Module):
    def __init__(self, c1, r=16):
        super(SELayer, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // r, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // r, c1, bias=False)
        self.sig = nn.Sigmoid()
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)

在这里插入图片描述
demo:
在这里插入图片描述

模型结构如图所示,然后我们需要确定我们想要将模型结构所添加的位置,我们选择一个容易添加的位置,比如在yolo的head头的最后的部分。如下图所示

在这里插入图片描述

那么确定了要添加位置后就在网络结构中进行定义;

在这里插入图片描述

然后再前向传播中引入:

在这里插入图片描述

完成后我们开始训练,此时我们使用的依然是yolov7_weights.pth这个预训练模型。为了方便实验,博主只进行了一次迭代。

在这里插入图片描述
保存好我们训练的模型后,此时的pth里面是包含我们刚刚训练好的参数的。
我们计算mAP值来看看加入SENet后的效果,原mAP为90.07%:
在这里插入图片描述
将模型替换为刚刚训练好的模型文件:将yolo文件中的模型地址替换:
在这里插入图片描述
计算mAP可以看到,效果很差,理论上计算效果差些也不该直接没有结果的,这说明我们的改进肯定出问题了:
在这里插入图片描述
呜呜呜,当然也可能是训练次数太少导致的,正在找原因。。。。后面找到原因后会更新的

常用模块

下面介绍几种即插即用的注意力机制模块

CBAM模型

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
 
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        # 写法二,亦可使用顺序容器
        # self.sharedMLP = nn.Sequential(
        # nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
        # nn.Conv2d(in_planes // rotio, in_planes, 1, bias=False))
 
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out
 
 
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
 
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
 
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)
 
 
class CBAM(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAM, self).__init__()
        # c_ = int(c2 * e)  # hidden channels
        # self.cv1 = Conv(c1, c_, 1, 1)
        # self.cv2 = Conv(c1, c_, 1, 1)
        # self.cv3 = Conv(2 * c_, c2, 1)
        # self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)
 
        # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
 
    def forward(self, x):
        out = self.channel_attention(x) * x
        # print('outchannels:{}'.format(out.shape))
        out = self.spatial_attention(out) * out
        return out

ECA模块

class eca_layer(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)
        x=x*y.expand_as(x)

        return x * y.expand_as(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/134192.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Linux]----守护进程

文章目录前言一、什么是守护进程?二、会话和进程组会话进程组三、守护进程的编程流程总结前言 这节课我来给大家讲解在Linux下如何让进程守护化,运行在后台,处理我们的任务. 正文开始! 一、什么是守护进程? 守护进程也称为精灵进程(Daemon),是运行在后台的一种特殊进程.它…

Mybatis-Plus快速使用相关知识点1

Mybatis-Plus的mapper、service 基本CURD BaseMapper BaseMapper是MyBatis-Plus提供的模板mapper,其中包含了基本的CRUD方法,泛型为操作的实体类型,Mapper 继承该接口后,无需编写 mapper.xml 文件,即可获得CRUD功能…

JavaScript刷LeetCode拿offer-链表篇

一、链表 链表(Linked List)是一种常见的基础数据结构,也是线性表的一种。 一个线性表是 n 个具有相同特性的数据元素的有限序列,线性表的存储结构分为两类:顺序表(数组)和链表。 链表相比较顺…

站得高,望得远

1、站得高,望的远 计算机科学领域的任何问题都可以通过增加一个间接的中间层来解决。 这句话几乎概括了计算机系统软件体系结构的设计要点 ,整个体系结构从上到下都是按照严格的层次结构设计的。不仅是计算机系统软件整个体系是这样的,体系里…

884. 两句话中的不常见单词 map与stringstream

目录 力扣884. 两句话中的不常见单词 【解法一】:最后写出了一坨屎,虽然它是一坨屎,但是它能动,虽然它是一坨屎,但起码这是我自己拉的 【大佬解法】 stringstream的使用 以及 map的使用 884. 两句话中的不常见单词 句…

python实现bib文件中参考文献的题目每个单词首字母大写

文章目录前言实现思路前言 由于毕业论文格式要求英文参考文献的题目的每个单词(除了介词)的首字母都要大写,如果一条条地自己修改费时费力,这里就想着简单地用python操作字符串的方式实现。 实现思路 观察bib参考文献格式&#x…

20230102单独编译Toybrick的TB-RK3588X开发板的Android12的内核

20230102单独编译Toybrick的TB-RK3588X开发板的Android12的内核 2023/1/2 17:40 《RK3588_Android12_SDK_Developer_Guide_CN.pdf》 原厂的开发板rk3588-evb1-lp4-v10单独编译内核的方式: cd kernel-5.10 export PATH../prebuilts/clang/host/linux-x86/clang-r4161…

【数据结构】C语言实现链表(单链表部分)

目录 前言 链表 链表的分类 1.单向或者双向 2.带头或者不带头 3.循环或者非循环 单链表实现 定义节点 接口函数实现 创建节点 打印链表 尾插节点 尾删节点 头插节点 头删节点 单链表查找 删除指定位置后的节点 指定位置后插入节点 删除指定位置 指定位置插入节点…

Linux-7 文本编辑vivim

Linux-7 文本编辑vi/vim vim介绍 什么是vim? vi和vim是Linux下的一个文本编辑工具。(可以李姐为Windows的记事本或word文档) 为什么要使用vim? 因为Linux系统一切皆为文件,而我们工作最多的就是修改某个服务的配置&a…

一名七年老安卓的 2022 总结

大家好,我是 shixin。一转眼到了 2022 的最后一天,今年发生了很多事,这篇文章来总结一下。长短期目标达成情况和去年一样,我的长期目标是成为具备创业能力的人,包括商业思维和全栈技术能力。总的来说,今年是…

STM32MP157驱动开发——USB设备驱动

STM32MP157驱动开发——USB设备驱动一、简介1.电气属性2.USB OTG3.STM32MP1 USB 接口简介4.Type-C 电气属性二、USB HOST 驱动开发1.USB HOST 驱动编写2.配置 PHY 控制器3.配置usbh_ehci三、USB HOST 测试1.鼠标键盘驱动使能2.U盘驱动四、USB OTG驱动开发1.USB OTG 控制器节点信…

系统设计实战一

文章目录前言一、服务幂等1.防止订单重复下单1.1 场景如下:当用户在提交订单的时候1.2 重复下单解决方案1.3案例一幂等性总结2 防止订单ABA问题2.1 场景如下:当在修改订单用户信息的时候发生服务器或者网络问题导致的重试2.2 ABA问题解决方案2.3 业务ABA…

Mac本地安装Mysql并配置

文章目录一、安装Mysql二、配置Mysql三、启动mysql四、SQL语法初步了解1.创建数据库2.建表3.查看表一、安装Mysql 笔者推荐采用安装包的方法安装Mysql,比较简单,适合新手。 首先在网上搜安装包: baidu按关键字搜即可:mysql mac安…

多兴趣向量重构用户向量

Re4: Learning to Re-contrast, Re-attend, Re-construct for Multi-interest Recommendation 论文地址:https://arxiv.org/pdf/2208.08011.pdf 一般的多兴趣建模过程是对用户序列进行编码,抽取出用户的多个兴趣向量,然后利用这些用户兴趣向…

【Vue中使用Echarts】echarts初体验

文章目录一、echarts简介二、初次体验echarts1.下载2.在vue中引入echarts①全局引入(代码)② 局部引入一、echarts简介 在大数据盛行的今天,数据可视化变得越来越广泛。而在前端工作中,数据可视化用得最多的,可能就是…

Usaco Training 刷怪旅 第三层 第四题 :Combination Lock

一个六年级博主写文章不容易,给个关注呗 (点赞也行啊) 本蒟蒻的bilibili账号 注:这种题当你看不懂的时候是可以把题目复制去洛谷看中文版的 Farmer Johns cows keep escaping from his farm and causing mischief. To try and pre…

如何通过 Python 与 ChatGPT 对话

文章目录简介安装 OpenAI API实例1预备条件: 1. 科学上网; 2. 注册 OpenAI 账号。 简介 ChatGPT 是 GPT-3 语言模型的变体,专为会话语言生成而设计。要在 Python 中使用 ChatGPT,您需要安装 OpenAI API 客户端并获取 API 密钥。当前提你需要…

前端工程师leetcode算法面试必备-二分搜索算法(中)

一、前言 二分搜索算法本身并不是特别复杂,核心点主要集中在: 有序数组:指的是一个递增或者递减的区间(特殊情况如:【852. 山脉数组的峰顶索引】); 中间数:用来确定搜索目标落在左…

Pytorch学习笔记①——anaconda和jupyter环境的安装(小白教程)

一、安装Pytorch 1、首先找到anaconda命令端并点击进入。 2、输入如下命令创建子空间(博主的命名是pytorch1.4.0,使用python3.6版本) conda create -n pytorch1.4.0 python3.6对于下载速度慢的话,首先需要进行换源,换…

FastJson不出网rce

BCEL ClassLoader去哪了 0x01 BCEL从哪里来 首先,BCEL究竟是什么?它为什么会出现在JDK中? BCEL的全名应该是Apache Commons BCEL,属于Apache Commons项目下的一个子项目。Apache Commons大家应该不陌生,反序列化最著…