CV05_深度学习模块之间的缝合教学(1)

news2024/11/13 9:23:26

1.1 在哪里缝

测试文件?(×)

训练文件?(×)

模型文件?(√)

1.2 骨干网络与模块缝合

以Vision Transformer为例,模型文件里有很多类,我们只在最后集大成的那个类里添加模块。

之后后,我们准备好我们要缝合的模块,比如SE Net模块,我们先建立一个测试文件测试能否跑通

import numpy as np
import torch
from torch import nn
from torch.nn import init

class SEAttention(nn.Module):
    # 初始化SE模块,channel为通道数,reduction为降维比率
    def __init__(self, channel=512, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 自适应平均池化层,将特征图的空间维度压缩为1x1
        self.fc = nn.Sequential(  # 定义两个全连接层作为激励操作,通过降维和升维调整通道重要性
            nn.Linear(channel, channel // reduction, bias=False),  # 降维,减少参数数量和计算量
            nn.ReLU(inplace=True),  # ReLU激活函数,引入非线性
            nn.Linear(channel // reduction, channel, bias=False),  # 升维,恢复到原始通道数
            nn.Sigmoid()  # Sigmoid激活函数,输出每个通道的重要性系数
        )

    # 权重初始化方法
    def init_weights(self):
        for m in self.modules():  # 遍历模块中的所有子模块
            if isinstance(m, nn.Conv2d):  # 对于卷积层
                init.kaiming_normal_(m.weight, mode='fan_out')  # 使用Kaiming初始化方法初始化权重
                if m.bias is not None:
                    init.constant_(m.bias, 0)  # 如果有偏置项,则初始化为0
            elif isinstance(m, nn.BatchNorm2d):  # 对于批归一化层
                init.constant_(m.weight, 1)  # 权重初始化为1
                init.constant_(m.bias, 0)  # 偏置初始化为0
            elif isinstance(m, nn.Linear):  # 对于全连接层
                init.normal_(m.weight, std=0.001)  # 权重使用正态分布初始化
                if m.bias is not None:
                    init.constant_(m.bias, 0)  # 偏置初始化为0

    # 前向传播方法
    def forward(self, x):
        b, c, _, _ = x.size()  # 获取输入x的批量大小b和通道数c
        y = self.avg_pool(x).view(b, c)  # 通过自适应平均池化层后,调整形状以匹配全连接层的输入
        y = self.fc(y).view(b, c, 1, 1)  # 通过全连接层计算通道重要性,调整形状以匹配原始特征图的形状
        return x * y.expand_as(x)  # 将通道重要性系数应用到原始特征图上,进行特征重新校准

# 示例使用
if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)  # 随机生成一个输入特征图
    se = SEAttention(channel=512, reduction=8)  # 实例化SE模块,设置降维比率为8
    output = se(input)  # 将输入特征图通过SE模块进行处理
    print(output.shape)  # 打印处理后的特征图形状,验证SE模块的作用

打印处理后的形状,我们这里要注意,缝合模块时只需要注意第一维,也就是这个channel,要和骨干网络保持一致,只要你把输入输出的通道数对齐,那么这个通道数就可以缝合成功。

把模块复制进骨干网络中:

然后进行缝合,在缝合之前要先测试通道是否匹配,不然肯定报错。

如何验证通道数

我们找到骨干网络前向传播的部分,在你想加入这个模块地方print(x.shape)即可。运行训练文件:

放在最前面:

通道数为3(8为batch size)。

将模块添加进骨干网络

在骨干网络的init函数下添加:(ctrl+p可查看参数)通道数与之前查的对齐。

在前向传播中添加:

看看是否正常运行:

正常运行,说明模块缝合成功!

打印缝合后的模型结构

该操作在模型文件中进行。

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (2): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (3): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (4): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (5): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (6): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (7): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (8): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (9): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (10): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (11): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (pre_logits): Sequential(
    (fc): Linear(in_features=768, out_features=768, bias=True)
    (act): Tanh()
  )
  (head): Linear(in_features=768, out_features=21843, bias=True)
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=3, out_features=0, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=0, out_features=3, bias=False)
      (3): Sigmoid()
    )
  )
)

我们可以看到多了一个SEAttention,说明模块缝合进去了!

1.3 模块之间缝合

以SENet和ECA模块为例。

串联模块

方式1

同1.2。照猫画虎。(注意通道数保持一致)

打印模型结构:

