第五章 ResNeXt网络详解

news2025/1/22 17:51:43

系列文章目录

第一章 AlexNet网络详解

第二章 VGG网络详解

第三章 GoogLeNet网络详解 

第四章 ResNet网络详解 

第五章 ResNeXt网络详解 

第六章 MobileNetv1网络详解 

第七章 MobileNetv2网络详解 

第八章 MobileNetv3网络详解 

第九章 ShuffleNetv1网络详解 

第十章 ShuffleNetv2网络详解 

第十一章 EfficientNetv1网络详解 

第十二章 EfficientNetv2网络详解 

第十三章 Transformer注意力机制

第十四章 Vision Transformer网络详解 

第十五章 Swin-Transformer网络详解 

第十六章 ConvNeXt网络详解 

第十七章 RepVGG网络详解 

第十八章 MobileViT网络详解 


文章目录

  • ResNeXt网络详解
  • 0. 前言
  • 1. 摘要
  • 2. ResNeXt网络详解网络架构
    • 1. ResNeXt_Model.py(pytorch实现)
    • 2.
  • 总结


0、前言


1、摘要

      我们提出了一种简单、高度模块化的图像分类网络架构。我们的网络由重复的构建块构建,这些构建块聚集了一组具有相同拓扑的变换。我们的简单设计导致了一个具有少量超参数的同质、多分支架构。这种策略展示了一个新的维度,我们称之为“基数”(变换集的大小),它是除了深度和宽度之外的一种关键因素。在ImageNet-1K数据集上,我们经验证明,即使在保持复杂性的限制条件下,增加基数也能提高分类准确性。此外,当我们增加容量时,增加基数比增加深度或宽度更有效。我们的模型名为ResNeXt,是我们参加ILSVRC 2016分类任务的基础,我们获得了第二名。我们进一步对ResNeXt进行了ImageNet-5K集和COCO检测集的研究,结果比ResNet更好。代码和模型可在网上公开获取。

2、ResNeXt网络结构

1.本文介绍了一个高度模块化的图像分类网络结构,名为ResNeXt,通过增加变换的张量大小(cardinality)提高准确率。

2.本文研究的背景是图像分类网络结构设计和性能优化。

3.本文的主要论点是增加变换的张量大小可以提高图像分类网络的准确率。

4.以往的研究主要集中在增加网络深度或宽度来提高性能,但这样会增加计算复杂度和运算时间。本文提出的方法是增加变换的张量大小,这样可以在保持网络复杂度不变的前提下提高分类准确率。

5.本文的研究方法是构建一个多分支的图像分类网络结构,通过增加变换的张量大小来提高准确率。实验数据来自ImageNet-1K数据集、ImageNet-5K数据集和目标检测数据集COCO。

6.本文的发现是增加变换的张量大小是提高图像分类网络性能的一种有效方法,但由于实验数据集有限,该方法是否适用于其他数据集需要进一步研究。

1.ResNeXt_Model.py(pytorch实现)

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128*6*6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

2.train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    data_root = 'D:/100_DataSets/'
    image_path = os.path.join(data_root, "03_flower_data")
    assert os.path.exists(image_path), "{} path does not exits.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = data_transform['train'])
    train_num = len(train_dataset)
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    batch_size = 6
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    print('Using {} dataloder workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform['val'])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                 batch_size=4,
                                                 shuffle=False,
                                                 num_workers=nw)
    print("using {} image for train, {} images for validation.".format(train_num, val_num))
    net = AlexNet(num_classes=5, init_weights=True)
    net.to(device)
    loss_fuction = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_fuction(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:,.3f}".format(epoch+1, epochs, loss)
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        print('[epoch % d] train_loss: %.3f val_accuracy: %.3f' %
              (epoch+1, running_loss / train_steps, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(),save_path)
    print("Finished Training")

if __name__ == '__main__':
    main()






3.predict.py

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
    )
    img_path = "D:/20_Models/01_AlexNet_pytorch/image_predict/tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
    with open(json_path,"r") as f:
        class_indict = json.load(f)
    model = AlexNet(num_classes=5).to(device)
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' does not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))
    model.eval()
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
    print_res = "class: {} prob: {:.3f}".format(class_indict[str(predict_cla)],
                                                predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))
    plt.show()

