使用ResNet50实现CIFAR100数据集的训练

news2025/1/15 6:21:37

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄     

        使用ResNet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到32X32,因为算力相关的问题所以我选择了较低的训练图像分辨率。但是假如你自己的算力比较充足的话,我建议使用训练的使用图像的分辨率设置为224X224(这个可以在代码里面的transforms.RandomResizedCrop(32)和transforms.Resize((32, 32)),进行修改,很简单),因为在测试训练的时候,发现将CIFAR10数据集的分辨率拉大可以让模型更快地进行收敛,并且识别的效果也是比低分辨率的更加好。

首先来介绍一下,ResNet:

1.论文下载地址:https://arxiv.org/pdf/1512.03385.pdf 

2.ResNet的介绍:

 

代码实现:

数据集的处理:
        调用torchvision里面封装好的数据集进行数据的训练,并且利用官方已经做好的数据集分类是数据集的划分大小。进行了一些简单的数据增强,分别是随机的随机剪切和随机的水平拉伸操作。

模型的代码结构目录:

train.py文件内容: 

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm



import torchvision

from model import resnet50
import os
import parameters
import function
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm



def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    epochs = parameters.epoch
    save_model = parameters.resnet_save_model
    save_path = parameters.resnet_save_path_CIFAR100


    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(32),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

        "val": transforms.Compose([transforms.Resize((32, 32)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        }

    train_dataset = torchvision.datasets.CIFAR100(root='./data/CIFAR100', train=True,
                                                download=True, transform=data_transform["train"])

    val_dataset = torchvision.datasets.CIFAR100(root='./data/CIFAR100', train=False,
                                           download=False, transform=data_transform["val"])


    train_num = len(train_dataset)
    val_num = len(val_dataset)
    print("using {} images for training, {} images for validation.".format(train_num, val_num))
    # #################################################################################################################

    batch_size = parameters.batch_size

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    # ##################################################################################################################
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             )

    model = resnet50(num_classes=parameters.CIFAR100_class)
    model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=parameters.resnet_lr)
    best_acc = 0.0

    # 为后面制作表图
    train_acc_list = []
    train_loss_list = []
    val_acc_list = []

    for epoch in range(epochs):
        # train
        model.train()
        running_loss_train = 0.0
        train_accurate = 0.0
        train_bar = tqdm(train_loader)
        for images, labels in train_bar:
            optimizer.zero_grad()

            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            predict = torch.max(outputs, dim=1)[1]
            train_accurate += torch.eq(predict, labels.to(device)).sum().item()
            running_loss_train += loss.item()

        train_accurate = train_accurate / train_num
        running_loss_train = running_loss_train / train_num
        train_acc_list.append(train_accurate)
        train_loss_list.append(running_loss_train)

        print('[epoch %d] train_loss: %.7f  train_accuracy: %.3f' %
              (epoch + 1, running_loss_train, train_accurate))

        # validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_loader = tqdm(val_loader)
            for val_data in val_loader:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        val_acc_list.append(val_accurate)
        print('[epoch %d] val_accuracy: %.3f' %
              (epoch + 1, val_accurate))

        function.writer_into_excel_onlyval(save_path, train_loss_list, train_acc_list, val_acc_list,"CIFAR100")

        # 选择最best的模型进行保存 评价指标此处是acc
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_model)



if __name__ == '__main__':
    main()

model.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:24
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)

        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(512 * block.expansion, 512),  # [2 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            # x = self.fc(x)
            # print((x.shape()))
            x = self.classifier(x)

        return x


class AlexnetChange(nn.Module):
    def __init__(self, ):
        super(AlexnetChange, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=13, stride=2, padding=0),                  # output[128, 1, 1]
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 1 * 1, 512),        # [batchsize值 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)      # output[512, 1, 1]

        return x



class Classifier(nn.Module):
    def __init__(self, num_classe=1000):
        super(Classifier, self).__init__()

        self.FC = nn.Sequential(
            nn.Linear(512 * 1 * 1, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classe),
        )

    def forward(self, x1=None, x2=None):
        if x1 != None and x2 != None:
            x = x1.add(x2)
            x = self.FC(x)
            # print("x1 add x2  ")
        elif x1 != None and x2 == None:
            x = self.FC(x1)
            # print("only x1  ")
        elif x1 == None and x2 != None:
            x = self.FC(x2)
            # print("only x2  ")
        else:
            print("Alexnet_Con has wrong")

        return x



