第四章 ResNet网络详解

news2024/12/29 19:07:31

系列文章目录

第一章 AlexNet网络详解

第二章 VGG网络详解

第三章 GoogLeNet网络详解 

第四章 ResNet网络详解 

第五章 ResNeXt网络详解 

第六章 MobileNetv1网络详解 

第七章 MobileNetv2网络详解 

第八章 MobileNetv3网络详解 

第九章 ShuffleNetv1网络详解 

第十章 ShuffleNetv2网络详解 

第十一章 EfficientNetv1网络详解 

第十二章 EfficientNetv2网络详解 

第十三章 Transformer注意力机制

第十四章 Vision Transformer网络详解 

第十五章 Swin-Transformer网络详解 

第十六章 ConvNeXt网络详解 

第十七章 RepVGG网络详解 

第十八章 MobileViT网络详解 

 


文章目录

  • ResNet网络详解
  • 0. 前言
  • 1. 摘要
  • 2. ResNet网络详解网络架构
    • 1. ResNet_Model.py(pytorch实现)
    • 2.
  • 总结


0、前言


1、摘要

       更深的神经网络更难训练。我们提出了一个残差学习框架,以便训练比以前使用的网络更深。我们将层明确重组为以层输入为参考的学习残差函数,而不是学习无关的函数。我们提供了全面的实证证据,表明这些残差网络更容易优化,并且可以通过大幅增加深度获得更高的准确性。在ImageNet数据集上,我们评估了深度达152层的残差网络,比VGG网络[40]深8倍,但复杂度仍然较低。这些残差网络的组合在ImageNet测试集上实现3.57%的错误率。这个结果赢得了ILSVRC 2015分类任务的第一名。我们还对具有100和1000层的CIFAR-10进行了分析。表示的深度对于许多视觉识别任务具有重要意义。仅仅由于我们极其深的表示,我们在COCO物体检测数据集上获得了28%的相对改进。深层残差网络是我们提交给ILSVRC和COCO 2015比赛的基础,其中我们还赢得了ImageNet检测、ImageNet定位、COCO检测和COCO分割任务的第一名。

2、ResNet网络结构

1.本文介绍了一种深度神经网络的学习框架,使得更深层次的网络更容易训练和优化。

2.本文的研究背景是深度神经网络在训练和优化方面的困难。

3.本文的主要论点是,通过引入残差学习框架,可以更轻松地训练和优化深度神经网络,并在多个视觉识别任务中获得更高的准确率。

4.过去的研究主要采用传统的前向传播方法来训练和优化神经网络,但在深度增加时很难解决梯度消失和梯度爆炸的问题。这些方法也容易导致网络过拟合和训练缓慢。

5.本文提出的方法是引入残差学习框架,重定义网络层的学习方式为学习残差函数,从而实现更深层次网络的训练和优化。

6.研究发现,在多个视觉识别任务中,通过残差学习框架训练的深度神经网络可以获得更高的准确率。该方法在2015年的多个比赛中获得了第一名的好成绩。但是,由于本文研究主要关注残差学习框架在视觉识别任务中的应用,因此在其他领域的适用性还需要进一步探讨。

ResNet网络解决了深度神经网络训练过程中的梯度消失和梯度爆炸问题。梯度消失问题是由于当神经网络过深时,反向传播算法中的梯度值会变得非常小,可能会接近于0。这使得神经网络无法更新权重,从而无法继续学习。而梯度爆炸则是相反的问题,即梯度值过大,导致权重更新过快,从而使得网络失去稳定性。 ResNet的创新点是通过引入残差连接(residual connections)解决了梯度消失和梯度爆炸问题。残差连接直接连接了网络中前一层和后一层的输出,使得后一层不仅学习新的特征,还能保留前一层学习到的特征。这种方法使得神经网络的训练变得更加稳定,深度网络也变得更易于训练。此外,ResNet还创新性地使用了1x1卷积来降低特征图的维度,从而减少了模型参数,提高了效率。

1.ResNet_Model.py(pytorch实现)

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128*6*6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

