深度学习Week16-yolo.py文件解读(YOLOv5)

news2025/1/30 16:34:25

目录

 简介

 需要的基础包和配置

 二、主要组件介绍

2.1 parse_model

2.2Detect类

2.3DetectionModel类

三、实验


🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

 这周接着详细解析小白YOLOv5全流程-训练+实现数字识别_牛大了2022的博客-CSDN博客_yolov5识别数字,之前入门教大家下载配置环境,如果没有的话请参考这篇的文章深度学习Week11-调用官方权重进行检测(YOLOv5)_yolov5权重_牛大了2022的博客-CSDN博客

📌 本周任务:将YOLOv5s网络模型中的C3模块按照下图方式修改形成C2模块,并将C2模块插入第2层与第3层之间,且跑通YOLOv5s。
●💫 任务提示:
  ○提示1:需要修改common.yaml、yolo.py、yolov5s.yaml文件。
  ○提示2:C2模块与C3模块是非常相似的两个模块,我们要插入C2到模型当中,只需要找到哪里有C3模块,然后在其附近加上C2即可。

 简介

common.py文件位置在./models/yolo.py

 需要的基础包和配置

import argparse
import contextlib
import os
import platform
import sys
from copy import deepcopy
from pathlib import Path
 
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]  # YOLOv5 root directory
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH
if platform.system() != 'Windows':
    ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative
 
from models.common import *
from models.experimental import *
from utils.autoanchor import check_anchor_order
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
from utils.plots import feature_visualization
from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
                               time_sync)
 
try:
    import thop  # for FLOPs computation
except ImportError:
    thop = None

 二、主要组件介绍

2.1 parse_model

用于将模型的模块拼接起来,搭建完成的网络模型。不是固定的,改框架也要改这个函数。

def parse_model(d, ch):  # model_dict, input_channels(3)
    # Parse a YOLOv5 model.yaml dictionary
    LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")
    anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
    if act:
        Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        LOGGER.info(f"{colorstr('activation:')} {act}")  # print
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)

    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
        m = eval(m) if isinstance(m, str) else m  # eval strings
        for j, a in enumerate(args):
            with contextlib.suppress(NameError):
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings

        n = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gain
        if m in {
                Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        # TODO: channel, gw, gd
        elif m in {Detect, Segment}:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)
            if m is Segment:
                args[3] = make_divisible(args[3] * gw, 8)
        elif m is Contract:
            c2 = ch[f] * args[0] ** 2
        elif m is Expand:
            c2 = ch[f] // args[0] ** 2
        else:
            c2 = ch[f]

        m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
        t = str(m)[8:-2].replace('__main__.', '')  # module type
        np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
        LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # print
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

2.2Detect类

构建Detect层的,将输入feature map 通过一个卷积操作和公式计算到我们想要的shape,为后面的计算损失或者NMS作准备。

class Detect(nn.Module):
    # YOLOv5 Detect head for detection models
    stride = None  # strides computed during build
    dynamic = False  # force grid reconstruction
    export = False  # export mode
 
    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.empty(0) for _ in range(self.nl)]  # init grid
        self.anchor_grid = [torch.empty(0) for _ in range(self.nl)]  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.inplace = inplace  # use inplace ops (e.g. slice assignment)
 
    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
 
            if not self.training:  # inference
                if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
 
                if isinstance(self, Segment):  # (boxes + masks)
                    xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
                    xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
                else:  # Detect (boxes only)
                    xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, self.na * nx * ny, self.no))
 
        return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
 
    def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
        d = self.anchors[i].device
        t = self.anchors[i].dtype
        shape = 1, self.na, ny, nx, 2  # grid shape
        y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
        yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x)  # torch>=0.7 compatibility
        grid = torch.stack((xv, yv), 2).expand(shape) - 0.5  # add grid offset, i.e. y = 2.0 * x - 0.5
        anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
        return grid, anchor_grid

2.3DetectionModel类

模型的构建工场,指定模型的yaml文件,以及一系列的训练参数

