pytorch复现3_GoogLenet

news2024/9/22 5:37:44

背景:
GoogLeNeta是2014年提出的一种全新的深度学习结构,在这之前的AlexNet、VGG等结构都是通过增大网络的深度(层数)来获得更好的训练效果,但层数的增加会带来很多负作用,比如overfit、梯度消失、梯度爆炸等。GoogLeNet通过引入inception从另一种角度来提升训练结果:能更高效的利用计算资源,在相同的计算量下能提取到更多的特征,从而提升训练结果。
因此GoogLeNet在专注于加深网络结构的同时,引入了新的基本结构——Inception模块,以增加网络的宽度。

网络结构图:
在这里插入图片描述
1、Inception模块
Inception就是把多个卷积或池化操作,放在一起组装成一个网络模块
在这里插入图片描述
实际中需要什么样的Inception
  我们在上面提供了一种Inception的结构,但是这个结构存在很多问题,是不能够直接使用的。首要问题就是参数太多,导致特征图厚度太大。为了解决这个问题,作者在其中加入了1X1的卷积核,改进后的Inception结构如下图:

  在这里插入图片描述
代码:
定义BasicConv2d

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

定义Inception

class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

定义分类器

class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x

model完整代码:

import torch.nn as nn
import torch
import torch.nn.functional as F


class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)
        if self.training and self.aux_logits:   # eval model lose this layer
            return x, aux2, aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x


class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1153120.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码精简10倍,责任链模式yyds

1 推荐看的文章1 责任链设计——责任链验证推翻 if-else 炼狱 2 推荐看的文章2 代码精简10倍,责任链模式yyds

引入了mybatis-spring-boot-starter,还需要引入mysql-connector-java吗?

spring boot集成mybatis&#xff0c;是需要引入mybatis-spring-boot-starter&#xff0c;有文章说不需要引入mysql-connector-java&#xff0c;但实际用下来并不行&#xff0c;我看了里面的pom文件&#xff0c;终于知道怎么一回事。 <!--引入mybatis的依赖--><depende…

HTTP协议说明

1.用于HTTP协议交互的信息被称为HTTP报文。请求端&#xff08;客户端&#xff09;的HTTP报文叫做请求报文&#xff0c;响应端&#xff08;服务器端&#xff09;的叫做响应报文。HTTP 报文本身是由多行&#xff08;用 CRLF 作换行符&#xff09;数据构成的字符串文本。 HTTP报文…

[Linux C] signal 的使用

前言&#xff1a; signal 是一种通信机制&#xff0c;可以跨进程发送&#xff0c;可以同进程跨线程发送&#xff0c;可以不同进程向指定线程发送。 信号的创建有两套api&#xff0c;一个是signal&#xff0c;一个是sigaction&#xff0c;signal缺陷很多&#xff0c;比如没有提…

亚马逊美国站衣物收纳商品合规标准是什么?如何办理?

随着秋季的来临&#xff0c;不少人翻箱倒柜地寻找换季用品。相信现在很多人都和小编一样&#xff0c;出门时打算找个外套穿上&#xff0c;但想到要去柜子里翻半天&#xff0c;就立刻打消了想要出门的念头。 但当翻箱倒柜地找到了换季用品&#xff0c;却又要一件一件地把翻出来…

Variations-of-SFANet-for-Crowd-Counting可视化代码

前文对Variations-of-SFANet-for-Crowd-Counting做了一点基础梳理&#xff0c;链接如下&#xff1a;Variations-of-SFANet-for-Crowd-Counting记录-CSDN博客 本次对其中两个可视化代码进行梳理 1.Visualization_ShanghaiTech.ipynb 不太习惯用jupyter notebook, 这里改成了p…

spring解决后端显示时区的问题

spring解决后端显示时区的问题 出现的问题&#xff1a; 数据库中的数据&#xff1a; 解决方法 spring:jackson:date-format: yyyy-MM-dd HH:mm:sstime-zone: Asia/Shanghai

vscode前端必备插件

安装插件的位置如下&#xff1a; 1、Chinese (Simplified) Language Pack 中文简体插件 2、Vetur Vue官方钦定插件&#xff0c;包括&#xff1a;语法高亮&#xff0c;智能提示&#xff0c;错误提示&#xff0c;格式化&#xff0c;自动补全等等 3、ESLint 语法检查工具&#…

客户端性能测试基础知识

目录 1、客户端性能 1.1、客户端性能基础知识 2、客户端性能工具介绍与环境搭建 2.1.1、perfdog的使用 2.1.2、renderdoc的使用 1、客户端性能 1.1、客户端性能基础知识 客户端性能知识这里对2D和3D类游戏进行展开进行&#xff0c;讲述的有内存、CPU、GPU、帧率这几个模块…

