RestNet详解及在pytorh下基于CIFAR10数据集的实现

news2024/11/22 22:32:21

1 RestNet介绍

        RestNet是2015年由微软团队提出的,在当时获得分类任务,目标检测,图像分割第一名。该论文的四位作者何恺明、张祥雨、任少卿和孙剑如今在人工智能领域里都是响当当的名字,当时他们都是微软亚研的一员。实验结果显示,残差网络更容易优化,并且加深网络层数有助于提高正确率。在ImageNet上使用152层的残差网络(VGG net的8倍深度,但残差网络复杂度更低)。对这些网络使用集成方法实现了3.75%的错误率。获得了ILSVRC 2015竞赛的第一名。

        论文地址:https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf        这是一篇计算机视觉领域的经典论文。李沐曾经说过,假设你在使用卷积神经网络,有一半的可能性就是在使用 ResNet 或它的变种。https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdfResNet 论文被引用数量突破了 10 万+。

2 RestNet网络结构

        ResNet的经典网络结构有:ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152几种,其中,ResNet-18和ResNet-34的基本结构相同,属于相对浅层的网络,后面3种属于更深层的网络,其中RestNet50最为常用。

        残差网络是为了解决深度神经网络(DNN)隐藏层过多时的网络退化问题而提出。退化(degradation)问题是指:当网络隐藏层变多时,网络的准确度达到饱和然后急剧退化,而且这个退化不是由于过拟合引起的。

        假设一个网络 A,训练误差为 x。在 A 的顶部添加几个层构建网络 B,这些层的参数对于 A 的输出没有影响,我们称这些层为 C。这意味着新网络 B 的训练误差也是 x。网络 B 的训练误差不应高于 A,如果出现 B 的训练误差高于 A 的情况,则使用添加的层 C 学习恒等映射(对输入没有影响)并不是一个平凡问题。

        为了解决这个问题,上图中的模块在输入和输出之间添加了一个直连路径,以直接执行映射。这时,C 只需要学习已有的输入特征就可以了。由于 C 只学习残差,该模块叫作残差模块。

        此外,和当年几乎同时推出的 GoogLeNet 类似,它也在分类层之后连接了一个全局平均池化层。通过这些变化,ResNet 可以学习 152 个层的深层网络。它可以获得比 VGGNet 和 GoogLeNet 更高的准确率,同时计算效率比 VGGNet 更高。ResNet-152 可以取得 95.51% 的 top-5 准确率。

         RestNet18和RestNet50网络结构如下:

 3 基于pytorch在CIFAR10数据下的RestNet50的实现

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, utils
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
import torch.nn.functional as F
import datetime


transform = transforms.Compose([ToTensor(),
                                transforms.Normalize(
                                    mean=[0.5, 0.5, 0.5],
                                    std=[0.5, 0.5, 0.5]
                                ),
                                transforms.Resize((224, 224))
                                ])

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform,
)

testing_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transform,
)


class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride=[1, 1, 1], padding=[0, 1, 0], first=False) -> None:
        super(Bottleneck, self).__init__()
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride[0], padding=padding[0], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),  # 原地替换 节省内存开销
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride[1], padding=padding[1], bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),  # 原地替换 节省内存开销
            nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=stride[2], padding=padding[2], bias=False),
            nn.BatchNorm2d(out_channels * 4)
        )

        # shortcut 部分
        # 由于存在维度不一致的情况 所以分情况
        self.shortcut = nn.Sequential()
        if first:
            self.shortcut = nn.Sequential(
                # 卷积核为1 进行升降维
                # 注意跳变时 都是stride==2的时候 也就是每次输出信道升维的时候
                nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, stride=stride[1], bias=False),
                nn.BatchNorm2d(out_channels * 4)
            )

    def forward(self, x):
        out = self.bottleneck(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


# 采用bn的网络中,卷积层的输出并不加偏置
class ResNet50(nn.Module):
    def __init__(self, Bottleneck, num_classes=10) -> None:
        super(ResNet50, self).__init__()
        self.in_channels = 64
        # 第一层作为单独的 因为没有残差快
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # conv2
        self.conv2 = self._make_layer(Bottleneck, 64, [[1, 1, 1]] * 3, [[0, 1, 0]] * 3)

        # conv3
        self.conv3 = self._make_layer(Bottleneck, 128, [[1, 2, 1]] + [[1, 1, 1]] * 3, [[0, 1, 0]] * 4)

        # conv4
        self.conv4 = self._make_layer(Bottleneck, 256, [[1, 2, 1]] + [[1, 1, 1]] * 5, [[0, 1, 0]] * 6)

        # conv5
        self.conv5 = self._make_layer(Bottleneck, 512, [[1, 2, 1]] + [[1, 1, 1]] * 2, [[0, 1, 0]] * 3)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, num_classes)

    def _make_layer(self, block, out_channels, strides, paddings):
        layers = []
        # 用来判断是否为每个block层的第一层
        flag = True
        for i in range(0, len(strides)):
            layers.append(block(self.in_channels, out_channels, strides[i], paddings[i], first=flag))
            flag = False
            self.in_channels = out_channels * 4

        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)

        out = self.avgpool(out)
        out = out.reshape(x.shape[0], -1)
        out = self.fc(out)
        return out


