霹雳吧啦Wz《pytorch图像分类》-p4GoogLeNet网络

news2024/12/31 0:55:55

《pytorch图像分类》p4GoogLeNet网络详解

  • 一、GoogLeNet网络中的亮点
    • 1.inception结构
    • 2.使用1×1的卷积核进行降维及映射处理
    • 3.GoogLeNet辅助分类器
    • 4.模型参数
  • 二、模块代码
    • 1.BasicConv2d
    • 2.Inception
  • 三、课程代码
    • 1.module.py
    • 2.train.py
    • 3.predict.py

一、GoogLeNet网络中的亮点

论文链接:Going Deeper with Convolutions
代码链接:霹雳吧啦Wzdeep-learning-for-image-processing

1.inception结构

Inception结构是由Google团队提出的一种卷积神经网络架构。
Inception结构的主要目标是解决卷积神经网络中平衡计算量和表示能力之间的权衡问题。它通过并行执行多个不同尺寸的卷积操作和池化操作,从而充分利用不同大小的感受野(receptive field)来捕捉图像的特征。

ps.英语单词inception n.开端,创始
在这里插入图片描述)

2.使用1×1的卷积核进行降维及映射处理

使用1×1的卷积网络进行降维,可以在神经网络中减少特征图的维度,提高神经网络的效率和性能。
在这里插入图片描述

3.GoogLeNet辅助分类器

主要目的是通过这些辅助分类器来引入额外的目标函数来帮助训练深层神经网络。(太抽象了)
简单来说就是,当训练非常深的神经网络时,梯度消失是一个普遍的问题,导致较低层的权重很难更新。为了解决这个问题,GoogLeNet引入了辅助分类器。
我记得Sigmoid函数在饱和时,容易出现梯度消失现象。
可见我的博客:pytorch机器学习各种激活函数总结
在这里插入图片描述

4.模型参数

GoogLeNet使用了Inception结构,在减少参数量的同时保持性能。
VGG网络是一个非常深的卷积神经网络,以VGG16为例,由13个卷积层和3个全连接层组成。由于网络深度的增加,它具有非常多的模型参数。
但是VGG网络搭建方便,GoogLeNet辅助分类器修改麻烦,所以现实中使用vgg网络的更多。
在这里插入图片描述

二、模块代码

1.BasicConv2d

BasicConv2d类里面自定义包含了一个卷积层+Relu激活函数

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_chennels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_chennels, **kwargs)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,x):
        x = self.conv(x)
        x = self.relu(x)
        return x

2.Inception

在这里插入图片描述

class Inception(nn.Module):
    def __init__(self,in_channels,ch1x1,ch3x3red,ch3x3,ch5x5red,ch5x5,pool_proj):

在这里插入图片描述
添加padding参数时,要保证输出大小等于输入大小

三、课程代码

1.module.py

import torch.nn as nn
import torch
import torch.nn.functional as F


class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14
        if self.training and self.aux_logits:  # eval model lose this layer
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14
        if self.training and self.aux_logits:  # eval model lose this layer
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)
        if self.training and self.aux_logits:  # eval model lose this layer
            return x, aux2, aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)  # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)  # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x


class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

2.train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 2
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
    # 官方的模型中使用了bn层以及改了一些参数,不能混用
    # import torchvision
    # net = torchvision.models.googlenet(num_classes=5)
    # model_dict = net.state_dict()
    # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    # pretrain_model = torch.load("googlenet.pth")
    # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
    #             "aux2.fc2.weight", "aux2.fc2.bias",
    #             "fc.weight", "fc.bias"]
    # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    # model_dict.update(pretrain_dict)
    # net.load_state_dict(model_dict)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0003)

    epochs = 30
    best_acc = 0.0
    save_path = './googleNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits, aux_logits2, aux_logits1 = net(images.to(device))
            loss0 = loss_function(logits, labels.to(device))
            loss1 = loss_function(aux_logits1, labels.to(device))
            loss2 = loss_function(aux_logits2, labels.to(device))
            loss = loss0 + loss1 * 0.3 + loss2 * 0.3
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))  # eval model only have last output layer
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

我的电脑带不动,batch_size=32我改成了batch_size=2,也就运行了一个多小时吧(麻了)
在这里插入图片描述

3.predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "2.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # load model weights
    weights_path = "./googleNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                          strict=False)

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

我去,预测这么准的吗?
1.jpg 蒲公英
在这里插入图片描述
2.jpg 郁金香
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1351178.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Springcloud 微服务实战笔记 Ribbon