class DetectionModel(BaseModel):
    # YOLOv5 detection model
    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None):  # model, input channels, number of classes
        super().__init__()
        if isinstance(cfg, dict):
            self.yaml = cfg  # model dict
        else:  # is *.yaml
            import yaml  # for torch hub
            self.yaml_file = Path(cfg).name
            with open(cfg, encoding='ascii', errors='ignore') as f:
                self.yaml = yaml.safe_load(f)  # model dict
 
        # Define model
        ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels
        if nc and nc != self.yaml['nc']:
            LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
            self.yaml['nc'] = nc  # override yaml value
        if anchors:
            LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
            self.yaml['anchors'] = round(anchors)  # override yaml value
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist
        self.names = [str(i) for i in range(self.yaml['nc'])]  # default names
        self.inplace = self.yaml.get('inplace', True)
 
        # Build strides, anchors
        m = self.model[-1]  # Detect()
        if isinstance(m, (Detect, Segment)):
            s = 256  # 2x min stride
            m.inplace = self.inplace
            forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
            m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))])  # forward
            check_anchor_order(m)
            m.anchors /= m.stride.view(-1, 1, 1)
            self.stride = m.stride
            self._initialize_biases()  # only run once
 
        # Init weights, biases
        initialize_weights(self)
        self.info()
        LOGGER.info('')
 
    def forward(self, x, augment=False, profile=False, visualize=False):
        if augment:
            return self._forward_augment(x)  # augmented inference, None
        return self._forward_once(x, profile, visualize)  # single-scale inference, train
 
    def _forward_augment(self, x):
        img_size = x.shape[-2:]  # height, width
        s = [1, 0.83, 0.67]  # scales
        f = [None, 3, None]  # flips (2-ud, 3-lr)
        y = []  # outputs
        for si, fi in zip(s, f):
            xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
            yi = self._forward_once(xi)[0]  # forward
            # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
            yi = self._descale_pred(yi, fi, si, img_size)
            y.append(yi)
        y = self._clip_augmented(y)  # clip augmented tails
        return torch.cat(y, 1), None  # augmented inference, train
 
    def _descale_pred(self, p, flips, scale, img_size):
        # de-scale predictions following augmented inference (inverse operation)
        if self.inplace:
            p[..., :4] /= scale  # de-scale
            if flips == 2:
                p[..., 1] = img_size[0] - p[..., 1]  # de-flip ud
            elif flips == 3:
                p[..., 0] = img_size[1] - p[..., 0]  # de-flip lr
        else:
            x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale  # de-scale
            if flips == 2:
                y = img_size[0] - y  # de-flip ud
            elif flips == 3:
                x = img_size[1] - x  # de-flip lr
            p = torch.cat((x, y, wh, p[..., 4:]), -1)
        return p
 
    def _clip_augmented(self, y):
        # Clip YOLOv5 augmented inference tails
        nl = self.model[-1].nl  # number of detection layers (P3-P5)
        g = sum(4 ** x for x in range(nl))  # grid points
        e = 1  # exclude layer count
        i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e))  # indices
        y[0] = y[0][:, :-i]  # large
        i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e))  # indices
        y[-1] = y[-1][:, i:]  # small
        return y
 
    def _initialize_biases(self, cf=None):  # initialize biases into Detect(), cf is class frequency
        # https://arxiv.org/abs/1708.02002 section 3.3
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
        m = self.model[-1]  # Detect() module
        for mi, s in zip(m.m, m.stride):  # from
            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
            b.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
            b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum())  # cls
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)

三、实验

C2 模块结构图

 任务:如图所示,将C3模块改为C2模块,并将C2模块插入到第二层与第三层之间,并跑通代码。

common.py中 

 yolov5s.yaml

 yolo.py

注意添加修改会涉及多处地方。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/363282.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaEE简单示例——动态SQL的<trim>属性

