YOLOV3 Pytorch版本代码解读

news2024/9/22 9:44:55

YOLOV3 Pytorch版本代码解读

代码与coco数据集关注wx公众号JokerTong回复yolov3即可获取
参考视频 YOLO系列算法

文章目录

  • YOLOV3 Pytorch版本代码解读
    • 数据集准备与关键文件说明
    • 前提准备
      • 代码大致流程
      • 需要自行修改代码的部分
    • 项目代码解读
      • 一 数据与标签的读取
      • 二 模型构造
        • convolutional层的构建
        • rout层与shortcut层的构建
        • yolo层的构建
      • 三 前向传播
      • 四 计算损失


数据集准备与关键文件说明

使用经典的coco2014数据集,下载地址点击此处进入官网下载(也可以自行去网上搜索)
下载之后解压到项目对应的文件夹, 如下

下载的数据集imagelabel的版本需要一一对应
在这里插入图片描述
trainvalno5k.txt文件
在这里插入图片描述
5k.txt文件 : 验证集的数据的位置
在这里插入图片描述
PyTorch-YOLOv3\config\yolov3.cfg
网络配置文件
在这里插入图片描述

前提准备

代码大致流程

第一步: 加载配置参数
在这里插入图片描述
第二步: 构造模型
在这里插入图片描述
第三步: 读取数据
在这里插入图片描述

Tips: 以前在小数据集上进行训练的时候我们可以将数据集全部加载到内存中,但是由于coco数据集太大了, 内存放不下, 因此我们使用generator来提供数据, 在训练的过程中才读取数据, 根据模型的需要一个batch一个batch的为其提供数据
在这里插入图片描述

需要自行修改代码的部分

添加训练参数

--data_config config/coco.data  
--pretrained_weights weights/darknet53.conv.74

在这里插入图片描述
coco.data描述了训练数据集所需要的所有信息: 类别, 训练数据, 验证数据
在这里插入图片描述
pretrained_weights 迁移学习的思想, 加载一个预训练模型

修改数据集路径
在这里插入图片描述
在这里插入图片描述

项目代码解读

一 数据与标签的读取

109行打上断点, 观察数据的读取过程
在这里插入图片描述
点击Step Into My Code进入项目中的datasets.py文件, 可以看出, dataset通过getitem一张一张的读取图片
在这里插入图片描述
使用Image.open实际读取图片,统一通道为RGB,并且转换为tensor格式
在这里插入图片描述
使用padding的思想, 把原本的长方形图片padding为正方形
在这里插入图片描述
与读取图像类似, 读取对应的标签, 里面包括了类别, box的信息
在这里插入图片描述
这里要注意 图片的编号与标签的编号应该是对应的, 不然数据与标签不匹配训练结果啥也不是
在这里插入图片描述
读取label文件中的物体框信息
在这里插入图片描述

原始图像经过了padding处理, 因此标签中框的坐标也需要进行padding的操作, 最后转化为网络中需要的x,y,w,h格式
在这里插入图片描述
这一系列操作之后, 得到一个imgtarget, 后面会反复进行batch次, 然后返回模型一个batch的数据
在这里插入图片描述

二 模型构造

重新把断点打在66
在这里插入图片描述

进入models.py, init函数中根据之前的yolov3.cfg文件对网络整体架构进行定义, forward函数规定了网络前向传播的整个过程

在这里插入图片描述
加载配置文件的信息到self.module_defs变量中
在这里插入图片描述

self.hyperparams, self.module_list = create_modules(self.module_defs)  # 逐层定义好网络结构

调用create_modules来进行模型的构建
在这里插入图片描述
循环增加网络层

在这里插入图片描述

convolutional层的构建

注意这里的convolutional其实可能包括了卷积, BN, 以及Relu三种操作, 先将里面的参数信息读取出来

在这里插入图片描述
**根据参数信息 添加Conv2d 也就是卷积层 **在这里插入图片描述
如果有bn层以及激活函数, 将它们加入到网络结构当中
在这里插入图片描述
这里在控制台打印一下当前构造的modules, 其中包含了我们想要的三种层
在这里插入图片描述
也就是将之前构造完成的第一个模块加入到大的module_list当中, 并且记录下filters的输出个数
在这里插入图片描述

rout层与shortcut层的构建

这里只是创建了个空的层, 具体的操作执行在前向传播中
在这里插入图片描述
rout层主要起到拼接的作用, 起到了特征融合的作用
在这里插入图片描述

shortcut层主要是加法的作用(resnet的思想), 残差连接的功能已经很熟悉了, 这里就不介绍了
配置文件中的-3表示跟
在这里插入图片描述

yolo层的构建

YOLOV3中有三种YOLO层, 它们分别可以检测大中小三种物体

YOLOV3中有三种scale的先验框, 对应着这三种YOLO层, 感受野比较大的YOLO层对应使用大的先验框

