GoogLeNet(pytorch)

news2024/12/23 22:14:44

亮点与创新:

1. 引入Inception基础结构

2. 引入PW维度变换卷积,启迪后续参数量的优化

3. 丢弃全连接层,使用平均池化层(大大减少模型参数)

4. 添加两个辅助分类器帮助训练(避免梯度消失,用于向前传导梯度,也有一定的正则化效果,防止过拟合)

 model.py

import torch.nn as nn
import torch
import torch.nn.functional as F


class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)
        if self.training and self.aux_logits:   # eval model lose this layer
            return x, aux2, aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x


class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

train.py

模型有三个返回结果,一个预测,两个辅助分类器

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()

    net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
    # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
    # 官方的模型中使用了bn层以及改了一些参数,不能混用
    # import torchvision
    # net = torchvision.models.googlenet(num_classes=5)
    # model_dict = net.state_dict()
    # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    # pretrain_model = torch.load("googlenet.pth")
    # del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
    #             "aux2.fc2.weight", "aux2.fc2.bias",
    #             "fc.weight", "fc.bias"]
    # pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    # model_dict.update(pretrain_dict)
    # net.load_state_dict(model_dict)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0003)

    epochs = 30
    best_acc = 0.0
    save_path = './googleNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            logits, aux_logits2, aux_logits1 = net(images.to(device))
            loss0 = loss_function(logits, labels.to(device))
            loss1 = loss_function(aux_logits1, labels.to(device))
            loss2 = loss_function(aux_logits2, labels.to(device))
            loss = loss0 + loss1 * 0.3 + loss2 * 0.3
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))  # eval model only have last output layer
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

注意训练代码里的这段注释代码:

    import torchvision
    net = torchvision.models.googlenet(num_classes=5)

#如果用的是官方提供的torchvision.models.googlenet模型,下载下来官方训练的权重是在imagenet1000分类上训练到的权重,需要删除一些层的权重来训练,而且训练的模型直接就是官方的模型,得到自己适应训练集的权重,简单总结,官方模型,官方权重,自己数据集,稍微修改官方权重,使用稍微修改的官方权重训练官方模型下自己数据集的权重
#对于怎么冻结权重,解冻权重,由于这个模型学习时还在训练前面的VGG网络,到resnet进行学习

#若想用官方权重训练自己的模型得到适应数据集的权重,也应该是修改一下官方权重
#之后resnet细说


    model_dict = net.state_dict()
    # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
    pretrain_model = torch.load("googlenet.pth")
    del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
                "aux2.fc2.weight", "aux2.fc2.bias",
                "fc.weight", "fc.bias"]
    pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
    model_dict.update(pretrain_dict)
    net.load_state_dict(model_dict)

predict.py

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import GoogLeNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # load image
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    #对单张图片添加维度,以适应模型
    #如果是批次预测,之后注意批次预测的代码怎么设计
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = GoogLeNet(num_classes=5, aux_logits=False).to(device)

    # load model weights
    weights_path = "./googleNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),
                                                          strict=False)

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1317110.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智能电气柜环境监测系统

智能电气柜环境监控系统是一种基于传感器技术和物联网技术的智能化监控系统,用于对电气柜内的环境参数进行实时监测和管理。依托智慧电力运维工具-电易云,通过安装在电气柜内的多个传感器,实时采集电气柜内的温度、湿度、氧气浓度、烟雾等关键…

windows redis 允许远程访问配置

安装好windows版本的redis,会以服务方式启动,但是不能远程访问,这个时候需要修改配置。redis安装路径下会有2个配置文件,究竟需要怎么修改才能生效呢?看下图 这里的redis服务指定了是redis.windows-service.conf文件&…

java_web_电商项目

java_web_电商项目 1.登录界面2.注册界面3. 主界面4.分页界面5.商品详情界面6. 购物车界面7.确认订单界面8.个人中心界面9.收货地址界面10.用户信息界面11.用户余额充值界面12.后台首页13.后台商品增加14.后台用户增加15.用户管理16.源码分享1.登录页面的源码2.我们的主界面 1.…

Xml与Json格式在线转换器

具体请前往:在线Json转Form表单参数工具

计算机网络(四)

