Pytorch从零开始实战16

news2025/1/10 12:15:12

Pytorch从零开始实战——ResNeXt-50算法的思考

本系列来源于365天深度学习训练营

原作者K同学

对于上次ResNeXt-50算法,我们同样有基于TensorFlow的实现。具体代码如下。

引入头文件

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model

分组卷积模块

# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):
    group_list = []
    # 分组进行卷积
    for c in range(groups):
        # 分组取出数据
        x = Lambda(lambda x: x[:, :, :, c * g_channels:(c + 1) * g_channels])(init_x)
        # 分组进行卷积
        x = Conv2D(filters=g_channels, kernel_size=(3, 3),strides=strides, padding='same', use_bias=False)(x)
        # 存入list
        group_list.append(x)
    # 合并list中的数据
    group_merage = concatenate(group_list, axis=3)
    x = BatchNormalization(epsilon=1.001e-5)(group_merage)
    x = ReLU()(x)
    return x

残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):

    if conv_shortcut:
        shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x)
        # epsilon为BN公式中防止分母为零的值
        shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)
    else:
        # identity_shortcut
        shortcut = x
        
    # 三层卷积层
    x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = ReLU()(x)
    # 计算每组的通道数
    g_channels = int(filters / groups)
    # 进行分组卷积
    x = grouped_convolution_block(x, strides, groups, g_channels)

    x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = Add()([x, shortcut])
    x = ReLU()(x)
    return x

堆叠残差单元

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):
    # 每个stack的第一个block的残差连接都需要使用1*1卷积升维
    x = block(x, filters, strides=strides, groups=groups)
    for i in range(blocks):
        x = block(x, filters, groups=groups, conv_shortcut=False)
    return x

网络搭建

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):
    inputs = Input(shape=input_shape)
    # 填充3圈0,[224,224,3]->[230,230,3]
    x = ZeroPadding2D((3, 3))(inputs)
    x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='valid')(x)
    x = BatchNormalization(epsilon=1.001e-5)(x)
    x = ReLU()(x)
    # 填充1圈0
    x = ZeroPadding2D((1, 1))(x)
    x = MaxPool2D(pool_size=(3, 3), strides=2, padding='valid')(x)
    # 堆叠残差结构
    x = stack(x, filters=128, blocks=2, strides=1)
    x = stack(x, filters=256, blocks=3, strides=2)
    x = stack(x, filters=512, blocks=5, strides=2)
    x = stack(x, filters=1024, blocks=2, strides=2)
    # 根据特征图大小进行全局平均池化
    x = GlobalAvgPool2D()(x)
    x = Dense(num_classes, activation='softmax')(x)
    # 定义模型
    model = Model(inputs=inputs, outputs=x)
    return model

对于残差单元中的代码,提出一个问题:当conv_shortcut=False的时候,在执行Add操作时,理论上通道数不一致,为什么代码不报错?
在这里插入图片描述
答:这主要是跟下面堆叠残差单元的代码有关系,每个stack第一轮总会令conv_shortcut为True,使得x通道数进行扩展,而后面循环的时候传入的filters还是这个函数的实参,没有发生变化,但由于conv_shortcut为False,此时shortcut的通道数是与上面的x一致,所以在Add的时候,代码不会报错。

def stack(x, filters, blocks, strides, groups=32):
    # 每个stack的第一个block的残差连接都需要使用1*1卷积升维
    x = block(x, filters, strides=strides, groups=groups)
    for i in range(blocks):
        x = block(x, filters, groups=groups, conv_shortcut=False)
    return x

本文只是对ResNeXt-50算法的部分代码进行思考,学习过程中需要积极思考与探索,以提高能力和解决问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1372039.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

传统 VC 机构,是否还能在 Fair launch 的散户牛市中胜出?

LaunchPad 是代币面向市场的重要一环,将代币推向市场,加密项目将能够通过代币的销售从市场上募集资金,同时生态也开始进入全新的发展阶段。而对于投资者来说,早期打新市场同样充满着机会,参与 LaunchPad 对于每一个投资…

通过iFrame嵌入Grafana页面或pannel

前言 在当前数据驱动的时代,有效地可视化和监控关键性能指标变得至关重要。Grafana,作为一个开源的监控解决方案,提供了强大的功能来呈现和分析数据,从而帮助用户及时洞察和响应各种情况。随着技术的不断发展,将这些信…

TOWE 65W智能快充插线板为家庭用电保驾护航

随着家用电器在人们生活中的普及,在给人们带来便利的同时,其安全性也日益引起人们的重视。延长线插座作为固定插座的补充,因其可移动、携带方便、灵活性强而成为家居和办公中不可或缺的用品。一般普通的塑料插排插座受额定电流和额定功率的影…

Unity 了解Input Manage下默认的输入轴

在Unity菜单Edit->Project Settings->Input Manager->Axes下有一些默认的输入轴,如 这些输入轴代表不同类型的输入,其中: Horizontal:水平移动输入轴。通常与键盘的左右箭头键、A和D键、游戏手柄的左摇杆水平轴等相关联…

经典目标检测YOLO系列(二)YOLOv2算法详解

经典目标检测YOLO系列(二)YOLOv2算法详解 YOLO-V1以完全端到端的模式实现达到实时水平的目标检测。但是,YOLO-V1为追求速度而牺牲了部分检测精度,在检测速度广受赞誉的同时,其检测精度也饱受诟病。正是由于这个原因,YOLO团队在20…