if __name__ == '__main__':
    main()

4.predict.py

import os
from shutil import copy, rmtree
import random

def mk_file(file_path: str):
    if os.path.exists(file_path):
        rmtree(file_path)
    os.makedirs(file_path)

def main():
    random.seed(0)
    split_rate = 0.1
    #cwd = os.getcwd()
    #data_root = os.path.join(cwd, "flower_data")
    data_root = 'D:/100_DataSets/03_flower_data'
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist".format(origin_flower_path)
    flower_class = [cla for cla in os.listdir(origin_flower_path) if os.path.isdir(os.path.join(origin_flower_path, cla))]
    train_root = os.path.join(data_root,"train")
    mk_file(train_root)
    for cla in flower_class:
        mk_file(os.path.join(train_root, cla))
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        mk_file(os.path.join(val_root,cla))
    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path,cla)
        images = os.listdir(cla_path)
        num = len(images)
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{} / {}]".format(cla, index+1, num), end="")
        print()
    print("processing done!")
    
if __name__ == "__main__":
    main()

总结

提示:这里对文章进行总结:

每天一个网络,网络的学习往往具有连贯性,新的网络往往是基于旧的网络进行不断改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/674916.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络套接字函数 | socket、bind、listen、accept、connect

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起学习和分享Linux、C、C、Python、Matlab,机器人运动控制、多机器人协作,智能优化算法,滤波估计、多传感器信息融合,机器学习,人工智能等相关领域的知识和…

CSS3-补充-结构伪类选择器

结构伪类选择器 作用:在HTML中定位元素 优势:减少对于HTML中类的依赖,有利于保持代码整洁 场景:常用于查找某父级选择器中的子元素 选择器: 选择器 …

SAC算法小结

算法SAC 基于动态规划的贝尔曼方城如下所示: 则,基于最大熵的软贝尔曼方程可以描述为如下的形式: 可以这么理解soft贝尔曼方程,就是在原有的贝尔曼方程的基础上添加了一个熵项。 另外一个角度理解soft-贝尔曼方程: …

Vue-组件自定义事件(绑定和解绑)

组件自定义事件(绑定) 像click,change这些都是js的内置事件,我们可以直接使用,本次我们学习自己根据需求打造全新的事件,但是js内置的是给html元素用的,本次的自定义事件是给组件用的 注意:组件上也可以绑定原生DOM事件&#xf…

(十一)CSharp-LINQ(1)

一、LINQ 数据库可以通过 SQL 进行访问,但在程序中,数据要被保存在差异很大的类对象或结构中。由于没有通用的查询语言来从数据结构中获取数据。所以可以使用 LINQ 可以很轻松地查询对象集合。 LINQ 高级特性: LINQ 代表语言集成查询。LIN…

【机器学习】信息熵和信息度量

一、说明 信息熵是概率论在信息论的应用,它简洁完整,比统计方法更具有计算优势。在机器学习中经常用到信息熵概念,比如决策树、逻辑回归、EM算法等。本文初略介绍一个皮毛,更多细节等展开继续讨论。 二、关于信息熵的概念 2.1 …

尚硅谷课程vue学习(一)

目录 data两种写法el两种写法由vue管理的函数,一定不要写箭头函数,不然this指向windows实例了MVVM模型defineProperty属性数据代理v-on: v-bind:键盘事件keyup keydowncomputed计算属性监视属性watch监视属性和计算属性区别绑定class和style属性条件渲染…

cocosCreator 3.3~6 安卓热更新官方详细示例

官方的热更新虽给出了示例和源码,但是一些细节的地方和步骤还是没说清楚,导致新手包括我死活是运行不起来,热更新失败!很打击人啊。这里有必要给出新手的热更新步骤,前提是你安装了Node.js和python环境,我装…

chatgpt赋能python:更新Python所有库,避免安全漏洞和兼容性问题!

