Pytorch搭建GoogleNet神经网络

news2025/1/24 17:31:01

一、创建卷积模板文件

因为每次使用卷积层都需要调用Con2d和relu激活函数,每次都调用非常麻烦,就将他们打包在一起写成一个类。

in_channels:输入矩阵深度作为参数输入

out_channels: 输出矩阵深度作为参数输入

经过卷积层和relu激活函数之后通过正向传播得到输出。

class BasicConv2d(nn.Module):
    def __init__(self, in_chaannels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_chaannels, out_channels, **kwargs)
        self.conv = nn.ReLU(inplace=True)
     # 定义正向传播  
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

二、定义Inception结构 

 def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):

in_channels:输入特征矩阵的参数

ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj:表格中对应需要的参数。

branch1、2、3、4 分别对应inception结构中的四个分支,第一层只有一个分支,第二、三、四层都有1x1卷积核起到降维的作用。

所有定义的参数都来自于GoogleNet神经网络的参数表格。

详见GoogleNet神经网络介绍-CSDN博客

class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
        
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)  # padding=1 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )
        
    def forward(self, x):  # 将输出的特征矩阵分别输出到branch1234中
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        # 将输出放入列表中
        outputs = [branch1, branch2, branch3, branch4]
        # 对四个输出特征矩阵在channel维度进行合并。“1”指需要合并的维度
        return torch.cat(outputs, 1)

三、辅助分类器

 def __init__(self, in_channels, num_classes):表示输入特征矩阵的个数和要分类的类别数。

辅助分类器1对应参数:N x 512 x 14 x 14 ,辅助分类器2对应参数:N x 528 x 14 x 14

在经过平均池化下采样后高度和宽度变成了4 x 4

辅助分类器1对应参数:N x 512 x 4 x 4 ,辅助分类器2对应参数:N x 528 x 4 x 4。

x = F.dropout(x, 0.5, training=self.training)

输入特征矩阵按照50%的比例随机失活。

training=self.training

当实例化一个模型model时,可以通过model.train() 和 model.eval()来控制模型的状态,在model.train()模式下,self.training=True, 在model.eval()模式下,self.training=False。

# 辅助分类器
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)  # 平均池化下采样层
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)   # 1 x 1 的卷积层

        # 全连接层
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)


    # 定义正向传播
    def forward(self, x):
        # aux1:N x 512 x 14 x 14 aux2:N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1:N x 512 x 4 x 4   aux2:N x 528 x 4 x 4。
        x = self.conv(x)
        x = torch.flatten(x, 1)    # "1"表示按channel维度展平
        x = F.dropout(x, 0.5, training=self.training)
        x = self.fc2(x)
        return x

四、定义层结构

 def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):

num_classes:分类类别个数

aux_logits: 是否使用辅助分类器

init_weight: 是否为权重进行初始化

self.aux_logits = aux_logits:将变量传入变为类变量

conv1、conv2、conv3 都是使用之前定义的卷积模板文件来定义的。

inception结构使用定义的 Inception结构。

self.avgpool = nn.AdaptiveAvgPool2d((1, 1)):自适应平均池化下采样

括号里输入所需要的输出特征矩阵的高和宽。

使用自适应平均池化下采样的好处:无论输入特征矩阵的高和宽时一个什么样的大小,都能得到所指定的高和宽,这样就可以不用限定输入图像的尺寸了。

class GoogleNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogleNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:  # 如果使用辅助分类器
            self.aux1 = InceptionAux(512, num_classes)   # 创建辅助分类器1
            self.aux2 = InceptionAux(528, num_classes)   # 创建辅助分类器2

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)  # 40%比例随机失活
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()  # 如果需要初始化,将会进入模型权重初始化函数

定义正向传播 

将整个层结构输出出来。

if self.training and self.aux_logits:
            aux2 = self.aux2(x)

判断模型是处于训练模式还是验证模式?并判断是否要使用辅助分类器。

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        x = self.inception4a(x)
        if self.training and self.aux_logits:
            aux1 = self.aux1(x)

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        if self.training and self.aux_logits:
            aux2 = self.aux2(x)

        x = self.inception4e(x)

        x = self.maxpool4(x)
        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        if self.training and self.aux_logits:  # 如果处于训练模式并使用了辅助分类器
            return x, aux2, aux1  # 将返回 主分类器、辅助分类器2、辅助分类器1 这样3个参数
        return x  # 否则只返回主分类器这一个函数

初始化权重函数

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):  # 卷积层
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):  # 全连接层
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

五、训练模型 

没有报错,训练成功。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1600712.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt对象池,单例模式,对象池可以存储其他类的对象指针

代码描述: 写了一个类,命名为对象池(ObjectPool ),里面放个map容器。 3个功能:添加对象,删除对象,查找对象 该类只构建一次,故采用单例模式功能描述:对象池可…

04 MySQL --DQL 专题--Union、exists

1. UNION、UNION ALL UNION 关键字的作用? 合并两个或多个 SELECT 语句的结果。发挥的作用与 or 非常相似 UNION关键字生效的前提? 每个 SELECT 语句必须拥有相同数量的列。每个 SELECT 语句中的列的顺序必须相同。列必须拥有相似的数据类型。 SELEC…

WebRTC直播间搭建记录