九、网络安全 (一)什么是网络安全? A、网络安全状况 分布式反射攻击逐渐成为拒绝攻击的重要形式 涉及重要行业和政府部门的高危漏洞事件增多。 基础应用和通用软硬件漏洞风险凸显(“心脏出血”,“破壳”等&#x…

springMVC-@RequestMapping

基本介绍 RequestMapping注解可以指定控制器/处理器的某个方法的请求的url, 示例 (结合springMVC基本原理理解) Controller public class UserHandler {RequestMapping(value "/login")public String login() {System.out.println("登…

JOSEF约瑟 静态双位置继电器 DPR-35 DC110V柜内固定安装,板前接线

系列型号: DPR-20双位置继电器;DPR-31双位置继电器; DPR-32双位置继电器;DPR-33双位置继电器; DPR-34双位置继电器;DPR-35双位置继电器; DPR-11双位置继电器;DPR-12双位置继电器…

【数据结构和算法】--队列的特殊结构-循环队列

目录 循环队列的结构循环队列的实现循环队列的创建循环队列为空判断循环队列为满判断入队出队返回循环队列首元素返回循环队列尾元素释放循环队列 循环队列的结构 循环队列是队列的一种特殊结构,它的长度是固定的k,同样是先进先出,理论结构是…

飞天使-docker知识点6-容器dockerfile各项名词解释

文章目录 docker的小技巧dockerfile容器为什么会出现启动了不暂停查看docker 网桥相关信息 docker 数据卷 docker的小技巧 [rootlight-test playbook-vars[]# docker inspect -f "{{.NetworkSettings.IPAddress}}" d3a9ae03ae5f 172.17.0.4docker d3a9ae03ae5f:/etc…

RK3399平台开发系列讲解(内核入门篇)什么是函数调用栈

🚀返回专栏总目录 文章目录 一、什么是函数调用栈二、函数调用栈解析三、什么是stack overflow沉淀、分享、成长,让自己和他人都能有所收获!😄 📢在开发软件的过程中我们经常会遇到错误,如果你用 Google 搜过出错信息,那你多少应该都访问过Stack Overflow这个网站。作…

三、JS逆向

一、JS逆向 解释:在我们爬虫的过程中经常会遇到参数被加密的情况,这样只有先在前端搞清楚加密参数是怎么生成的才能继续我们的爬虫,而且此时我们还需要用python去执行这个加密的过程。本文主要讲怎么在浏览器调试JS,以及Python执…

网络(七)路由协议以及相关配置

目录 一、路由器的工作原理 二、路由表的形成 2.1 直连网段 2.2 非直连网 2.3 路由表解析 2.3.1 查看路由表 2.3.2 解析 三、静态路由和默认路由 1. 静态路由 1.1 定义 1.2 特点 2. 默认路由 2.1 定义 2.2 特点 四、静态路由和默认路由的配置 1. 静态路由配置…

很抱歉,Midjourney,但Leonardo AI的图像指导暂时还无人能及…至少目前是这样

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

汽车充电协议OpenV2G的平替cbexigen!!

纵所周知,开源欧规协议 CCS 的 OpenV2G 协议不支持 ISO15118-20:2022 协议,并且软件维护者已经明确不在进行该软件的维护。 前几天在 Github 上冲浪发现了一个宝藏开源项目,完美的实现的 OpenV2G 的 Exidizer 工具的功能:cbexigen…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《考虑灵活性资源传输精细化建模的配电网优化运行》

这个标题表达的是关于配电网优化运行的一个概念,其中考虑了灵活性资源传输的精细化建模。让我们逐个解读关键词: 考虑灵活性资源传输:这指的是在配电网优化运行中考虑到不同类型的灵活性资源的传输。灵活性资源包括可再生能源、储能系统、柔性…

2044回文字符串(C语言)

目录 一:题目 二:思路分析 1.什么是回文? 2.判断回文: 三:代码 一:题目 二:思路分析 1.什么是回文? 最简单的理解方式就是一个字符串正着写和倒着写一样 2.判断回文&#xff1…

【C语言(十)】

字符函数和字符串函数 一、字符分类函数 C语言中有⼀系列的函数是专门做字符分类的,也就是⼀个字符是属于什么类型的字符的。这些函数的使用都需要包含⼀个头文件是 ctype.h 这些函数的使用方法非常类似,我们就讲解⼀个函数的事情,其他的非…

代码随想录-刷题第二十九天

491. 递增子序列 题目连接:491. 递增子序列 思路:将问题抽象为树形结构,使用回溯法。本题求自增子序列,是不能对原数组进行排序的,排完序的数组都是自增子序列了。所以不能使用之前的去重逻辑! HashSet是…

【人工智能】实验二: 洗衣机模糊推理系统实验与基础知识

实验二: 洗衣机模糊推理系统实验 实验目的 理解模糊逻辑推理的原理及特点,熟练应用模糊推理。 实验内容 设计洗衣机洗涤时间的模糊控制。 实验要求 已知人的操作经验为: “污泥越多,油脂越多,洗涤时间越长”;“…

DDD挤水分和强行加异性为好友-UMLChina建模知识竞赛第4赛季第25轮

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 参考潘加宇在《软件方法》和UMLChina公众号文章中发表的内容作答。在本文下留言回答。 只要最先答对前3题,即可获得本轮优胜。第4题为附加题,对错不影响优胜者…