简单介绍&#xff1a; 在之前我们介绍过使用<where>和<set>可以帮我们动态的添加和删除一些关键字&#xff0c;但是这些只能操作特定的关键字&#xff0c;比如where和set&#xff0c;但是有一些时候我们需要操作的关键字并不是这些常见的关键字&#xff0c;而是一…

基于SSM的婴幼儿商城

基于SSM的婴幼儿商城 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; …

软件测试3年经验就能拿30K?

1.软件测试如何实现涨薪 首先涨薪并不是从8000涨到9000这种涨薪&#xff0c;而是从8000涨到15K加到25K的涨薪。基本上三年之内就可以实现。 如果我们只是普通的有应届毕业生或者是普通本科那我们就只能从小公司开始慢慢往上走。 有些同学想去做测试&#xff0c;是希望能够日…

springboot+vue员工宿舍报修系统 uniapp微信小程序开发的

目 录 目 录 III 第一章 概述 1 1.1 研究背景 1 1.2 开发意义 1 1.3 研究现状 1 1.4 研究内容 2 1.5 论文结构 2 第二章 开发技术介绍 1 2.2 微信开发者工具 1 2.3 mysql数据库介绍 1 2.4 MySQL环境配置 2 2.5 B/S架构 2 第三章 系统分析 1 3.1 可行性分析 1 3.1.1 技术可行性…

这次,我的CentOS又ping不通www.baidu.com了(gateway配置)

当我们保证了宿主机与虚拟机的ip地址在同一网段&#xff0c;并且我们使用虚拟机ping宿主机&#xff0c;与宿主机ping虚拟机都可以互相ping通的情况下虚拟机却ping不通外网了&#xff0c;由于涉及到了跨越网络访问&#xff0c;所以我们应该把问题聚焦在网关的配置上&#xff01;…

手工布署 java 项目

新建一个java springboot项目 maven 这是一个非常简易的 springBoot 的项目 使用 maven 的 package 工具进行打包 把包上传到 linux 的机器上&#xff0c; 确保 linux 机器上安装了 java jdk工具&#xff0c; 并且配置好了 JAVA_HOME 注意&#xff0c;helloworld 默认的是要使…

Rocky 9.1操作系统实现zabbix6.0的安装部署实战

文章目录前言一. 实验环境二. 安装zabbix过程2.1. 安装zabbix源2.2 安装zabbix相关的软件2.3 安装数据库并启动2.4 开始初始化数据库&#xff1a;2.5 创建数据库实例及对应的用户2.6 导入官网提供的数据2.7 配置zabbix 服务的配置文件2.8. 启动服务2.9 从网页进行安装2.10 登陆…

H5盲盒抽奖系统源码

盲盒抽奖系统4.0&#xff0c;带推广二维码防洪炮灰功能和教程。 支持微信无限回调登录 标价就是源码价格&#xff0c;vuetp5框架编写&#xff0c;H5网页&#xff0c;前后端分离 此源码为正规开发&#xff0c;正版产品已申请软著。 开源无加密无授权&#xff0c;可以二开使用…

网络工程师必备知识点

作为网络工程师&#xff0c;您将负责设计、部署和维护计算机网络系统。这包括构建、配置和管理网络设备&#xff0c;如交换机、路由器、防火墙等&#xff0c;并确保网络系统能够高效地运行。您需要了解计算机网络的各个层次、协议、标准和技术&#xff0c;包括TCP/IP、DNS、HTT…

东京大学最新研究成果!一种可实现陆空两栖的新型四足机器人SPIDAR,具备多模态运动能力

原创/文 BFT机器人 现实中&#xff0c;蜘蛛可以凭借飘荡的蛛丝在空中漂浮&#xff0c;让它们能够穿越复杂地形。普通蜘蛛长度只有几毫米&#xff0c;重量只有几十克&#xff0c;如何让比蜘蛛重数百倍的机器人实现多模态运动&#xff0c;是众多学者研究的热点。 具有多模态运动…

分布式链路追踪-skywalking