if __name__ == "__main__":
    res50 = ResNet50(Bottleneck)

    batch_size = 64
    train_data = DataLoader(dataset=training_data, batch_size=batch_size, shuffle=True, drop_last=True)
    test_data = DataLoader(dataset=testing_data, batch_size=batch_size, shuffle=True, drop_last=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = res50.to(device)
    cost = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())

    epochs = 10
    for epoch in range(epochs):
        running_loss = 0.0
        running_correct = 0.0
        model.train()
        print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, Epoch {epoch}/{epochs}")
        for X_train, y_train in train_data:
            X_train, y_train = X_train.to(device), y_train.to(device)
            outputs = model(X_train)
            _, pred = torch.max(outputs.data, 1)
            optimizer.zero_grad()
            loss = cost(outputs, y_train)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            running_correct += torch.sum(pred == y_train.data)

        testing_correct = 0
        test_loss = 0
        model.eval()
        for X_test, y_test in test_data:
            X_test, y_test = X_test.to(device), y_test.to(device)
            outputs = model(X_test)
            loss = cost(outputs, y_test)
            _, pred = torch.max(outputs.data, 1)
            testing_correct += torch.sum(pred == y_test.data)
            test_loss += loss.item()

        print("{}, Train Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Loss is::{:.4f} Test Accuracy is:{:.4f}%".format(
            datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            running_loss / len(training_data),
            100 * running_correct / len(training_data),
            test_loss / len(testing_data),
            100 * testing_correct / len(testing_data)
        ))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/540715.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32单片机蓝牙APP自动伸缩遮阳棚雨伞雨滴角度温度光强控制

实践制作DIY- GC0130-蓝牙APP自动伸缩遮阳棚 一、功能说明: 基于STM32单片机设计-蓝牙APP自动伸缩遮阳棚 二、功能介绍: 基于STM32F103C系列,LCD1602显示器,光敏电阻采集光强,雨滴传感器,ULN2003控制步进…

MySQL:5.6同步到5.7 GTID报错

问题描述和处理 同步到的版本为5.7.35,按理说在5.7种还是一个比较新的版本了,报错大概如下: 2023-05-14T05:09:47.427031Z 12 [Note] Multi-threaded slave statistics for channel : seconds elapsed 163; events assigned 67585; worke…

GD32 系列FLASH锁死解决.

1.背景描述 使用keil开发工具JLINK调试过程中偶尔出现找不到目标版,或存在目标版但keil调试烧录出现如下界面: 2.问题查询步骤 2.1检查jlink连接线是否异常; 2.2确定boot0和boot1设置是否正确; 2.3确定是否是flash读写保护 2.3.1…

K8s进阶2——二进制搭建K8s高可用集群

文章目录 一、单master资源清单二、系统初始化三、部署etcd集群3.1 生成etcd证书3.2 部署流程3.2.1 准备二进制安装文件3.2.2 创建工作目录3.2.3 创建etcd配置文件3.2.4 设置成systemd服务3.2.5 添加etcd-2和etcd-3节点3.2.6 所有节点启动etcd并设置开机启动 四、安装容器引擎&…

webpack基础

1. 当面试官问Webpack的时候他想知道什么 前言 在前端工程化日趋复杂的今天,模块打包工具在我们的开发中起到了越来越重要的作用,其中webpack就是最热门的打包工具之一。 说到webpack,可能很多小伙伴会觉得既熟悉又陌生,熟悉是…

java汽车4S店管理系统myeclipse定制开发oracle数据库网页模式java编程jdbc

一、源码特点 java汽车4S店管理系统 是一套完善的web设计系统,对理解JSP java编程开发语言有帮助 oracle数据库,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 java汽车4S店管理系统myeclipse定制开发orac 二、功能介绍 此次系统…

漏扫工具-xray 1.9.10(文末附下载)

一、工具介绍 一款功能强大的安全评估工具 二、使用说明 1.使用基础爬虫爬取并对爬虫爬取的链接进行漏洞扫描 xray webscan --basic-crawler http://example.com --html-output vuln.html 2.使用 HTTP 代理进行被动扫描 xray webscan --listen 127.0.0.1:7777 --html-outp…

Fiddler如何抓取微信小程序的包