2.train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    data_root = 'D:/100_DataSets/'
    image_path = os.path.join(data_root, "03_flower_data")
    assert os.path.exists(image_path), "{} path does not exits.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = data_transform['train'])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    batch_size = 6
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloder workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform['val'])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                 batch_size=4,
                                                 shuffle=False,
                                                 num_workers=nw)
    print("using {} image for train, {} images for validation.".format(train_num, val_num))
    net = AlexNet(num_classes=5, init_weights=True)
    net.to(device)
    loss_fuction = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_fuction(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:,.3f}".format(epoch+1, epochs, loss)
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        print('[epoch % d] train_loss: %.3f val_accuracy: %.3f' %
              (epoch+1, running_loss / train_steps, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(),save_path)
    print("Finished Training")

if __name__ == '__main__':
    main()






3.predict.py

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
    )
    img_path = "D:/20_Models/01_AlexNet_pytorch/image_predict/tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
    with open(json_path,"r") as f:
        class_indict = json.load(f)
    model = AlexNet(num_classes=5).to(device)
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' does not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))
    model.eval()
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
    print_res = "class: {} prob: {:.3f}".format(class_indict[str(predict_cla)],
                                                predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))
    plt.show()

if __name__ == '__main__':
    main()

4.predict.py

import os
from shutil import copy, rmtree
import random

def mk_file(file_path: str):
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)

def main():
    random.seed(0)
    split_rate = 0.1
    #cwd = os.getcwd()
    #data_root = os.path.join(cwd, "flower_data")
    data_root = 'D:/100_DataSets/03_flower_data'
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist".format(origin_flower_path)
    flower_class = [cla for cla in os.listdir(origin_flower_path) if os.path.isdir(os.path.join(origin_flower_path, cla))]
    train_root = os.path.join(data_root,"train")
    mk_file(train_root)
    for cla in flower_class:
        mk_file(os.path.join(train_root, cla))
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        mk_file(os.path.join(val_root,cla))
    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path,cla)
        images = os.listdir(cla_path)
        num = len(images)
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{} / {}]".format(cla, index+1, num), end="")
        print()
    print("processing done!")
    
if __name__ == "__main__":
    main()

总结

提示:这里对文章进行总结:

每天一个网络,网络的学习往往具有连贯性,新的网络往往是基于旧的网络进行不断改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/672820.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

.net 软件开发模式——三层架构

三层架构是一种常用的软件开发架构模式,它将应用程序分为三个层次:表示层、业务逻辑层和数据访问层。每一层都有明确的职责和功能,分别负责用户交互、业务处理和数据存储等任务。这种架构模式的优点包括易于维护和扩展、更好的组织结构和代码…

Vue的详细安装教程,使用NVM安装是我所推荐的方式

第一步:卸载之前安装的node,安装NVM 进入网站:Releases coreybutler/nvm-windows GitHub 选择一个版本进行安装即可 安装的路径我就不用多说了,全英文路径,尽量不要安装在c盘上,计算机人都懂为什么&…

vmware17+ubuntu18.04通过qemu8.0.2启动arm64虚拟机-测试vsock

文章目录 一、环境搭建1.qemu-8.0.22.buildroot配置 3.编译工具链gcc-linaro-7.2.1下载交叉编译工具链 4.linux kernel 5.16config_kernel.sh配置内核build_kernel.sh 5.启动虚拟机(1)创建磁盘镜像文件(2)拷贝内核镜像和根文件系统…

ESP32(MicroPython)端午节项目

本程序致敬了屏幕驱动例程,依次以4种字体显示Happy Dragon Boat Festival!,并重复一次。 代码如下 from ili934xnew import ILI9341, color565 from machine import Pin, SPI import m5stack import tt14 import glcdfont import tt14 import tt24 imp…

JDK自带的构建线程池的方式之newWorkStealingPool

newWorkStealingPool和之前的几种线程池的创建方式有很大的不同,之前定长、单例、缓存、定时任务的四大线程池都是基于ThreadPoolExecutor去实现的。newWorkStealingPool则是基于ForkJoinPool的方式构建出来的。 ThreadPoolExecutor的核心特点 只有一个阻塞队列Dela…

深度学习(24)——YOLO系列(3)

