Pytorch入门实战 P08-YOLOv5里面的C3模块实现

news2024/11/24 12:57:00

目录

1、YOLOv5骨干网络模型图:

2、C3模块介绍:

3、C3模块的主要代码:

4、完整的code

5、运行结果展示:

(1)使用SGD优化器

(2)使用Adam优化器


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

今天这篇博客的主要内容是,实现Yolov5里面的C3模块。

1、YOLOv5骨干网络模型图:

更多详细的v5模型介绍:【目标检测】yolov5模型详解-CSDN博客

从上述YOLOv5的模型图中,我们可以看到,C3模块 主要出现在Backbone模块和Neck模块。

这篇文章,我们主要先来看下YOLOv5里面的C3模块

2、C3模块介绍:

YOLOv5中的C3模块是目标检测算法中的一个关键组件,主要用于特征提取和融合。该模块在YOLOv5的骨干网络中扮演着重要角色,帮助算法更好地理解和分析图像。

具体来说,C3模块的结构较为复杂,它包含了多个Conv模块和一个Bottleneck模块

Conv模块主要负责对输入的特征图进行卷积操作,以提取更高级别的特征。这种卷积操作可以通过任意的卷积核进行,但根据设计,采用1*1的卷积核可以起到降维或升维的作用,对于提取特征有重要意义。

Bottleneck模块是C3模块的另一个重要组成部分,其设计有利于增加网络的感受野,同时减少计算量。感受野的增加可以让网络更加关注物体的全局信息,从而提高特征提取的效果。具体来说,Bottleneck模块包含两个部分:一个(1,1)的卷积,用于将输入特征图的通道数减半;以及一个(3,3)的卷积,用于将通道数翻倍。

此外,C3模块还引入了一些创新性的技术,如在模块中引入自注意力机制,以加强对图像中重要区域的关注。这种机制在处理具有复杂背景和遮挡情况的图像时,能够提升模型对关键特征的提取能力。

总的来说,YOLOv5中的C3模块是一个设计精良的特征提取模块,它通过复杂的结构和先进的技术,提高了目标检测算法的性能和准确性。

我们先来看下C3模块的模型图(如下)。

这次的模型搭建就是C3模块。

3、C3模块的主要代码:

class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self,x):
        return self.act(self.bn(self.conv(x)))
def autopad(k, p=None):  # kernel, padding
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
    return p
class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class model_K(nn.Module):
    def __init__(self):
        super(model_K, self).__init__()
        # 卷积模块
        self.Conv = Conv(3, 32, 3, 2)
        # C3模块
        self.C3_1 = C3(32, 64, 3, 2)
        # 全连接网络层,用于分类
        self.classifier = nn.Sequential(
            nn.Linear(in_features=802816, out_features=100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features=4)
        )

    def forward(self,x):
        x = self.Conv(x)
        x = self.C3_1(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x


model = model_K().to(device)
print(model)

4、完整的code

将一般的网络结构改成C3网络结构,完整代码如下:

import copy
import pathlib
import warnings

import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import transforms
import matplotlib as mpl
mpl.use('Agg')  # 在服务器上运行的时候,打开注释

'''
    利用v5里面的C3来模块搭建网络
'''

# 1、设备检查
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# 2、导入数据
data_dir = './data'
data_dir = pathlib.Path(data_dir)

data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('/')[1] for path in data_paths]
print(classNames)       # ['cat', 'dog']

# 3、图像预处理
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229,0.224,0.225]
    )
])

test_transforms = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229,0.224,0.225]
    )
])

total_data = datasets.ImageFolder('./data', transform=train_transforms)

# 4、划分数据集
train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
print(train_size, test_size)   # 2720 680

# 5、加载数据
batch_size = 4
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

for X,y in test_dl:
    print('Shape of X [N,C,H,W]:', X.shape)   # torch.Size([4, 3, 224, 224])
    print('Shape of y:',y.shape, y.dtype)
    break

# (二)C3模块的模型搭建


def autopad(k, p=None):  # kernel, padding
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
    return p


class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self,x):
        return self.act(self.bn(self.conv(x)))


class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))