ECAAttention(
  (gap): AdaptiveAvgPool2d(output_size=1)
  (conv): Conv1d(1, 1, kernel_size=(3,), stride=(1,), padding=(1,))
  (sigmoid): Sigmoid()
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=64, out_features=4, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=4, out_features=64, bias=False)
      (3): Sigmoid()
   )))

 方式2

我们定义一个串联函数,将模块之间串联起来:

实例化查看一下模型结构

输出结果:

torch.Size([1, 63, 64, 64]) torch.Size([1, 63, 64, 64])
Cascade(
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=63, out_features=3, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=3, out_features=63, bias=False)
      (3): Sigmoid()
    )
  )
  (eca): ECAAttention(
    (gap): AdaptiveAvgPool2d(output_size=1)
    (conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
    (sigmoid): Sigmoid()
  )
)

并联模块

对于并联模块,方法有很多种,两个两个模块输出的张量可以:

(1)逐元素相加(2)逐元素相乘(3)concat拼接(4)等等

输出结果:

torch.Size([1, 63, 64, 64]) torch.Size([1, 126, 64, 64])
Cascade(
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=63, out_features=3, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=3, out_features=63, bias=False)
      (3): Sigmoid()
    )
  )
  (eca): ECAAttention(
    (gap): AdaptiveAvgPool2d(output_size=1)
    (conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
    (sigmoid): Sigmoid()
  )
)

1.4 思考 

我们不要拘泥于只串联获并联,可以将二者结合,多个模块中,部分模块并联后又与其他模块串联,等等。。这种排列组合之后,总会有一个你想要的模型!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1917909.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Pytorch】Conda环境下载慢换源/删源/恢复默认源

文章目录 背景临时换源永久换源打开conda配置condarc换源执行配置 命令行修改源添加源查看源 删源恢复默认源使用示范 背景 随着实验增多,需要分割创建环境的情况时有出现,在此情况下使用conda create --name xx python3.10 pytorch torchvision pytorc…

maven——(重要)手动创建,构建项目

创建项目 手动按照maven层级建好文件夹,并写上java,测试代码和pom文件 构建项目 在dos窗口中执行如下命令 compile编译 当前maven仓库中什么都没有。 在pom所在层级下,执行: mvn compile 就开始显示下面这些,…

整数的英语表示

题目链接 整数的英语表示 题目描述 注意点 0 < num < 2^31 - 1 解答思路 每三个数字形成一组&#xff08;高位不足的部分可以用0填充&#xff09;&#xff0c;使用StringBuilder拼接每组的数字和单位关键是三个数字的英语表示&#xff0c;包含个位、十位、百位&…

java入门-告别C进入java世界

目标 java体系 java开发环境 helloworld java语法 java体系 java开发环境 安装JDK JDK&#xff1a; Java Developement Kit 配置jdk 为什么需要配置 操作系统找不到此程序 操作系统PATH PATH C:\Users\49354>echo %PATH% C:\Program Files (x86)\VMware\VMware Works…

Java---SpringBoot详解一

人性本善亦本恶&#xff0c; 喜怒哀乐显真情。 寒冬暖夏皆有道&#xff0c; 善恶终归一念间。 善念慈悲天下广&#xff0c; 恶行自缚梦难安。 人心如镜自省照&#xff0c; 善恶分明照乾坤。 目录 一&#xff0c;入门程序 ①&#xff0c;创建springboot工程&#…

安泰电压放大器的选型方案是什么

电压放大器是一种常见的电路元件&#xff0c;广泛应用于各种电子设备中。在选择电压放大器的时候&#xff0c;我们需要考虑一系列因素&#xff0c;以确保选型方案能够满足实际需求。下面安泰电子将详细介绍电压放大器选型的主要考虑因素&#xff0c;包括应用需求、技术性能、成…

CSDN回顾与前行:我的创作纪念日——2048天的技术成长与感悟

CSDN回顾与前行&#xff1a;我的创作纪念日——2048天的技术成长与感悟 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 前言 时光荏苒&#xff0c;岁月如梭。转眼间&#xff0c;从我在CSDN上写下第一篇技术博客《2-6 带头结点的链式表操作…

《WebGIS快速开发教程》第7版发布

老规矩先看封面&#xff1a; 可以看到我们在封面上加了“classic”的字样&#xff0c;这意味着第7版将会是经典版本&#xff0c;或者说具有里程碑意义的一个版本。 拿到新书我们可以看到第7版的整体风格是以“业务场景”为核心&#xff0c;所有讲解的知识点和案例都是围绕着业…

一文清晰了解CSS——简单实例