更新 Python 所有库,避免安全漏洞和兼容性问题! Python 是当今最受欢迎的编程语言之一,拥有强大而多功能的 API 和丰富的第三方库来支持开发,如 numpy、pandas、tensorflow 等等。但是,这些库不断地更新与改进&#x…

端午作业1

只要文件存在,就会有唯一对应的inode号,且相应的会存在一个struct inode结构体。在应用层通过open()打开一个设备文件,会对应产生一个inode号,通过inode号可以找到文件的inode结构体 根据inode结构体中文件…

【Dart语言解密】想要深入了解Dart语法和类型变量吗?

快来读读这篇文章吧!本文从Dart信息表示的角度出发,详细讲解了Dart的基础语法和类型变量。通过本文的学习,你将会对Dart语言有更深入的认识和理解,更好地掌握Dart的开发技巧和实践应用。快来一起解密Dart语言吧! 1 Da…

数据透视表 - 学习笔记

教程资源:数据透视表_哔哩哔哩_bilibili 目录 一、内容概括 数据操作: 案例: 二、数据操作 (一)数据清洗 (二)创建数据透视表 1、数据格式 2、显示方式 3、分组 4、修改数据源 5、…

Web 安全之 HSTS 详解和使用

HSTS(HTTP Strict Transport Security) 是一种网络安全机制,可用于防范网络攻击,例如中间人攻击和 CSRF(Cross-Site Request Forgery)等攻击。本文将详细介绍 HSTS 的工作原理、应用场景以及如何在网站中开…

【计算机视觉 | 目标检测】arxiv 计算机视觉关于分类和分割的学术速递(6月 22 日论文合集)

文章目录 一、分类相关(4篇)1.1 Annotating Ambiguous Images: General Annotation Strategy for Image Classification with Real-World Biomedical Validation on Vertebral Fracture Diagnosis1.2 Benchmark data to study the influence of pre-training on explanation pe…

无需专业知识!学会用TensorFlow 2实现天气识别的秘诀

💡《目标识别100例》使用的是Python语言、TensorFlow框架,包含了几十种CNN算法案例💎 附有 🖥 源码 ,可一键运行,避免调试烦恼🏆 课程大作业、毕业论文可直接考借鉴🎈 同时 附带各种算法原理及对应的代码教程,用户可根据自身情况快速排列组合,在不同的数据集上实…

从零开始:入门双目视觉你需要了解的知识

文章目录 前言 双目相机标定去畸变极线校正(立体校正)立体匹配深度图生成文章已经同步更新在3D视觉工坊啦,原文链接如下: 前言 双目立体视觉是计算机视觉中的一个重要领域,它利用两个相机拍摄同一场景的不同视角的图像…

HDLBits笔记5:Circuits.Combinational Logic.Basic gates

Wire 实现一个电路完成in和out的连线 module top_module (input in, output out);assign out in; endmoduleGND 实现一个电路将out连到GND module top_module (output out);assign out 1b0; endmoduleNOR 实现或非门 module top_module (input in1,input in2,output ou…

Vue-消息订阅与发布(pub/sub)

消息订阅与发布(pub/sub) 消息订阅与发布和全局事件总线一样,也是一种组件间通信的方式 pub/sub全称为publisher(订阅)/subscriber(发布),一般需要数据的人订阅消息,提供数据的人发布消息 这个技术非常简单容易上手,主要有以下两步 1 订阅…

Java集合之ArrayList详解

Java集合之ArrayList 一、ArrayList类的继承关系1. 基类功能说明1.1. Iterator:提供了一种方便、安全、高效的遍历方式。1.2. Collection:为了使ArrayList具有集合的基本特性和操作。1.3. AbstractCollection:提供了一些通用的集合操作。1.4.…

Vue-动画效果

vue动画效果 vue中动画效果是很简单的一个东西,vue帮助我们做了一些动画封装,同时也支持自定义动画,过度,第三方库,这些方式都可以实现,我们一一举例说明 注意:下面的相关截图,由于…