class model_K(nn.Module):
    def __init__(self):
        super(model_K, self).__init__()
        # 卷积模块
        self.Conv = Conv(3, 32, 3, 2)
        # C3模块
        self.C3_1 = C3(32, 64, 3, 2)
        # 全连接网络层,用于分类
        self.classifier = nn.Sequential(
            nn.Linear(in_features=802816, out_features=100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features=4)
        )

    def forward(self,x):
        x = self.Conv(x)
        x = self.C3_1(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x


model = model_K().to(device)
print(model)


# (三)、编写训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    train_loss, train_acc = 0, 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss


# 编写测试函数
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target  in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

# 正式训练
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数

epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []

best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标
for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    # 保存最佳模型到best_model
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%, Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc, epoch_train_loss,epoch_test_acc,epoch_test_loss,lr))

# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')


# (四)、结果可视化
warnings.filterwarnings('ignore')
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100   # 分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12,3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label="Training Loss")
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.savefig("/data/jupyter/deep_demo/p08_v5-C3/resultImg.jpg")  # 保存图片在服务器的位置
plt.show()

# (五)、模型评估
best_model.eval()

epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(epoch_test_acc, epoch_test_loss)

5、运行结果展示:

(1)使用SGD优化器

(2)使用Adam优化器

6、总结C3模块

C3模块主要由三个Conv模块一个Bottleneck模块组成。

Conv模块负责对输入的特征图进行卷积操作,提取图像中的特征。

Bottleneck模块则进一步对特征进行处理,增加网络的感受野并减少计算量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1634899.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年必应bing广告推广开户有什么条件?

必应Bing作为全球领先的搜索引擎之一,其广告平台正为无数企业开辟着新的市场蓝海。如果您正寻求在必应Bing上投放广告,提升品牌影响力和市场份额,那么了解开户条件并找到一位可靠的合作伙伴至关重要。云衔科技,作为数字营销领域的…

Jetson Orin NX L4T35.5.0平台LT6911芯片 调试记录(2)vi discarding frame问题调试

基于上篇调试记录 Jetson Orin NX L4T35.5.0平台LT6911芯片 调试记录(1)MIPI问题调试-CSDN博客 1.前言 当通过gstreamer持续捕获视频设备时,帧数会下降,并且I输入越高,丢失的帧数越多。 当达到4k30hz时,它完全无法使用,系统会在几秒钟的收集后崩溃并重新启动 4k30hz …

使用yolo识别模型对比两张图片并标记不同(2)

上篇文章有漏洞,在这里补充下,比如要识别第二张图相对于第一张图的违建是否拆除了 第一步旋转对其后,图片会有黑色的掩码,如果旋转角度大的话,没识别出来的框可能不是已经拆除了,而是因为黑色掩码遮挡&…

Laravel 框架请求生命周期

Laravel 框架请求的生命周期 目录 请求图示 说明 ① ② ③ ④ ⑤ ⑥ ⑦ ⑧ 请求图示 说明 ① 所有的请求都是经Web 服务器(Apache/Nginx)配置引导到Laravel 应用的入口public/index.php文件。index.php 加载框架其它部分。 如下图&#xff…

基于FPGA的数字信号处理(3)--什么是浮点数?

科学计数法 你可能不了解「浮点数」&#xff0c;但你一定了解「科学记数法」。 10进制科学记数法把一个数表示成a与10的n次幂相乘的形式&#xff08;1≤|a|<10&#xff0c;a不为分数形式&#xff0c;n为整数&#xff09;&#xff0c;例如&#xff1a; 19970000000000 1.9…

关系(五)利用python绘制连接散点图

关系&#xff08;五&#xff09;利用python绘制连接散点图 连接散点图&#xff08;Connected Scatterplot&#xff09;简介 连接散点图&#xff08;点线图&#xff09;是折线图的一种&#xff0c;与散点图类似。但添加了按数据点出现顺序的连线&#xff0c;以此来表示两个变量…

币圈Cryptosquare论坛

Cryptosquare综合性资讯论坛汇集了币圈新闻、空投信息、社会热点以及与Web3相关的工作信息。让我们一起解锁加密世界的种种可能性&#xff0c;探索Cryptosquare论坛带来的精彩&#xff01; 币圈新闻板块&#xff1a; Cryptosquare论坛的币圈新闻板块是用户获取最新加密货币行业…

vite打包配置

目录 minify默认是esbuild&#xff0c;不能启动下面配置 使用&#xff1a; plugins: [viteMockServe({mockPath: mock})]根目录新建mock/index.ts. 有例子Mock file examples&#xff1a;https://www.npmjs.com/package/vite-plugin-mock-server 开发环境生产环境地址替换。根…

Matlab|含sop的33节点配电网优化

