yolov5网络初始化问题

news2024/9/23 11:20:06

当你打印detect层的三个特征层时,发现有三种不同的长和宽,如下图所示:
我提出三个问题:
为什么不一样呢,输入有什么含义吗?
为什么网络初始化四次(forward)?
下面来逐个击破

在这里插入图片描述

1. torch.Size([1, 3, 32, 32, 8]) (这个数据为detect层输出的最大宽度特征层)

第一层调用:train.py

model = Model(cfg or ckpt["model"].yaml, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device)  # create

第二层调用:

在yolo/DetectionModel里面定义的,是一个固定的输入,为[1,3,256,256]卷积完之后就如上。

使用256这个参数主要是因为①最大stride的倍数(8,16 ,32,64…),②这个数降采样之后的值真好,不会造成资源的浪费。

主要是用来网络初始化的,创建网络的

        if isinstance(m, (Detect, Segment)):
            s = 256  # 2x min stride   256
            m.inplace = self.inplace
            car_detect=[0,0,0,0]
            forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
            _,rs=forward(torch.zeros(1, ch, s, s))  #forward
            m.stride = torch.tensor([s / x.shape[-2] for x in rs[0]])  # forward torch.Size([1, 3, 32, 32, 8])
            # if m.stride==torch.tensor([]):
            #     m.stride = torch.tensor([8, 16, 32])
            check_anchor_order(m)
            m.anchors /= m.stride.view(-1, 1, 1)
            self.stride = m.stride
            self._initialize_biases()  # only run once

2. torch.Size([1, 3, 4, 4, 8])

第一层调用:train.py

model = Model(cfg or ckpt["model"].yaml, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device)  # create

第二层:

还是在yolo/DetectionModel里面实现的。

        # Init weights, biases
        initialize_weights(self)
        self.info()  # 第二遍 计算层数,参数,梯度等 YOLOv5s summary: 245 layers, 8091510 parameters, 8091510 gradients, 16.8 GFLOPs 
        LOGGER.info("")

主要是self.info()这个函数。

其中im是输入,是[1,3,32,32]卷积出来第一个卷积层也是上面的。

为什么是32呢,这个是因为①是最大stride,降采样使能成功 ②为什么不使用其他32的倍数,因为这个是最小计算量,确保网络能够正确处理图像的前提。

主要是来计算网络的参数的,如层数,参数,计算量等。

这个flops计算量是这个模型的最快执行时间。

        p = next(model.parameters())  #  获取第一个模型的参数:32,3,6,6
        stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32  # max stride  压缩程度
        # torch.empty创建任意数据类型的张量  torch.tensor() 只创建torch.FloatTensor类型的张量
        # 使用32是因为①是最大stride,降采样使能成功 ②为什么不使用其他32的倍数,因为这个是最小计算量,确保网络能够正确处理图像的前提
        im = torch.empty((1, p.shape[1], stride, stride), device=p.device)  # input image in BCHW format
        # 浮点运算次数,可以用来衡量算法/模型复杂度 1GFLOPs = 10^9 FLOPs
        # 计算量(时间复杂度,flops) 与输入参数有关系 网络执行时间的长短
        # 参数量(空间复杂度,params)占用显存的大小 只与网络有关系
        # 这个地方除以2 是因为加法(偏置)可能没有算进去,所以初一二让他接近真实值,flops值越大越好
        flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1e9 * 2  # stride GFLOPs thop.profile计算flops,verbose是日志显示
        imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz]  # expand if int/float
        fs = f", {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs"  # 640x640 GFLOPs  计算真实图片的flops,使用最大stride就是为了简化计算,作为一个标准,

3. torch.Size([1, 3, 80, 60, 8])

第一步调用:

是在train中调用的,想要统计是否使用AMP(自动混合精度)

amp = check_amp(model)  # check AMP  第三次  计算是否使用amp自动混合精度(torch16和torch32)

第二步调用:

下面会调用Autoshape,im就是引用的data/imges/bus.jpg的一张yolo自带的图,进行初始化的。im进行resize后的shape是[1,3,640,480]。

主要是想用一张图片,然后用两种方式FP32 inferenceAMP inference进行推理,然后计算相似度,大于阈值,就是用AMP。