首先一个小技巧&#xff1a; 一定要学会的vsCode格式化整理代码的快捷键&#xff0c;再也不用手动调格式了-腾讯云开发者社区-腾讯云 (tencent.com) CSS选择器用于选择要应用样式的HTML元素。常见的选择器包括&#xff1a; 类选择器&#xff1a;以.开头&#xff0c;用于选择具…

STM32CubeMX 下载及安装教程

1. 什么是 STM32CubeMX? STM32CubeMX 是一款图形化的配置工具&#xff0c;用于配置 STM32 系列微控制器的硬件外设、时钟系统以及中间件组件。它提供了一种可视化的方式来设置硬件功能&#xff0c;并生成相应的初始化代码以帮助开发者快速启动项目。 2. 主要功能 2.1 图形化…

ISO三体系认证:助力企业迈向卓越管理

在激烈的市场竞争中&#xff0c;企业不仅需要优质的产品和服务&#xff0c;还需要科学高效的管理体系。ISO三体系认证&#xff0c;包括ISO 9001质量管理体系认证、ISO 14001环境管理体系认证和ISO 45001职业健康与安全管理体系认证&#xff0c;为企业提供了系统化的管理框架。这…

docker-compose安装rocketmq

创建挂载目录 mkdir -p /home/docker/rocketmq/rocketmq_server/logs mkdir -p /home/docker/rocketmq/rocketmq_broker/logs mkdir -p /home/docker/rocketmq/rocketmq_broker/store mkdir -p /home/docker/rocketmq/rocketmq_broker/conf#创建配置文件broker.conf cd /home/…

StarRocks 集群管理又添“猛将“ ,随配随用随时修改

前言 在存储业务数据时&#xff0c;StarRocks 存算分离支持使用各种外部独立存储系统。 在早期 3.0 版本中&#xff0c;用户需要在 fe.conf 中配置存储相关信息&#xff08;如 endpoint 等&#xff09;&#xff0c;这种静态配置模式也给用户使用带来了很多的不便性。 为此&a…

echarts中tooltip添加点击事件代码示例

echarts中tooltip添加点击事件代码示例_javascript技巧_脚本之家 点击事件无法使用this 或者 this无法使用&#xff1a;

【python】随机森林预测汽车销售

目录 引言 1. 数据收集与预处理 2. 划分数据集 3. 构建随机森林模型 4. 模型训练 5. 模型评估 6. 模型调优 数据集 代码及结果 独热编码 随机森林模型训练 特征重要性图 混淆矩阵 ROC曲线 引言 随机森林&#xff08;Random Forest&#xff09;是一种集成学习方法…

springboot餐饮管理系统-计算机毕业设计源码43667

摘 要 在信息化、数字化的时代背景下&#xff0c;餐饮行业面临着前所未有的挑战与机遇。为了提高运营效率、优化顾客体验&#xff0c;餐饮企业亟需一套高效、稳定且灵活的管理系统来支撑其日常运营。基于Spring Boot的餐饮管理系统应运而生&#xff0c;成为餐饮行业数字化转型的…

高仿imtoken钱包源码/获取助记词/获取私钥/自动归集

简介&#xff1a; 高仿imtoken钱包/获取助记词/获取私钥/自动归集 带双端&#xff0c;无纯源码 下载源码

企业网站源码系统 自主快速搭建响应式网站 海量模版随心选择 带完整的源代码包以及搭建教程

系统概述 企业网站源码系统&#xff0c;是一款专为中小企业量身定制的网站建设解决方案。该系统基于先进的Web开发技术&#xff0c;融合了模块化设计理念和用户友好的操作界面&#xff0c;旨在帮助企业用户无需编程基础&#xff0c;即可轻松搭建出符合自身需求的响应式网站。通…

太恐怖了,30秒录音,就能复刻你的声音

最近出的这一款AI文本转语音工具&#xff0c;太恐怖了&#xff01; 只需要有你一段录音&#xff0c;就能直接复刻你的声音。 下边是我复刻的李云龙的声音 这个工具复刻声音非常简单 使用步骤&#xff1a; 打开网站后选择构建声音 上传封面 填写名字和描述 上传音频或录制…

本周六!上海场新能源汽车数据基础设施专场 Meetup 来了

本周六下午 14:30 新能源汽车数据基础设施专场 Meetup 在上海&#xff0c;点击链接报名 &#x1f381; 到场有机会获得 Greptime 和 AutoMQ 的精美文创周边哦&#xff5e; &#x1f52e; 会后还有观众问答 & 抽奖环节等你来把神秘礼物带回家&#xff5e; &#x1f9c1; 更…