目录 1 主要内容 2 部分代码 3 程序结果 4 下载链接 1 主要内容 程序以IEEE33节点为例&#xff0c;分析含sop的配电网优化&#xff0c;包括sop有功约束、无功约束和容量约束&#xff0c;非线性部分通过转换为旋转锥约束进行编程&#xff0c;并且包括33节点配电网潮流及对应…

python自动化操作docx

使用Python自动化处理Word文档 在日常工作中&#xff0c;我们经常需要处理大量的Word文档&#xff0c;这时自动化脚本就显得尤为重要。本文将介绍如何使用Python中的python-docx库来创建和修改Word文档。 安装python-docx库 在开始之前&#xff0c;确保你已经安装了python-d…

基于JWT实现的Token认证方案

JSON Web Token是什么&#xff1f; JSON Web Token&#xff08;JWT&#xff09;是目前最流行的跨域身份验证解决方案。 JSON Web Token&#xff08;JWT&#xff09;是一个开放标准&#xff08;RFC 7519&#xff09;&#xff0c;它定义了一种紧凑且自包含的方式&#xff0c;用…

电脑文件误删除如何恢复?这5个策略亲测有效!

“求助&#xff01;在电脑上不小心删除了文件还有机会找回来吗&#xff1f;一不小心我就删除了一个重要的工作文件&#xff01;大家快帮帮我吧&#xff01;” 保存在电脑里的文件对电脑用户来说很多都是非常重要的&#xff0c;我们可能生活中、学习上以及工作上都需要使用这些文…

C++学习第七课:控制程序流程的学习和示例详解

C学习第七课&#xff1a;控制程序流程 在C中&#xff0c;控制程序流程是编程逻辑的核心部分&#xff0c;它决定了程序的执行顺序。本课我们将介绍C中的各种控制流程语句&#xff0c;包括条件语句、循环语句以及如何使用它们遍历多维数组和计算斐波那契数列。 控制流程语句 i…

哪个牌子的骨传导耳机好用?盘点五款高热度爆款骨传导耳机推荐!

近年来&#xff0c;骨传导耳机在潮流的推动下销量节节攀升&#xff0c;逐渐成为运动爱好者和音乐迷们的必备装备。但热度增长的同时也带来了一些品质上的忧患&#xff0c;目前市面上的部分产品&#xff0c;存在佩戴不舒适、音质不佳等问题&#xff0c;甚至可能对听力造成潜在损…

VSCode SSH连接远程主机失败,显示Server status check failed - waiting and retrying

vscode ssh连接远程主机突然连接不上了&#xff0c;终端中显示&#xff1a;Server status check failed - waiting and retrying 但是我用Xshell都可以连接成功&#xff0c;所以不是远程主机的问题&#xff0c;问题出在本地vscode&#xff1b; 现象一&#xff1a; 不停地输入…

Python俄罗斯方块

文章目录 游戏实现思路1. 游戏元素的定义2. 游戏区域和状态的定义3. 游戏逻辑的实现4. 游戏界面的绘制5. 游戏事件的处理6. 游戏循环7. 完整实现代码 游戏实现思路 这个游戏的实现思路主要分为以下几个步骤&#xff1a; 1. 游戏元素的定义 Brick类&#xff1a;表示游戏中的砖…

使用Tortoise 创建远程分支

1。首先创建本地分支branch1&#xff0c;右键tortoise git->创建分支&#xff0c;输入分支名称branch1&#xff0c;确定。 2。右键tortoise git->推送&#xff0c;按下图设置&#xff0c;确定&#xff0c;git会判断远程有没有分支branch1&#xff0c;如果没有会自动创建…

QT类之间主窗口子窗口传递*指针对象

1.新建CFile_Operation 类文件 2.主窗口头文件声明&#xff1a; CFile_Operation *cfile_operation; 按钮点击事件函数里面调用子窗口 dialog_debug new Dialog_Debug(this);connect(this,&MainWindow_oq::SendCfile_operation_Obj,dialog_debug,&Dialog_Debug::R…

【redis】初始redis和分布式系统的基本知识

˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好&#xff0c;我是xiaoxie.希望你看完之后,有不足之处请多多谅解&#xff0c;让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN 如…

Linux中ssh登录协议

目录 一.ssh基础 1.ssh协议介绍 2.ssh协议的优点 3.ssh文件位置 二.ssh原理 1.公钥传输原理&#xff08;首次连接&#xff09; 2.ssh加密通讯原理 &#xff08;1&#xff09;对称加密 &#xff08;2&#xff09;非对称加密 3.远程登录 三.服务端的配置 常用的配置项…