云栖大会十五年:开放创新,未来愿景

时光荏苒&#xff0c;转眼间云栖大会已经走过了十五个年头&#xff0c;这一场中国云计算行业的盛会已经成为业内不可或缺的一部分。在这个特殊的时刻&#xff0c;我想分享一些对未来云栖大会的期待与建议&#xff0c;希望这个盛会能够继续推动云计算领域的创新和发展。 云栖大会…

数据库深入浅出,数据库介绍,SQL介绍,DDL、DML、DQL、TCL介绍

一、基础知识&#xff1a; 1.数据库基础知识 数据(Data)&#xff1a;文本信息(字母、数字、符号等)、音频、视频、图片等&#xff1b; 数据库(DataBase)&#xff1a;存储数据的仓库&#xff0c;本质文件&#xff0c;以文件的形式将数据保存到电脑磁盘中 数据库管理系统(DBMS)&…

LSF 概览——了解 LSF 是如何满足您的作业要求,并找到最佳资源来运行该作业的

LSF 概览 了解 LSF 是如何满足您的作业要求&#xff0c;并找到最佳资源来运行该作业的。 IBM Spectrum LSF ("LSF", load sharing facility 的简称) 软件是行业领先的企业级软件。LSF 将工作分散在现有的各种 IT 资源中&#xff0c;以创建共享的&#xff0c;可扩展…

国内内卷太严重,还不考虑一下在海外接单?那这几个平台你知道吗?

作为一个程序员&#xff0c;在平台上接单赚点外快是再正常不过的事情了&#xff0c;但是现今国内各个平台都内卷比较严重&#xff0c;你是否考虑过去“外面的世界”看看&#xff1f; 如果想过&#xff0c;那么这几个外国的接单平台你都知道吗&#xff1f; 接下来就和我一起来看…

vmWare虚拟机扩容及pip国内镜像源

扩展虚拟机容量 打开虚拟机.sudo apt-get install gparted pip镜像源 pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple-i https://pypi.douban.com/simple-i https:// mirrors.aliyun.com/pypi/simple

Linux CentOS7 shell

学好linux&#xff0c;首先要深入理解shell。 shell俗称壳&#xff0c;它包裹在内核的外面&#xff0c;是用户命令的翻译官。 作用&#xff1a;接收用户的命令&#xff0c;翻译后(处理一下)交给Linux内核处理。 用户执行命令 -> shell -> 内核 -> CPU -> 内核 -…

C/C++笔试易错与高频题型图解知识点(三)——数据结构部分(持续更新中)

目录 1. 排序 1.1 冒泡排序的改进 2. 二叉树 2.1 二叉树的性质 3. 栈 & 队列 3.1 循环队列 3.2 链式队列 4. 平衡二叉搜索树——AVL树、红黑树 5 优先级队列&#xff08;堆&#xff09; 1. 排序 1.1 冒泡排序的改进 下面的排序方法中&#xff0c;关键字比较次数与记录的初…

LeetCode 996.正方形数组的数目

和上一道状压的区别在于我们要去重一下~ 思路都是和上一篇博客是一样的&#xff0c;感兴趣的同学可以看一下 const int N 15; int dp[1<<N][N]; int n; vector<int>nums1;bool check(int x){int tem sqrt(x);if(tem*temx)return 1;return 0; }int dfs(int u,in…

比较Excel中的两列目录编号是否一致

使用java代码比较excel中两列是否有包含关系&#xff0c;若有包含关系&#xff0c;核对编号是否一致。 excel数据样例如下&#xff1a; package com.itownet.hg;import org.apache.poi.xssf.usermodel.XSSFSheet; import org.apache.poi.xssf.usermodel.XSSFWorkbook;import j…

网站如何改成HTTPS访问

在今天的互联网环境中&#xff0c;将网站更改成HTTPS访问已经成为了一种标准做法。HTTPS不仅有助于提高网站的安全性&#xff0c;还可以提高搜索引擎排名&#xff0c;并增强用户信任。因此&#xff0c;转换为HTTPS是一个重要的举措&#xff0c;无论您拥有个人博客、电子商务网站…

如何将你的PC电脑数据迁移到Mac电脑?使用“迁移助理”从 PC 传输到 Mac的具体操作教程

有的小伙伴因为某一项工作或者其它原因由Windows电脑换成了Mac电脑&#xff0c;但是数据和文件都在原先的Windows电脑上&#xff0c;不知道怎么传输。接下来小编就为大家介绍使用“迁移助理”将你的通讯录、日历、电子邮件帐户等内容从 Windows PC 传输到 Mac 上的相应位置。 在…