从零学Java 集合概述

Java 集合概述 文章目录 Java 集合概述1 什么是集合?2 Collection体系集合2.1 Collection父接口2.1.1 常用方法2.1.2 Iterator 接口 1 什么是集合? 概念:对象的容器,定义了对多个对象进行操作的常用方法;可实现数组的功能。 和数组区别&…

基于深度学习的果蔬检测识别系统(含UI界面、yolov5、Python代码、数据集)

项目介绍 项目中所用到的算法模型和数据集等信息如下: 算法模型:     yolov5 yolov5主要包含以下几种创新:         1. 添加注意力机制(SE、CBAM、CA等)         2. 修改可变形卷积(DySnake-主…

Python如何使用Excel文件

使用Python操作Office——EXCEL 首先介绍下office win32 com接口,这个是MS为自动化提供的操作接口,比如我们打开一个EXCEL文档,就可以在里面编辑VB脚本,实现我们自己的效果。对于这种一本万利的买卖,Python怎么能放过…

VLAN原理与配置

0x00 前言 本节主要记录VLAN相关的内容。 传统以太网的缺点 广播域越大,产生的网络安全问题,垃圾流量问题越严重。 什么是VLAN? Virtual Local Area NetWork 虚拟局域网技术。 VLAN的特点是什么 一个VLAN就是一个广播域,在…

web前端案例之抽奖

使用HTMLJavascript完成抽奖案例 <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style>*{margin: 0;padding: 0;}</style></head><body><div id"container" onclic…

2023年全国职业院校技能大赛(高职组)“云计算应用”赛项赛卷5

某企业根据自身业务需求&#xff0c;实施数字化转型&#xff0c;规划和建设数字化平台&#xff0c;平台聚焦“DevOps开发运维一体化”和“数据驱动产品开发”&#xff0c;拟采用开源OpenStack搭建企业内部私有云平台&#xff0c;开源Kubernetes搭建云原生服务平台&#xff0c;选…

设计与实现基于Java+MySQL的模拟银行ATM操作系统

课题背景 随着现代经济的发展&#xff0c;电子支付和自动化银行服务已成为人们生活中不可或缺的一部分。自动取款机&#xff08;ATM&#xff09;作为一种常见的自助服务设备&#xff0c;使用户能够方便地进行资金的存取、查询余额、转账等操作&#xff0c;而无需到银行柜台。 …

动态内存管理的题目

数组串联 在leetcode上找的一题 &#xff1a; 给你一个长度为 n 的整数数组 nums 。请你构建一个长度为 2n 的答案数组 ans &#xff0c;数组下标 从 0 开始计数 &#xff0c;对于所有 0 < i < n 的 i &#xff0c;满足下述所有要求&#xff1a; ans[i] nums[i]ans[i…

微软Visual C++编程进阶——一维数组(画画版)

我是荔园微风&#xff0c;作为一名在IT界整整25年的老兵&#xff0c;看到不少初学者在学习编程语言的过程中如此的痛苦&#xff0c;我决定做点什么&#xff0c;我小时候喜欢看小人书&#xff08;连环画&#xff09;&#xff0c;在那个没有电视、没有手机的年代&#xff0c;这是…

使用request测试get请求 操作流程

第一步 谷歌f12或者其他抓包工具&#xff0c;抓包获取接口url&#xff1a; https://g-api.csdn.net/community/toolbar-api/v2/favorites-list 第二步 导包 import requests 第三步 调用请求并打印结果 url"https://g-api.csdn.net/community/toolbar-api/v2/favorites-l…

MVC+Layui 多选下拉框xmSelect

1、选择layui拓展第三方组件找到xmselect xmSelect下拉多选 xmSelect - Layui 第三方扩展组件平台 (layuion.com) 下载后放到项目文件中 2、项目引用js文件 <script src"~/Content/dist/xm-select.js"></script> 3、html添加表单设置id <div class…

2024年【高压电工】模拟考试及高压电工模拟考试题

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年高压电工模拟考试为正在备考高压电工操作证的学员准备的理论考试专题&#xff0c;每个月更新的高压电工模拟考试题祝您顺利通过高压电工考试。 1、【单选题】 使用验电器验电前,除检查其外观、电压等级、试验合…

软件测试|MySQL中的GROUP BY分组查询,你会了吗?

MySQL中的GROUP BY分组查询&#xff1a;详解与示例 在MySQL数据库中&#xff0c;GROUP BY语句用于将数据按照指定的列进行分组&#xff0c;并对每个分组执行聚合函数操作。这就是的我们可以在查询中汇总数据并生成有意义的结果。本文将深入介绍MySQL中的GROUP BY语句&#xff…

Java中多线程二

抢占调度模型 概述&#xff1a;优先让优先级高的线程使用 CPU &#xff0c;如果线程的优先级相同&#xff0c;那么随机会选择一个&#xff0c;优先级高的线程获取的 CPU 时间片相对多一些 Thread 类中一些关于线程的方法 方法简述public final int getPriority()返回此线程的优…

【博士每天一篇文-算法】Graph Structure of Neural Networks

阅读时间&#xff1a;2023-11-12 1 介绍 年份&#xff1a;2020 作者&#xff1a;尤家轩 斯坦福大学 期刊&#xff1a; International Conference on Machine Learning. 引用量&#xff1a;130 论文探讨了神经网络的图结构与其预测性能之间的关系。作者提出了一种新的基于图的…