考虑到后续增加平台直播的可能性,笔记记录一下WebRTC相关. 让我们分别分析两种情况下的WebRTC连接建立过程: 情况一:AB之间可以直接通信 1.信令交换: 设备A和设备B首先通过信令服务器交换SDP(Session Description Pr…

负载均衡集群——LVS

目录 1.LVS简介 2.LVS体系结构 3.LVS相关术语 4. LVS工作模式 5. LVS调度算法 6.LVS集群介绍 6.1 LVS-DR模式 6.2 LVS – NAT 模式 6.3 LVS – TUN 模式 7.LVS 集群构建 7.1 LVS/NAT 模式配置 实验操作步骤 步骤 1 Nginx1 和 Nginx2 配置 步骤 2 安装和配置 LVS …

R语言使用installr包对R包进行整体迁移

今天分享一个R语言的实用小技巧,如果咱们重新安装了电脑(我重装了电脑)或者因为需要卸载旧版本的R软件,安装新版本的R,那么必然会造成R包的库缺失,需要重新下载,有些还不是官方的R包&#xff0c…

如何从零开始创建React应用:简易指南

🌟 前言 欢迎来到我的技术小宇宙!🌌 这里不仅是我记录技术点滴的后花园,也是我分享学习心得和项目经验的乐园。📚 无论你是技术小白还是资深大牛,这里总有一些内容能触动你的好奇心。🔍 &#x…

认识异常(1)

❤️❤️前言~🥳🎉🎉🎉 hellohello~,大家好💕💕,这里是E绵绵呀✋✋ ,如果觉得这篇文章还不错的话还请点赞❤️❤️收藏💞 💞 关注💥&a…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之十三 简单去除图片水印效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之十三 简单去除图片水印效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之十三 简单去除图片水印效果 一、简单介绍 二、简单去除图片水印效果实现原理 三、简单去除图片水印效果案例…

stm32实现hid键盘

前面的cubelmx项目配置参考 stm32实现hid鼠标-CSDN博客https://blog.csdn.net/anlog/article/details/137814494?spm1001.2014.3001.5502两个项目的配置完全相同。 代码 引用 键盘代码: 替换hid设备描述符 先屏蔽鼠标设备描述符 替换为键盘设备描述符 修改宏定…

Springboot框架——3.整合SpringMVC

1.修改端口号: 在application.properties中添加如下配置即可: server.port8088 2.静态资源访问: 首先打开ResourceProperties这个类的源码: 将静态资源放到类中默认位置即可实现访问: http://localhost:8088/erth.jp…

Docker安装xxl-job分布式任务调度平台

文章目录 Docker安装xxl-job分布式任务调度平台1.xxl-job介绍2. 初始化“调度数据库”3、docker挂载运行xxl-job容器3.1、在linux的opt目录下创建xxl_job文件夹,并在里面创建logs文件夹和application.properties文件3.2、配置application.properties文件&#xff0c…

springboot整合dubbo实现RPC服务远程调用

一、dubbo简介 1.什么是dubbo Apache Dubbo是一款微服务开发框架,他提供了RPC通信与微服务治理两大关键能力。有着远程发现与通信的能力,可以实现服务注册、负载均衡、流量调度等服务治理诉求。 2.dubbo基本工作原理 Contaniner:容器Provider&#xf…

安全中级-环境安装(手动nginx以及自动安装php,mysql)

为了方便大家跟bilibili课程,出了第一节环境 bilibili搜凌晨五点的星可以观看相关的教程 一、环境 ubentu 二、nginx手动安装 2.1第一步 wget https://nginx.org/download/nginx-1.24.0.tar.gz 2.2下载好安装包以后解压 tar -zxvf nginx-1.21.6.tar.gz2.3安…

Python零基础从小白打怪升级中~~~~~~~TCP网络编程

TCP网络编程 一、什么是TCP协议 TCP( Transmission control protocol )即传输控制协议,是一种面向连接、可靠的数据传输协议,它是为了在不可靠的互联网上提供可靠的端到端字节流而专门设计的一个传输协议。 面向连接 :数据传输之前客户端和…

文心一言用户数突破2亿 百度官宣三大AI开发神器

在日益激烈的竞争中,百度正在中国AI市场努力保持领导者地位,文心一言用户规模突破2亿,较去年年底翻了一番。 4月16日周二,以“创造未来”为主题的Create 2024百度AI开发者大会在深圳国际会展中心举办。百度CEO李彦宏在会议上指出…

如何用JAVA如何实现Word、Excel、PPT在线前端预览编辑的功能?

背景 随着信息化的发展,在线办公也日益成为了企业办公和个人学习不可或缺的一部分,作为微软Office的三大组成部分:Word、Excel和PPT也广泛应用于各种在线办公场景,但是由于浏览器限制及微软Office的不开源等特性,导致…

PLSQL中文乱码问题 + EZDML导入数据库模型乱码

PLSQL中文乱码问题 EZDML导入数据库模型乱码 查询数据库字符集 select userenv(language) from dual;查询本地字符集编码 select * from V$NLS_PARAMETERS;理论上 数据库字符集 跟 本地字符集编码 是一致的 本地字符集编码需要拼接字段值 NLS_LANGUAGE NLS_TERRITORY NLS…

PySpark预计算ClickHouse Bitmap实践

1. 背景 ClickHouse全称是Click Stream,Data WareHouse,是一款高性能的OLAP数据库,既使用了ROLAP模型,又拥有着比肩MOLAP的性能。我们可以用ClickHouse用来做分析平台快速出数。其中的bitmap结构方便我们对人群进行交并。Bitmap位…

【muzzik 分享】关于 MKFramework 的设计想法

MKFramework是我个人维护持续了几年的项目(虽然公开只有一年左右),最开始由于自己从事QP类游戏开发,我很喜欢MVVM,于是想把他做成 MVVM 框架,在论坛第一个 MVVM 框架出来的时候,我的框架已经快完…

uniapp--登录和注册页面-- login

目录 1.效果展示 2.源代码展示 测试登录 login.js 测试请求 request.js 测试首页index.js 1.效果展示 2.源代码展示 <template><view><f-navbar title"登录" navbarType"4"></f-navbar><view class"tips"><…