一、分布式调用链随着业务的高速发展&#xff0c;服务之间的调用关系愈加复杂线上每一个请求会经过多个业务系统&#xff0c;并产生对各种缓存或者DB 的访问&#xff0c;业务流会经过很多个微服务的处理和传递。问题&#xff1a;• —次请求的流量从哪个服务而来&#xff1f;最…

ChatGPT这是要抢走我的饭碗?我10年硬件设计都有点慌了

前 言 呃……问个事儿&#xff0c;听说ChatGPT能写电路设计方案了&#xff0c;能取代初级工程师了&#xff1f;那我这工程师的岗位还保得住么&#xff1f;心慌的不行&#xff0c;于是赶紧打开ChatGPT问问它。 嘿&#xff0c;还整的挺客气&#xff0c;快来看看我的职业生涯是否…

图扑孪生工厂流水线组态图可视化

前言 2018 年&#xff0c;世界经济论坛(WEF)携手麦肯锡公司共同倡议并正式启动了全球“灯塔工厂网络项目”(Lighthouse Network)&#xff0c;共同遴选率先应用工业革命 4.0 技术实现企业盈利和持续发展的创新者与示范者。这就使得工厂系统需要对各流水线及生产运行成本方面进行…

非关系型数据库(mongodb)简单使用介绍

关系型数据库与非关系型数据库 关系型数据库有mysql、oracle、db2、sql server等&#xff1b; 关系型数据库特点&#xff1a;关系紧密&#xff0c;由表组成&#xff1b; 优点&#xff1a; 易于维护&#xff0c;都是使用表结构&#xff0c;格式一致&#xff1b; sql语法通用&a…

资源消耗降低 90%,速度提升 50%,解读 Apache Doris Compaction 最新优化与实现

背景LSM-Tree&#xff08; Log Structured-Merge Tree&#xff09;是数据库中最为常见的存储结构之一&#xff0c;其核心思想在于充分发挥磁盘连续读写的性能优势、以短时间的内存与 IO 的开销换取最大的写入性能&#xff0c;数据以 Append-only 的方式写入 Memtable、达到阈值…

Linux基础命令-lsof查看进程打开的文件

Linux基础命令-uptime查看系统负载 Linux基础命令-top实时显示系统状态 Linux基础命令-ps查看进程状态 文件目录 前言 一 命令的介绍 二 语法及参数 2.1 使用help查看命令的语法信息 2.2 常用参数 2.2.lsof命令-i参数的条件 三 命令显示内容的含义 3.1 FD 文件描述符的…

手工测试1年经验面试,张口要13K,我真是服了····

由于朋友临时有事&#xff0c; 所以今天我代替朋友进行一次面试&#xff0c;他需要应聘一个测试工程师&#xff0c; 我以很认真负责的态度完成这个过程&#xff0c; 大概近30分钟。 主要是技术面试&#xff0c; 在近30分钟内&#xff0c; 我与被面试者是以交流学习的方式进行的…

自动驾驶路径规划概况

文章目录前言介绍1. 路径规划在自动驾驶系统架构中的位置2. 全局路径规划的分类2.1 基础图搜索算法2.1.1 Dijkstra算法2.1.2 双向搜索算法2.1.3 Floyd算法2.2 启发式算法2.2.1 A*算法2.2.2 D*算法2.3 基于概率采样的算法2.3.1 概率路线图&#xff08;PRM&#xff09;2.3.2 快速…

docker启动容器报错No chain/target/match by that name.

一、问题描述 docker启动容器时提示: docker start xxx-search Error response from daemon: driver failed programming external connectivity on endpoint microblog-search (801478f2672887ee0fcf60eb7d7970703b4853f44f51b0b5b8622dafdb9580fb): (iptables failed: iptab…

【C ++】C++入门知识(二)

C入门&#xff08;二&#xff09; 作者&#xff1a;小卢 专栏&#xff1a;《C》 喜欢的话&#xff1a;世间因为少年的挺身而出&#xff0c;而更加瑰丽。 ——《人民日报》 1.引用 1.1.引用的概念及应用 引用&#xff08;&&#xff09; 引用不是新定义一个变量&#xff0…