def resnet18(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

function.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import xlwt



def writer_into_excel_onlyval(excel_path,loss_train_list, acc_train_list, val_acc_list,dataset_name:str=""):
    workbook = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    worksheet = workbook.add_sheet("sheet1", cell_overwrite_ok=True)  # 新增一个sheet
    worksheet.write(0, 0, label='Train_loss')
    worksheet.write(0, 1, label='Train_acc')
    worksheet.write(0, 2, label='Val_acc')


    for i in range(len(loss_train_list)):  # 循环将a和b列表的数据插入至excel
        worksheet.write(i + 1, 0, label=loss_train_list[i])  # 切片的原来是传进来的Imgs是一个路径的信息
        worksheet.write(i + 1, 1, label=acc_train_list[i])
        worksheet.write(i + 1, 2, label=val_acc_list[i])


    workbook.save(excel_path + str(dataset_name) +".xls")  # 这里save需要特别注意,文件格式只能是xls,不能是xlsx,不然会报错
    print('save success!   .')



parameters.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm




# 训练的次数
epoch = 2

# 训练的批次大小
batch_size = 4

# 数据集的分类类别数量
CIFAR100_class = 100

# 模型训练时候的学习率大小
resnet_lr = 0.002

# 保存模型权重的路径 保存xml文件的路径
resnet_save_path_CIFAR100 = './res/'
resnet_save_model = './res/best_model.pth'

其中部分参数,例如是学习率的大小,训练的批次大小,数据增强的一些小参数,可以根据自己的经验和算力的现实情况进行调整。

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/156249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯2019年第十届省赛C++B组

文章目录A:组队(5分 √)B:年号字串(5分 √)C: 数列求值(10分 √)D: 数的分解(10分 )F: 特别数的和(15分 √)A:组队&#x…

【蓝桥杯简单篇】Python组刷题日寄Part05

刷题日记?刷题日寄! 萌新备战蓝桥杯python组 🌹 发现有需要纠正的地方,烦请指正! 🚀 欢迎小伙伴们的三连关注! 往期系列: 【蓝桥杯简单篇】Python组刷题日寄Part01 【蓝桥杯简单篇】…

Oracle Mysql审计日志等保测评

mysql oracle的审计日志 mysql的审计日志说是有两种方法,一种是需要安装插件模式的审计,一种直接开一个参数 1.安装插件模式 一、 mysql 日志 配置永久配置 : 保存时间及大小 https://blog.csdn.net/m0_51197424/article/details/12432840…

设计模式(二)----软件设计原则

在软件开发中,为了提高软件系统的可维护性和可复用性,增加软件的可扩展性和灵活性,要尽量根据7条原则来开发程序,从而提高软件开发效率、节约软件开发成本和维护成本。 1、单一职责原则 ( 核心:尽量保证类&#xff0…

毫米波雷达和视觉融合的学习路线

了解各个传感器的成像原理,知其所以然,同时了解每种传感器的对比及优缺点,为什么要用这几种融合,可以通过去看一些德州仪器的雷达原理视频(b站),雷达工作手册等。先广而后深:了解经典…

C#-WPF介绍-创建项目-添加按钮等及其事件处理、属性设置

微软官网指导链接:适用于 .NET 5 的 Windows Presentation Foundation 文档 | Microsoft Learn WPF框架介绍:Windows Presentation Foundation 简介 - WPF .NET | Microsoft Learn WPF介绍 WPF(Windows Presentation Foundation&#xff09…

python3在window上运行或安装模块各种问题

1. 在window上运行celery各种奇怪的问题 如出现错误: ValueError: not enough values to unpack (expected 3, got 0) 请先安装如下模块 pip install eventlet 启动时,带上改模块,指定为运行参数 celery -A tasks worker --loglevelinfo -P …

基于小程序云开发的智慧物业、智慧小区微信小程序,在线报修报检,重大事项投票,报名参加小区活动,小区公告通知,业委会公示、租售房屋

功能介绍 完整代码下载地址:基于小程序云开发的智慧物业、智慧小区微信小程序 当前小区的物业事务越来越多、越来越杂,而很多业主工作繁重,加班很晚,以往对于重大事项投票,报修报检,装修申请,…

大数据技术之SparkSQL(超级详细)

第1章 Spark SQL概述 1.1什么是Spark SQL Spark SQL是Spark用来处理结构化数据的一个模块,它提供了2个编程抽象:DataFrame和DataSet,并且作为分布式SQL查询引擎的作用。 它是将Hive SQL转换成MapReduce然后提交到集群上执行,大大…

易盾滑块再再再试

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录前言文章推荐自己的坑js部分效果展示前言 声明:本文只作学习研究,禁止用于非法用途,否则后果自负,如有侵权&#xff…

php宝塔搭建部署实战响应式自动化设备科技企业网站源码

大家好啊,我是测评君,欢迎来到web测评。 本期给大家带来一套php开发的响应式自动化设备科技企业网站源码,感兴趣的朋友可以自行下载学习。 技术架构 PHP7.2 nginx mysql5.7 JS CSS HTMLcnetos7以上 宝塔面板 文字搭建教程 下载源码&…

我眼中的偶数数据库 OushuDB

各位大家好,在论坛跟大家学习也有一段时间了,今天来聊聊我眼中的偶数数据库 ~ 首先,先来介绍介绍我和偶数的故事(其实没有什么故事,只是一些交集片段)。 2015 年我开始接触 Greenplum&#xf…

Spring Boot 配置文件

Spring Boot 配置文件一、配置文件作用二、配置文件的格式三、properties 配置文件说明3.1 properties 基本语法3.2 读取配置文件3.3 properties 优缺点分析四、yml 配置文件说明4.1 yml 优点分析4.2 yml 基本语法4.3 yml 基本配置读取4.4 配置对象与读取4.5 配置集合与读取五、…

时序预测 | Python实现XGBoost极限梯度提升树股票价格预测

时序预测 | Python实现XGBoost极限梯度提升树股票价格预测 目录 时序预测 | Python实现XGBoost极限梯度提升树股票价格预测预测效果基本描述环境配置模型描述程序设计参考资料预测效果 基本描述 Python实现XGBoost极限梯度提升树股票价格预测 环境配置 XGboost (0.7) numpy (1.…

负载均衡器Ribbon原理及实战演练

目录 一、负载均衡原理 二、Ribbon 原理及使用 三、Loadbalancer 原理及使用 负载均衡器Ribbon在微服务领域是很常用的服务调用、负载均衡的中间件,其面包含Loadbalancer专门负载负载均衡;比如Eureka、Fegin,Nacos的注册中心jar包里面均包含Ribbon相关的jar,如图…

python实战案例:采集某漫客《网游之近战法师》所有章节

前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 环境使用: Python 3.8 Pycharm 模块使用: requests >>> pip install requests 数据请求模块 parsel >>> pip install parsel 数据解析模块 如果安装python第三方模块: win R 输入 cmd 点击确定, 输…

C++学习之旅 第五章(字符应用:小写字母转大写字母)

开头: 上一节我们讲了关于char类型许多知识,今天我们来更深层的学习一下字符实际上面的应用! ASCII码简介: 我们要进行字符的应用,首先就是要了解一下ACSII码: ASCII(全名:American Standard Code for Inf…

win10如何安装多个jdk并实时进行切换【建议收藏】

在windows10的系统中,如何安装jdk或者安装多个jdk版本,博主在这里整理了一份非常完美的jdk版本安装教程,且jdk版本可以随时切换,切换过程不超过10秒,让你在jdk版本中穿梭自如,直接可以食用,掌握…

【前端】ES6

let 和 const 类似var定义变量&#xff0c;但是let修饰的变量仅在声明的代码块中有效&#xff1b; var声明的变量&#xff0c;在全局有效 for (let i 0; i < 3; i) {let i abc;console.log(i); }js中的for循环声明循环变量的部分也作为一个父作用域&#xff0c;即(let i…

虹科方案|数据中心虚拟化和 HK-ATTO 产品—旨在协同工作的端到端解决方案

一、概述虚拟化技术正迅速成为现代数据中心的基础&#xff0c;因为IT 管理寻求显着提 高资源和运营效率以及对业务需求的响应能力。三个关键技术非常重要&#xff1a;服务器虚拟化、结构虚拟化和存储虚拟化。 本次介绍了 HK-ATTO 产品如何作为这些虚拟化解决方案中 每一个的关键…