这里通过anchor的编号获取实际先验框的大小

在这里插入图片描述
调用YOLOLayer来构建出YOLO层, 具体的前向传播在后面介绍
在这里插入图片描述

三 前向传播

在各个模型的forward处打上断点
首先进入的是Darknetforward
在这里插入图片描述
这里的卷积, 上采样, 和最大池化层的前向传播都非常简单, 根据pytorch内置的函数, 直接x = module(x)就可以了

            if module_def["type"] in ["convolutional", "upsample", "maxpool"]:
                x = module(x)
            elif module_def["type"] == "route":
                x = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
            elif module_def["type"] == "shortcut":
                layer_i = int(module_def["from"])
                x = layer_outputs[-1] + layer_outputs[layer_i]
  • route层的效果是按某个维度做拼接
  • shortcut层的效果是对两层做加法

最重要的是YOLO层

在这里插入图片描述
这里的img_dim可能有多种大小 但都能被32整除, 因为YOLOV3网络会随机的选择一个大小的图片进行训练

对于预测的维度进行调整, 4 * 3 * 15 * 15 * 85

  • 4 表示 batch_size
  • 3 表示 先验框的有三种
  • 15 * 15 为特征图的大小
  • 85 = 80 + 5 表示80个类别 与 x,y,w,h 与 置信度
    在这里插入图片描述
    取出其中的x,y,w,h,c与每个类别的预测值
    在这里插入图片描述
    通过相对位置得到对应的绝对位置
    在这里插入图片描述
    得到特征图(如当前为15*15)中各个坐标的实际位置
    在这里插入图片描述
    因为标签中的框是在原始图像中的, 所以output要把预测框也放大相应的倍数还原到原始图像中
    在这里插入图片描述

四 计算损失

通过build_targets函数将标签值进行转换, 转换成和预测值相同的格式, 这里可以点进去自己看一下build_targets的内容
在这里插入图片描述
计算loss

  • 对于x,y,w,h来说, 并不是所有的位置都要算一遍, 我们只计算有物体的位置处的损失, 所以这里用了obj_mask作为index
  • loss_conf_obj, loss_conf_noobj计算的是前景和背景的损失, 也就是当前位置是不是物体. 因为这里的预测值和真实值只有01, 所以使用bce_loss即可计算
    他们两个相加, 乘上相应的权重参数就得到了置信度损失loss_conf_noobj
  • loss_cls分类损失的原理也类似

最终将所有的损失相加就得到了总损失
在这里插入图片描述
这里附上一张损失函数的计算图像, 可以看出和代码中的一样, 很好理解
请添加图片描述
后面的操作基本上是通用的pytorch训练模式

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/177209.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据库工具类的编写