为什么使用AutoShape类,首先这个对输入包容性很大,无论是file还是uri或者numpy,torch等其他类型都可以进行统一预测,输出结果。

            n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims])  # number, list of images
            shape0, shape1, files = [], [], []  # image and inference shapes, filenames
            for i, im in enumerate(ims):
                f = f"image{i}"  # filename
                if isinstance(im, (str, Path)):  # filename or uri
                    im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im), im
                    im = np.asarray(exif_transpose(im))
                elif isinstance(im, Image.Image):  # PIL Image
                    im, f = np.asarray(exif_transpose(im)), getattr(im, "filename", f) or f
                files.append(Path(f).with_suffix(".jpg").name)
                if im.shape[0] < 5:  # image in CHW
                    im = im.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
                im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)  # enforce 3ch input
                s = im.shape[:2]  # HWC
                shape0.append(s)  # image shape
                g = max(size) / max(s)  # gain
                shape1.append([int(y * g) for y in s])
                ims[i] = im if im.data.contiguous else np.ascontiguousarray(im)  # update
            shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)]  # inf shape  640,480
            x = [letterbox(im, shape1, auto=False)[0] for im in ims]  # pad
            x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2)))  # stack and BHWC to BCHW
            x = torch.from_numpy(x).to(p.device).type_as(p) / 255  # uint8 to fp16/32

        with amp.autocast(autocast):
            # Inference
            with dt[1]:
                y = self.model(x, augment=augment)  # forward

总结

第几次调用forward输入尺寸作用
第一次调用torch.Size([1, 3, 256, 256])主要用于创建网络,计算stride的值
第二次调用torch.Size([1, 3,32, 32 ])主要用于计算网络参数的,如层数,参数,计算量等
第三次调用torch.Size([1, 3, 640, 480])主要是确认是否使用amp

注:此处的数据建立在stride的最大值为32的

专栏指路:

YOLOv5评价指标:yolov5 评价指标_yolov5评价指标-CSDN博客

YOLOv5网络结构:yolov5 网络结构_yolov5头部网络-CSDN博客

YOLOv5主要流程:yolov5 主要流程_yolov5网络流程-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2050732.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode - LCR 146- 螺旋遍历二维数组

LCR 146题 题目描述&#xff1a; 给定一个二维数组 array&#xff0c;请返回「螺旋遍历」该数组的结果。 螺旋遍历&#xff1a;从左上角开始&#xff0c;按照 向右、向下、向左、向上 的顺序 依次 提取元素&#xff0c;然后再进入内部一层重复相同的步骤&#xff0c;直到提取完…

模型训练坎坷路--逐步提升模型准确率从40%到90%+

文章目录 〇、前言一、更改学习率1.原理&#xff1a;欠拟合需要减小学习率2.效果-->有用&#xff01; 二、更改训练批次batch_size1.原理&#xff1a;更大的批量大小时&#xff0c;梯度估计更加精确2.效果-->有点用 三、更改数据预处理方式1.原理&#xff1a;数据可能没有…

【微服务】springboot 整合表达式计算引擎 Aviator 使用详解

目录 一、前言 二、表达式计算框架概述 2.1 规则引擎 2.1.1 什么是规则引擎 2.1.2 规则引擎用途 2.1.3 规则引擎使用场景 2.2 表达式计算框架 2.2.1 表达式计算框架定义 2.2.2 表达式计算框架特点 2.2.3 表达式计算框架应用场景 2.3 表达式计算框架与规则引擎异同点 …

二叉树练习习题集一(Java)

1. 思路&#xff1a; 就是让左孩子和右孩子进行交换&#xff0c;这里需要一个中间变量用来记录&#xff0c;然后完成交换。如果进行优化则添加当左孩子和右孩子都为null时直接返回。 class Solution {public TreeNode invertTree(TreeNode root) {TreeNode tmpnull;//用来进行…

C++适配windows和linux下网络编程TCP简单案例

C网络编程 网络协议是计算机网络中通信双方必须遵循的一套规则和约定&#xff0c;用于实现数据的传输、处理和控制。这些规则包括了数据格式、数据交换顺序、数据处理方式、错误检测和纠正等。网络协议是使不同类型的计算机和网络设备能够相互通信的基础&#xff0c;是网络通信…

PDF转markdown工具:magic-pdf

1. magic-pdf 环境安装 conda create -n MinerU python3.10 conda activate MinerU pip install boto3>1.28.43 -i https://pypi.tuna.tsinghua.edu.cn/simple/ pip install magic-pdf[full]0.7.0b1 --extra-index-url https://wheels.myhloli.com -i https://pypi.tuna.t…

SSA-SVM多变量回归预测|樽海鞘群优化算法-支持向量机|Matalb

目录 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 亮点与优势&#xff1a; 二、实际运行效果&#xff1a; 三、算法介绍&#xff1a; 四、完整程序下载&#xff1a; 一、程序及算法内容介绍&#xff1a; 基本内容&#xff1a; 本代码基于Matlab平台编译&a…

Chrome浏览器更改默认User-Agent

一、业务需求 二、插件下载 三、插件使用 原创文章&#xff0c;请勿转载&#xff01; 详细教程教你如何更改默认浏览器的User-Agent&#xff0c;几分钟足以&#xff01; 一、业务需求 当我们遇到一些特定的UA才能访问的网址时&#xff0c;我们就可以通过一些手段来修改我们浏…

