GoogLeNet,代码示例,辅助分类器,Inception

news2025/1/11 19:43:56

亮点:

Ø 引入了Inception结构(融合不同尺度的特征信息)
Ø 使用1x1的卷积核进行降维以及映射处理
Ø 添加两个辅助分类器帮助训练
Ø 丢弃全连接层,使用平均池化层(大大减少模型 参数)

GoogLeNet的网络连接图:

卷积层,最大池化层,LocalResponseNorm层
显示一系列inception结构,
辅助分类器1,
inception结构,
辅助分类器2,
inception结构,
平均池化下采样操作
然后与我们的输出节点进行一个全连接再通过softmax函数得到输出。

Inception结构:

每个分支所得的特征矩阵高和宽必须相同;

上上图的参数对应我们Inceptian卷积层中卷积核的个数

#3*3reduce对应Inceptian中分支2上卷积核大小为1*1的卷积核个数

辅助分类器(Auxiliary Classifier):

第一层平均池化下采样操作:池化核大小是5*5,步长为3。得到(4a)4*4*521特征矩阵和(4dd)4*4*528特征矩阵【因为有两个辅助分类器】

矩阵计算公式:out=(in-F+2P)/S +1

使用128个1*1的卷积层进行卷积降低维度,使用RULE激活函数

采用节点个数为1024的全连接层,使用RULE激活函数

然后使用dropout函数以70%比例随机失活神经元。

输出层:节点个数对应类别个数。通过softmax激活函数得到概率分布。

代码示例:

model.py

import torch.nn as nn
import torch
import torch.nn.functional as F


class GoogLeNet(nn.Module):
    # 类别个数,是否使用辅助分类器(布尔变量)
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        # 输入原因:RGB图像
        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        #  ceil_mode=True向上取整
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        # 特征矩阵
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        #自适应平均池化下采样操作 输出特征矩阵需要的高和宽
        # 无论输入的特征矩阵高和宽是多少,输出都是1*1
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14
        # 判断是否使用辅助分类器self.training当前模型处于什么模式
        # 当处于训练模式并使用分类器则
        if self.training and self.aux_logits:    # eval model lose this layer
            # 把4a的输出输入到我们的辅助分类器1中
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)
        if self.training and self.aux_logits:   # eval model lose this layer
            return x, aux2, aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # 在官方的实现中,其实是3x3的kernel并不是5x5,,具体可以参考下面的issue
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        # cat 相当于异构的网络,把outputs串联相连在深度方向拼接
        # troch.cat(inputs, dimension=0,out=None)、dim:沿着此维连接张量序列
        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


# 辅助分类器
class InceptionAux(nn.Module):
    # 深度,类别
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        # 平均池化下采样操作
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        # x输入的特征矩阵
        # 当我们实例化一个模型model后,可以通过model.train()和model.eval()来控制模型的状态
        # 随着我们训练和测试(不需要辅助分类器的结果)的不同而变化的。
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x

# 卷积和激活函数连接
class BasicConv2d(nn.Module):
    # 输入矩阵的深度,输出矩阵的深度
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
    # 官方的模型中使用了bn层以及改了一些参数,不能混用
    # import torchvision
    # net = torchvision.models.googlenet(num_classes=5)
    # model_dict = net.state_dict()
    # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    # pretrain_model = torch.load("googlenet.pth")
    # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
    #             "aux2.fc2.weight", "aux2.fc2.bias",
    #             "fc.weight", "fc.bias"]
    # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    # model_dict.update(pretrain_dict)
    # net.load_state_dict(model_dict)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0003)

    epochs = 30
    best_acc = 0.0
    save_path = './googleNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            # 有三个输出
            logits, aux_logits2, aux_logits1 = net(images.to(device))
            loss0 = loss_function(logits, labels.to(device))
            loss1 = loss_function(aux_logits1, labels.to(device))
            loss2 = loss_function(aux_logits2, labels.to(device))
            loss = loss0 + loss1 * 0.3 + loss2 * 0.3
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))  # eval model only have last output layer
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

辅助分类器是训练过程中优化网络参数的,测试时参数都优化好了,只用考虑最后的输出结果就是了

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # load model weights
    # strict=False是因为我们在预测是不使用辅助辅助分类器,模型权重缺少,调试中unexpected_keys辅助分类器一系列层
    weights_path = "./googleNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                          strict=False)

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2209205.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LInux学习FreeType编程

文章目录 使用freetype 显示一个文字使用 freetype 显示一行文字了解笛卡尔坐标系每个字符的大小可能不同怎么在指定位置显示一行文字freetype 的几个重要数据结构1、**FT_Library**结构体2、FT_Face结构体3、FT_GlyphSlot结构体4、FT_Glyph结构体5、FT_BBox结构体 读懂显示一行…

Linux运维_Apache更改默认网站目录

1.首先创建目录 并且在目录下新建测试文件 index.html mkdir -p /home/test/ap_web 直接wget 百度官网 wget www.baidu.com 2.编辑配置文件 /etc/apache2/sites-available/000-default.conf(找到 DocumentRoot)更改为刚刚创建的目录 接着在添加 最终文件: 3.给文件 添加属…

Nacos配置管理和Nacos集群配置

目录 Nacos作为配置中心实现配置管理 统一配置管理 如何在nocas添加配置文件 在微服务拉取nacos配置中心的配置 1)引入nacos-config依赖 2)添加bootstrap.yaml 3)测试,读取nacos配置中心中配置文件的内容 ​编辑 总结&…

在Spring Boot中具有多个实现的接口正确注入的六种方式