使用 Configurationpublic class CustomConfiguration {BeanLoadBalanced // 开启负载均衡能力public RestTemplate restTemplate() {return new RestTemplate();}}可看到使用Ribbon,非常简单,只需将LoadBalanced注解加在RestTemplate的Bean上&#xff0…

Django 分页(表单)

目录 一、手动分页二、分页器分页 一、手动分页 1、概念 页码:很容易理解,就是一本书的页码每页数量:就是一本书中某一页中的内容(数据量,比如第二页有15行内容),这 15 就是该页的数据量 每一…

将 Python 和 Rust 融合在一起,为 pyQuil® 4.0 带来和谐

文章目录 前言设定方向从 Rust 库构建 Python 软件包改装 pyQuil异步困境回报:功能和性能结论 前言 pyQuil 一直是在 Rigetti 量子处理单元(QPUs)上构建和运行量子程序的基石,通过我们的 Quantum Cloud Services(QCS™…

EBU7140 Security and Authentication(三)密钥管理;IP 层安全

B3 密钥管理 密钥分类: 按时长: short term:短期密钥,用于一次加密。long term:长期密钥,用于加密或者授权。 按服务类型: Authentication keys:公钥长期,私钥短期…

【HarmonyOS开发】共享包HAR和HSP的创建和使用以及三方库的发布

OpenHarmony提供了两种共享包,HAR(Harmony Archive)静态共享包,和HSP(Harmony Shared Package)动态共享包。 HAR与HSP都是为了实现代码和资源的共享,都可以包含代码、C库、资源和配置文件&…

C语言——操作符

一、算数操作符 1、(加操作符) 用于将两个数相加,例:3 3结果为6 2、-(减操作符) 用于将两个数相减,例:3 - 3结果为0 3、*(乘操作符) 用于将两个数相乘,例:3 * 3结果为9 4、/(除操作符) 用于将两个…

fastadmin学习02-修改后台管理员账号密码

问题 如果是别人部署好的fastadmin网站不知道后台登录地址和账号密码怎么办 后台登录地址 public目录下有一个很奇怪的php就是后台登录地址啦 忘记账号密码 找到fa_admin,fa_是前缀,肯能每个项目不太一样 UPDATE fa_admin set password1d020dee8ec…

计算机毕业论文内容参考|基于智能搜索引擎的图书管理系统的设计与实现

文章目录 摘要前言绪论课题背景国内外现状与趋势课题内容相关技术与方法介绍系统分析系统设计系统实现系统测试总结与展望摘要 本文介绍了基于智能搜索引擎的图书管理系统的设计与实现。该系统旨在提供一个高效、智能化的图书管理平台,帮助用户更快、更准确地找到所需的图书资…

使用Apache Commons SCXML实现状态机管理

第1章:引言 大家好,我是小黑,咱们程序员在开发过程中,经常会遇到需要管理不同状态和状态之间转换的场景。比如,一个在线购物的订单,它可能有“新建订单”、“已支付”、“配送中”、“已完成”等状态。在这…

Linux/Unix/国产化操作系统常用命令(二)

目录 后CentOS时代国产化操作系统国产化操作系统有哪些常用Linux命令关于Linux的LOGO 后CentOS时代 在CentOS 8发布后,就有了一些变化和趋势,可以说是进入了"后CentOS时代"。这个时代主要表现在以下几个方面: CentOS Stream的引入…

Spring技术内幕笔记之SpringMvc

WebApplicationContext接口的类继承关系 org.springframework.web.context.ContextLoader#initWebApplicationContext 对IOC容器的初始化 SpringMvc如何设计 DispatcherServlet类继承关系 MVC处理流程图如下: DispatcherServlet的工作大致可以分为两个部分&#xf…

Spring Boot 如何使用 Maven 实现多环境配置管理

Spring Boot 如何使用 Maven 实现多环境配置管理 实现多环境配置有以下几个重要原因: 适应不同的部署环境:在实际部署应用程序时,通常会有多个不同的部署环境,如开发环境、测试环境、生产环境等。每个环境可能需要不同的配置&…

Linux 命令tail

命令作用 tail 命令用于显示文件的末尾内容,默认显示文件的最后 10 行。通常情况下,tail 命令用于实时查看动态日志文件,可以使用 -f 参数跟踪文件内容的变化。 语法 tail [选项] [文件名] 参数 以 log.txt 为例演示参数效果 -n -linesK…

解决报错:找不到显卡

今天做实验碰到一个问题:torch找不到显卡: 打开任务管理器,独显直接没了,一度以为是要去修电脑了,突然想到上次做实验爆显存,屏蔽了gpu用cpu训练: import os os.environ["CUDA_DEVICE_OR…

【华为数据之道学习笔记】9-3构建以元数据为基础的安全隐私保护框架

以元数据为基础的安全隐私治理 有决策权的公司高层已经意识到安全隐私的重要性,在变革指导委员会以及各个高层会议纪要中都明确指明安全隐私是变革优先级非常高的主题,安全是一切业务的保障。 基于这个大前提,我们构建了以元数据为基础的安全…

【Spark精讲】记一个SparkSQL引擎层面的优化:SortMergeJoinExec

SparkSQL的Join执行流程 如下图所示,在分析不同类型的Join具体执行之前,先介绍Join执行的基本框架,框架中的一些概念和定义是在不同的SQL场景中使用的。 在Spark SQL中Join的实现都基于一个基本的流程,根据角色的不同&#xff0…

zookeeper未授权访问漏洞增加用户认证修复

linux机器中使用root命令行cd到zookeeper的bin文件夹下 启动zookeeper ./zkCli.sh # 启动zookeeper如果此时有未授权漏洞,可通过以下命令验证。 getAcl / #可以看到默认是world:anyone 就相当于无权限访问验证结果显示没有用户认证也可执行一些命令 增添用户认…

[GKCTF 2020]ez三剑客-eztypecho

[GKCTF 2020]ez三剑客-eztypecho 考点:Typecho反序列化漏洞 打开题目,发现是typecho的CMS 尝试跟着创建数据库发现不行,那么就搜搜此版本的相关信息发现存在反序列化漏洞 参考文章 跟着该文章分析来,首先找到install.php&#xf…

LeetCode第32题 : 最长有效括号

题目介绍 给你一个只包含 ( 和 ) 的字符串,找出最长有效(格式正确且连续)括号子串的长度。 示例 1: 输入:s "(()" 输出:2 解释:最长有效括号子串是 "()" 示例 2&#xf…

DALLE·3的图片一致性火了!

DALLE3的图片一致性火了! DALLE3 利用种子值&GenID的一致性作图 防止大家不知道,这里拓展一下种子值,DALL每次生成的图像都会有一个新的、随机的种子值。我们可以利用种子值进行角色的变换 每次生成的图像都会有一个新的、随机的种子值…