Python之字符串练习题(下)

21.nameStr“Albert Einstein"&#xff0c;如何使用字符串运算符“:”来提取 nameStr 中的名和姓? mingnameStr[:6] xingnameStr[7:]23.下面哪些语句在运行时不会出错? (a)var xyz ’ * 10.5 (b)var ‘xyz’ * ‘5 ©var‘’xyz’*5 (d)var‘xyz’*5.0 重复运算符…

HTML静态网页成品作业(HTML+CSS)——美食企业介绍设计制作(1个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有1个页面。 二、作品演示 三、代…

在亚马逊云科技上对Stable Diffusion模型提示词、输出图像内容进行安全审核

项目简介&#xff1a; 小李哥将继续每天介绍一个基于亚马逊云科技AWS云计算平台的全球前沿AI技术解决方案&#xff0c;帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS AI最佳实践&#xff0c;并应用到自己的日常工作里。 本次介绍的是如何在亚马逊云科技机器学习托…

HighPoint SSD7749M2:128TB NVMe 存储卡实现28 GB/s高速传输

HighPoint Technologies推出了一款全新的SSD7749M2 RAID卡&#xff0c;能够在标准的桌面工作站中安装多达16个M.2 SSD&#xff0c;实现高达128TB的闪存存储。该卡通过PCIe Gen4 x16接口提供高达28 GB/s的顺序读写性能。这些令人瞩目的性能规格伴随着高昂的价格标签。 #### 技术…

ArcGIS Pro基础:设置快速访问工具栏

上图【红色框线】内显示就是快速访问工具栏&#xff0c;访问非常方便&#xff0c;不需要切换到选项卡了 上图显示&#xff0c;可以勾选或者取消进行设置&#xff0c;通过【更多命令】可以选择更多的工具 如上图所示&#xff0c;可以选择自己经常使用的命令&#xff0c;可以输入…

手撕线程池

1.手撕线程池原理图 2.代码实现 // 手撕线程池 public class Main {public static void main(String[] args) {ThreadPool threadPool new ThreadPool(1,1000,TimeUnit.MILLISECONDS,1,(queue, task) -> {queue.putByTime(task,1500,TimeUnit.MILLISECONDS);});for (int i…

LangChain 实战演练:借助 LangChain SQL Agent 与 GPT 实现文档智能分析及交互

LangChain实战&#xff1a;利用LangChain SQL Agent和GPT进行文档分析和交互 我最近接触到一个非常有趣的挑战&#xff0c;涉及到人工智能数字化大量文件的能力&#xff0c;并使用户可以在这些文件上提出复杂的与数据相关的问题&#xff0c;比如&#xff1a; 数据检索问题&…

【qt】基于tcp的消息发送

我们需要实现客户端发消息&#xff0c;服务端接收消息 服务端界面新增接收消息 实现客户端发送和清空 发送数据需要将发送栏的信息转化为QByteArray,然后使用socket的write发送过去 实现服务端的接收 效果演示 20240818_111603 代码展示 server Widget.h #ifndef WIDGET_H …

Java的File类与IO流

目录 1. java.io.File类的使用 1.1 概述 1.2 构造器 1.3 常用方法 1、获取文件和目录基本信息 2、列出目录的下一级 3、File类的重命名功能 4、判断功能的方法 5、创建、删除功能 1.4 练习 2. IO流原理及流的分类 2.1 Java IO原理 2.2 流的分类 2.3 流的API 3. …

如何在 Windows/Mac/在线/iPhone/Android 上将 PDF 转换为 Word

PDF&#xff08;便携式文档格式&#xff09;是一种流行的格式&#xff0c;广泛用于在数字电子设备中呈现文档。输出文件小且兼容性强&#xff0c;使 PDF 如此受欢迎。但是&#xff0c;编辑 PDF 文件并非免费。您无需购买 PDF 编辑器&#xff0c;而是可以将 PDF 转换为 Word 进行…

「OC」NSPredicate —— 使用谓词过滤元素

「OC」NSPredicate —— 使用谓词过滤元素 文章目录 「OC」NSPredicate —— 使用谓词过滤元素前言介绍常见用法**比较运算符****逻辑运算符****字符串比较运算符****聚合运算符****用于字典或者类当中****格式说明符&#xff08;占位符&#xff09;** 实际运用总结参考文章 前…

05创建型设计模式——原型模式

一、原型模式简介 原型模式&#xff08;Prototype Pattern&#xff09;模式是一种对象创建型模式&#xff0c;它采取复制原型对象的方法来创建对象的实例。使用原型模式创建的实例&#xff0c;具有与原型一样的数据。 1&#xff09;由原型对象自身创建目标对象。换句话说&…