package com.bjpowernode.jdbc.utils;import java.sql.*; import java.util.ResourceBundle;/*** 数据库工具类简化JDBC的代码编写。** 在同一个没有结束的程序中,DBUtil类只加载一次,加载一次以后,再次调用该类中的方法,本不会再…

基于Echarts构建大数据招聘岗位数据可视化大屏

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

SpringBoot 3.0.x使用SpringDoc

为什么使用SpringDoc 在Springfox3.0停更的两年里,SpringBoot进入3.0时代, SpringFox出现越来越多的问题,最为明显的就是解析器的问题,已经在上文 中解释清楚,这里就不再赘述。 SpringDoc是Spring官方推荐的API&#x…

Spring笔记上(基于注解开发)

一、第三方资源配置管理 以DataSource连接池对象为例,进行第三方资源配置管理。 1. 管理DataSource连接池对象 spring整合Druid、C3P0数据库连接池 1.1 管理Druid连接池 1、准备数据 create database if not exists spring_db character set utf8; use spring_db; …

二、python基础语法篇(黑马程序猿-python学习记录)

黑马程序猿的python学习视频:https://www.bilibili.com/video/BV1qW4y1a7fU/ 目录 一 、print 1. end 2. \t对齐 二、字面量 1. 字面量的含义 2. 常见的字面量类型 3. 如何基于print语句完成各类字面量的输出 三、 注释的分类 1. 单行注释 2. 多行注释 3. 注释的…

MXNet的Faster R-CNN(基于区域提议网络的实时目标检测)《9》

MXNet的Faster R-CNN(基于区域提议网络的实时目标检测)《1》:论文源地址,克隆MXNet版本的源码,安装环境与测试,以及对下载的源码的每个目录做什么用的,做个解释。 MXNet的Faster R-CNN(基于区域提议网络的实时目标检测…

【手写 Vue2.x 源码】第四十一篇 - 组件部分 - 生成组件的真实节点

一,前言 上篇,介绍了组件部分-组件的生命周期,主要涉及以下几部分: 本篇,组件部分-生成组件的真实节点; 二,生成组件的真实节点 1,前文回顾 前篇,在 createElement 方…

【剧前爆米花--爪哇岛寻宝】Java中有关异常类的详细讲解

作者:困了电视剧 专栏:《JavaSE语法与底层详解》 文章分布:这是一篇关于Java中异常类的文章,在本篇文章中详细讲解了异常的使用逻辑和底层的执行过程,如有疏漏,欢迎大佬指正! 目录 异常的体系结…

<C++>哈希

文章目录1. unordered系列容器1.1 unordered_map1.1.1 unordered_map的文档介绍1.1.2 unordered_map的接口说明1.2 unordered_set2. 哈希概念3. 哈希冲突4. 哈希函数5. 哈希冲突解决5.1 闭散列5.1.1 线性探测5.1.2 二次探测5.2 开散列5.3 开散列与闭散列比较6. 模拟实现1. unor…

配置野火霸道V2环境

配置野火霸道V2环境野火霸道开发板学习笔记信息说明下载安装Keil5配置Keil以使用DAP下载器DAP下载器的使用使用串口下载程序安装USB转串口驱动CH340检查是否安装成功配置MCUISP软件野火霸道开发板学习笔记 信息说明 日期 : 2023-01-23开发板: 野火霸道V2芯片型号: STM32F103Z…

[Paper Reading] Towards Conversational Recommendation over Multi-Type Dialogs

[Paper Reading] Towards Conversational Recommendation over Multi-Type Dialogs 文章目录[Paper Reading] Towards Conversational Recommendation over Multi-Type Dialogs论文简介快速回顾论文(借助scispace)梳理一下文章内容(参考百度N…

自动化将Gitee的仓库导入Github

自动化将Gitee的仓库导入Github准备工作获取方式gitee的授权码github授权码工具源码用法下载gitee所有仓库到本地下载并更新到github(自动创建仓库)写在最后本方法能实现自动创建仓库 脚本及用法放在文章最后了,需要的自取 转跳到结尾 准备工…

高性能 Java 框架。Solon v1.12.3 发布(春节前兮的最后更)

一个更现代感的 Java "生态型"应用开发框架:更快、更小、更自由。不是 Spring,没有 Servlet,也无关 JavaEE;新兴独立的开放生态 (已有150来个生态插件) 。主框架仅 0.1 MB。 相对于 Spring Boot…

计算正整数的阶乘math.factorial()

【小白从小学Python、C、Java】【计算机等级考试500强双证书】【Python-数据分析】计算正整数的阶乘math.factorial()[太阳]选择题请问math.factorial(3)的输出结果是?import mathprint("【执行】math.factorial(3):",math.factorial(3))print("【执行】math.f…

带你玩转Jetson之Deepstream简明教程(二)Deepstream是什么?干什么?有什么优势?

1.Deepstream是什么? Deepstream是Nvidia公司推出的一套基于开源视频流框架Gstreamer的一套库。其本身由多个.lib.so和.h构成,其支持语言包括了Python和Cpp两种主流语言。你可以在任何Python或者Cpp编译器、开发环境中引用库的API构建属于你自己的推理流…

【c++之于c的优化】

目录:前言关键字一、命名空间1.什么是命名空间2.如何使用命名空间3.如何自己创建命名空间4.为什么要使用命名空间5.命名空间起别名6.匿名命名空间二、缺省参数定义缺省参数类型注意事项三、函数重载定义函数重载的三种方式操作系统的区分方式四、引用定义引用特性使…

【4-网络八股扩展】北京大学TensorFlow2.0

课程地址:【北京大学】Tensorflow2.0_哔哩哔哩_bilibiliPython3.7和TensorFlow2.1六讲:神经网络计算:神经网络的计算过程,搭建第一个神经网络模型神经网络优化:神经网络的优化方法,掌握学习率、激活函数、损…

【LeetCode每日一题】【2023/1/24】1828. 统计一个圆中点的数目

文章目录1828. 统计一个圆中点的数目方法1:枚举1828. 统计一个圆中点的数目 LeetCode: 1828. 统计一个圆中点的数目 中等\color{#FFB800}{中等}中等 给你一个数组 points ,其中 points[i] [x_i, y_i] ,表示第 i 个点在二维平面上的坐标。多…

【算法面试】队列算法笔试面试全解(金三银四面试专栏启动)

📫作者简介:小明java问道之路,专注于研究 Java/ Liunx内核/ C及汇编/计算机底层原理/源码,就职于大型金融公司后端高级工程师,擅长交易领域的高安全/可用/并发/性能的架构设计与演进、系统优化与稳定性建设。 &#x1…

02_gpio子系统

总结 驱动程序还想控制gpio 可以不用读写寄存器 直觉用gpio子系统开发的接口就能用了 轻松做输入输出 获取当前值 详细介绍 用设备树里的节点 gpio1 介绍 imx6ull.dtsi gpio1 记录了控制器相关的寄存器基地址 gpio1: gpio209c000 {compatible "fsl,imx6ul-gpio"…