​ 博客主页: 南来_北往 系列专栏:Spring Boot实战 在Spring Boot中,当一个接口具有多个实现时,正确地将这些实现注入到需要使用它们的地方是一个常见的需求。以下是在Spring Boot中实现这一目标的六种方式: 1. 使用Autowir…

maven聚合ssm

如果没有写过ssm项目请移步SSM后端框架搭建(有图有真相)-CSDN博客 数据库准备 create table user (id int (11),uid varchar (60),name varchar (60),age int (11),sex varchar (12) ); insert into user (id, uid, name, age, sex) values(10,202409…

小米电机与STM32——CAN通信

背景介绍:为了利用小米电机,搭建机械臂的关节,需要学习小米电机的使用方法。计划采用STM32驱动小米电机,实现指定运动,为此需要了解他们之间的通信方式,指令写入方法等。花了很多时间学习,但网络…

LINUX网络编程:cookie和session

目录 1.cookie 1.2.cookie原理 1.3.cookie的格式以及字段 字段介绍: 完整的cookie 1.4.cookie的安全问题 2.session 2.2session的原理 1.cookie 在大家在浏览b站的时候,都会发现一个问题,当我们登录过一次之后,下次点开b站…

2024年最新Stable Diffusion模型资源合集!附整合安装包!

(模型资源在ComfyUI、WebUI以及ForgeUI中都通用) 之前的Stable Diffusion笔记受到了不少小伙伴的关注,很感谢大家的建议和支持。有很多小伙伴私信我问我一些AI绘画的模型资源在哪来下载,一般来说有两个网站比较常用,分…

位操作解决数组的花样遍历

文章目录 题目 一、思路: 二、代码 总结 题目 leetcodeT289 https://leetcode.cn/problems/game-of-life/description/ 一、思路: 这题思路很简单,对每个位置按照题目所给规则进行遍历,判断周围网格的活细胞数即可。但是题目要求…

【LVGL快速入门】SquareLine Studio安装教程(LVGL官方工具)

一.简介与导航: SquareLine Studio是由LVGL官方开发的一款UI设计工具,采用图形化进行界面UI设计,轻易上手。 SquareLine Studio官方网址:https://squareline.io/SquareLine Studio官方文档:https://docs.squareline.io…

太阳能电池特性及其应用

中南民族大学-通信工程2024-大学物理下实验 目录 代码实现结果显示 🛠工具使用 MarsCode(插件,集成在PyCharm); python编程(豆包AI智能体) 💻编程改进 此处是用「Matplotlib」来作图…

Monkey测试工具大盘点!如何选怎么用全整明白了!

什么是Monkey测试? 以下是官方说法: Monkey 测试是通过向系统发送伪随机的用户事件流(如按键输入、触摸屏输入、手势输入等),实现对应用程序客户端的稳定性测试;这种随机性可以模拟真实用户的行为&#x…

理解Web3的互操作性:不同区块链的连接

随着Web3的迅速发展,互操作性成为区块链技术中的一个核心概念。互操作性指的是不同区块链之间能够无缝地交流和共享数据,从而实现更加高效和灵活的生态系统。本文将探讨Web3中互操作性的意义、面临的挑战以及未来的发展趋势。 1. 互操作性的意义 在Web…

优达学城 Generative AI 课程3:Computer Vision and Generative AI

文章目录 1 官方课程内容自述第 1 课:图像生成简介第 2 课:计算机视觉基础第 3 课:图像生成与生成对抗网络(GANs)第 4 课:基于 Transformer 的计算机视觉模型第 5 课:扩散模型第 6 课&#xff0…

利用AI大模型,增强你的DevOps!

前言 自从去年春天ChatGPT问世之后,互联网也掀起了拥抱AI的浪潮,不仅是各大头部大厂相继发布大模型产品,在开发者的Coding过程中也紧跟时代,一些热门插件也纷纷受到了开发者的青睐,比如GitHub Copilot的智能代码生成。…

数据结构编程实践20讲(Python版)—05二叉树

本文目录 写在前面:大“树”下好乘凉定义主要术语基本特征主要应用领域:05 二叉树Binary treeS1 说明S2 示例S3 二叉树类型(1)满二叉树(Perfect Binary Tree)(2)完全二叉树(Complete Binary Tree)(3)二叉搜索树(Binary Search Tree)(4)平衡二叉树(Balanced Bin…

Windows环境NodeJS下载配置安装运行

Windows环境NodeJS下载配置安装运行 (1)下载 Node.js — Run JavaScript Everywhere 安装文件。 一路傻瓜式安装。 如果安装正常,输入命令可显示版本号: (2)可以查询nodejs默认的后续依赖安装包位置及缓存…

稻盛和夫认为,一个领导是否值得追随,看这几点就够了

一、真正的领导者有自信,但不自大,有坚定的信念和价值观。 稻盛和夫认为,领导者的真正强大之处在于他们能够坚定地做正确的事情。领导者必须具备勇气和决心,以及坚定的信念,以便在面对挑战和困难时能够坚持自己的信念…

Vert.x,Web - 静态资源/模板

静态资源 Vert.x-Web带有开箱即用的处理器(StaticHandler),用于处理静态Web资源(.html, .css, .js, …), 因此可以非常轻松地编写静态Web服务器。 默认静态文件目录为类路径下的webroot目录,对于maven的项目,按规范放在src/main/…

BIO与NIO学习

BIO:同步阻塞IO,客户端一个连接请求(socket)对应一个线程。阻塞体现在: 程序在执行I/O操作时会阻塞当前线程,直到I/O操作完成。在线程空闲的时候也无法释放用于别的服务只能等当前绑定的客户端的消息。 BIO的代码实现 …