1.简介 有些小伙伴或者是童鞋们说小程序抓不到包,该怎么办了???其实苹果手机如果按照宏哥前边的抓取APP包的设置方式设置好了,应该可以轻松就抓到包了。那么安卓手机小程序就比较困难,不是那么友好了。所以…

FMC子卡设计资料原理图450-基于ADRV9009的双收双发射频FMC子卡 数字信号处理卡 射频收发卡 基站应用 便携测试设备

FMCJ450-基于ADRV9009的双收双发射频FMC子卡 一、板卡概述 ADRV9009是一款高集成度射频(RF)、捷变收发器,提供双通道发射器和接收器、集成式频率合成器以及数字信号处理功能。这款IC具备多样化的高性能和低功耗组合,FMC子卡为2路输入,…

MySQL高级_第08章_索引的创建与设计原则

MySQL高级_第08章_索引的创建与设计原则 1. 索引的声明与使用 1.1 索引的分类 MySQL 的索引包括普通索引、唯一性索引、全文索引、单列索引、多列索引和空间索引等。 从 功能逻辑 上说,索引主要有 4 种,分别是普通索引、唯一索引、主键索引、全文索…

新手如何重装Win10系统 新手重装Win10系统的方法

电脑系统是电脑运行的核心,如果出现问题就需要重装系统。对于新手来说,重装电脑系统可能会显得比较困难和陌生。本文将介绍新手如何重装电脑系统Win10,让电脑新手也能轻松搞定。 新手重装Win10系统的方法 一、准备工作 1、下载极客狗电脑重…

canvas、svg的基本使用【数据可视化】

什么是数据可视化? 基本概念:是关于数据视觉表现形式的科学技术研究 这个概念向我们传达了两个信息: (1)数据可视化是一门学科 (2)数据可视化与数据和视觉有关 数据可视化简单理解,…

veth网卡的多队列及RPS

背景: 3.10内核下容器使用的veth网卡,默认开启的是一个队列,导致在某些单线程多TCP链接的应用场景下,出现某个CPU软中断高的情况。之前处理的方案一直是开启这个veth网卡的RPS,让其在多流场景下可以去分散到其它CPU上…

DSSM - 双塔经典模型(微软)

《Learning Deep Structured Semantic Models for Web Search using Clickthrough Data》论文由微软发表于 CIKM-2013。DSSM被广泛用于工业界的 召回/粗排 阶段。 模型结构 模型结构一目了然,非常简单,双塔结构:user侧一个塔,ite…

ChatGPT的兴起的时代,国内chatgpt产品大盘点

在人工智能技术的不断发展和应用下,自然语言处理技术成为了研究的热点之一。而其中最受关注的就是“聊天机器人”技术,而GPT(Generative Pre-trained Transformer)模型则是目前最流行的聊天机器人生成模型之一。 随着 ChatGPT 技…

蓝牙RSSI/BLE AOA/UWB室内定位技术哪个好?

蓝牙AOA定位技术的出现,弥补了蓝牙RSSI值定位精度不高的缺陷。从理论上来说,可以对目前的蓝牙RSSI定位方案进行一定程度的替代。当然了,在高精度定位应用领域中,UWB定位已经在批量的成熟商用了。蓝牙AOA也具有很高的定位精度&…

单位网站被黑被下达整改进行行政处罚

最近这几年,由于信息系统安全等级保护法的普及,越来越多公司收到当地公安网监部门打来的电话,说你们公司网站有漏洞,需要限期在2-3内进行漏洞整改和加固,遇到这种情况,不要着急,下面来分享一下该…

JavaEE(系列8) -- 多线程案例(单例模式)

目录 1. 设计模式 2. 单例模式 -- 饿汉模式 3. 单例模式 -- 懒汉模式 4. 单例模式(懒汉模式-多线程) 1. 设计模式 什么是设计模式? 设计模式好比象棋中的 "棋谱". 红方当头炮, 黑方马来跳. 针对红方的一些走法, 黑方应招的时候有一些固定的套路. 按照套路…

【融合感知】激光雷达和相机融合感知-BEVFusion

BEVFusion有两篇文章,这里在一起分析下不同,分别是: 【1】BEVFusion: A Simple and Robust LiDAR-Camera Fusion Framework. 【2】BEVFusion: Multi-Task Multi-Sensor Fusion with Unified Bird’s-Eye View Representation 先说结论&…

品牌联名又出圈了!小红书数据揭示,引爆流量三部曲

这几天,你们的朋友圈是不是被喜茶FENDI黄刷屏啦?近日,茶饮品牌牵手意大利奢侈品牌联名上新,一跃成为各平台热门。 品牌新联名,这次又出圈了! 喜茶可谓是联名界的老玩家了,曾与藤原浩、《梦华录》…