深度学习(24)——YOLO系列(3) 文章目录 深度学习(24)——YOLO系列(3)1. BOF(bag of freebies)2. Mosaic data augmentation3. 数据增强4. self-adversarial-training(SAT…

深入浅出MySQL索引

索引 索引在MySQL中是举足轻重的。在添加索引后,我们在MySQL的查询上会极大的提高我们的查询效率,这也是慢查询解决办法之一。 数据结构 最初的时候MySQL中是采用二叉树进行插入数据的,这样的缺点很明显,就是树太高了&#xff…

C++QT入门

CQT 文章目录 CQT1. QT概述1.1 什么是QT1.2 QT的发展史1.3 支持的平台1.4 QT版本1.5 Qt 的下载与安装1.6 QT的优点1.7 成功案例 2. 创建QT项目2.1 使用向导创建2.2 手动创建2.3 .pro 文件2.4 设置父对象2.5 按钮设置属性2.5.1 按钮设置文本2.5.2 设置移动2.5.3 设置固定大小 2.…

[进阶]网络通信:UDP通信,一发一收、多发多收

UDP通信 特点:无连接、不可靠通信。不事先建立连接;发送端每次把要发送的数据(限制在64KB内)、接收端1P、等信息封装成一个数据包,发出去就不管了。Java提供了一个java.net.Datagramsocket类来实现UDP通信。 Datagram…

Qt/C++使用QUiLoader动态加载ui资源文件

目录 动态对话框使用场景注意事项动态对话框加载获取动态对话框的控件对象与动态对话框建立关联动态修改ui资源文件效果测试 动态对话框 动态对话框(dynamic dialog)就是在程序运行时使用的从Qt设计师的.ui文件创建而来的那些对话框。动态对话框不需要通过uic把 .ui文件转换成…

【ARM裸机编程 | 海思SS528】- 操作 GPIO 寄存器输出低电平点亮 LED 灯

目录 一、概述二、看原理图,找LED灯的GPIO管脚三、使能 GPIO 管脚功能四、配置 GPIO 管脚为输出五、设置 GPIO 管脚输出高、低电平 一、概述 这篇文章主要介绍在 海思SS528 开发板,去操作某个 GPIO 寄存器输出高、低电平,来熄灭或点亮 LED 灯…

Spring框架中,什么是控制反转?什么是依赖注入?使用控制反转与依赖注入有什么优点

目录 一、Spring 二、控制反转 三、依赖注入 四、控制反转与依赖注入有什么优点 一、Spring Spring框架是一款开源的Java应用程序框架,它为企业级Java应用提供了全面的基础设施支持和编程模型。通过Spring框架,开发人员可以快速地搭建出高效、可维护…

C++【红黑树】

✨个人主页: 北 海 🎉所属专栏: C修行之路 🎃操作环境: Visual Studio 2019 版本 16.11.17 文章目录 🌇前言🏙️正文1、认识红黑树1.1、红黑树的定义1.2、红黑树的性质1.3、红黑树的特点 2、红黑…

chatgpt赋能python:Python中是否有局部变量?-完全解析

Python 中是否有局部变量?- 完全解析 Python 是一种高级编程语言,它因其易学、可读性强、开发速度快、功能丰富、能够快速交互、具有跨平台特性等方面而备受欢迎。其中一块关键功能是变量,变量可以存储值,以供稍后使用&#xff0…

[保姆级啰嗦教程] Tesseract OCR 5在Windows 10下编译安装及测试 (亲测成功)

作为一个优秀的文字识别(OCR)库,Tesseract最早并非开源软件,它是HP实验室在1985-1994年开发的专属软件,直到2005年,HP及内华达大学拉斯维加斯分校以开源的形式发布,然后由Google从2006年开始赞助…

[SpringBoot 分布式调度elasticjob 整合 ]

目录 🥫前言: 🥫配置作业 🥫实现任务处理类 🥫启动SpringBoot应用程序 🥬下面是代码是我另一个文章看见 记录的笔记, 我前面也使用了elastic-job做重试机制,有兴趣可以看一下 🥬依赖: 🥬…

基于MATLAB的CFAR检测仿真程序分享

基于MATLAB的CFAR检测仿真,得到平均CFAR检测。 完整程序: clc; clear; close all; warning off; addpath(genpath(pwd)); cfar phased.CFARDetector(NumTrainingCells,200,NumGuardCells,50,Method,CA); % Expected probability of False Alarm (no u…

【瑞萨RA_FSP】CTSU——电容按键检测

文章目录 一、1. 电容按键介绍二、电容按键原理三、瑞萨QE在电容按键上面的运用四、电容按键实验1. 硬件设计2. FSP配置3.复制文件4.主函数 一、1. 电容按键介绍 电容式感应触摸按键可以穿透绝缘材料外壳 8mm (玻璃、塑料等等)以上,准确无误…

OpenStack(2)--项目(租户)、用户、角色

一、项目(租户)、用户、角色的关系 重点理解项目(project/租户)、用户(user)、角色(role)三者之间的关系,首先这三者都可以单独新建,但是绑定关系是通过open…

10 分钟玩转Elastcisearch——数据可视化分析

在当今这个快速发展的科技时代,Elasticsearch 已经成为企业和开发者的重要技术工具。随着数据的爆发式增长,Elasticsearch 可以帮助个人和企业更好的理解数据、发现数据中的规律趋势和模式、并从海量数据中洞察业务价值。 为了帮助开发者能够快速上手&am…