使用ResNet50实现CIFAR10数据集的训练

news2025/1/12 16:06:27

  如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄     

        使用ResNet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到32X32,因为算力相关的问题所以我选择了较低的训练图像分辨率。但是假如你自己的算力比较充足的话,我建议使用训练的使用图像的分辨率设置为224X224(这个可以在代码里面的transforms.RandomResizedCrop(32)和transforms.Resize((32, 32)),进行修改,很简单),因为在测试训练的时候,发现将CIFAR10数据集的分辨率拉大可以让模型更快地进行收敛,并且识别的效果也是比低分辨率的更加好。

首先来介绍一下,ResNet:

1.论文下载地址:https://arxiv.org/pdf/1512.03385.pdf 

2.ResNet的介绍:

 

代码实现:

数据集的处理:
        调用torchvision里面封装好的数据集进行数据的训练,并且利用官方已经做好的数据集分类是数据集的划分大小。进行了一些简单的数据增强,分别是随机的随机剪切和随机的水平拉伸操作。

模型的代码结构目录:

train.py文件内容:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm



import torchvision

from model import resnet50
import os
import parameters
import function
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm



def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    epochs = parameters.epoch
    save_model = parameters.resnet_save_model
    save_path = parameters.resnet_save_path_CIFAR10


    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(32),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

        "val": transforms.Compose([transforms.Resize((32, 32)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        }

    train_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=True,
                                                download=True, transform=data_transform["train"])

    val_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=False,
                                           download=False, transform=data_transform["val"])


    train_num = len(train_dataset)
    val_num = len(val_dataset)
    print("using {} images for training, {} images for validation.".format(train_num, val_num))
    # #################################################################################################################

    batch_size = parameters.batch_size

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    # ##################################################################################################################
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             )

    model = resnet50(num_classes=parameters.CIFAR10_class)
    model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=parameters.resnet_lr)
    best_acc = 0.0

    # 为后面制作表图
    train_acc_list = []
    train_loss_list = []
    val_acc_list = []

    for epoch in range(epochs):
        # train
        model.train()
        running_loss_train = 0.0
        train_accurate = 0.0
        train_bar = tqdm(train_loader)
        for images, labels in train_bar:
            optimizer.zero_grad()

            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            predict = torch.max(outputs, dim=1)[1]
            train_accurate += torch.eq(predict, labels.to(device)).sum().item()
            running_loss_train += loss.item()

        train_accurate = train_accurate / train_num
        running_loss_train = running_loss_train / train_num
        train_acc_list.append(train_accurate)
        train_loss_list.append(running_loss_train)

        print('[epoch %d] train_loss: %.7f  train_accuracy: %.3f' %
              (epoch + 1, running_loss_train, train_accurate))

        # validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_loader = tqdm(val_loader)
            for val_data in val_loader:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        val_acc_list.append(val_accurate)
        print('[epoch %d] val_accuracy: %.3f' %
              (epoch + 1, val_accurate))

        function.writer_into_excel_onlyval(save_path, train_loss_list, train_acc_list, val_acc_list,"CIFAR10")

        # 选择最best的模型进行保存 评价指标此处是acc
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_model)



if __name__ == '__main__':
    main()

model.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:24
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)

        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(512 * block.expansion, 512),  # [2 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            # x = self.fc(x)
            # print((x.shape()))
            x = self.classifier(x)

        return x


class AlexnetChange(nn.Module):
    def __init__(self, ):
        super(AlexnetChange, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=13, stride=2, padding=0),                  # output[128, 1, 1]
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 1 * 1, 512),        # [batchsize值 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)      # output[512, 1, 1]

        return x



class Classifier(nn.Module):
    def __init__(self, num_classe=1000):
        super(Classifier, self).__init__()

        self.FC = nn.Sequential(
            nn.Linear(512 * 1 * 1, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classe),
        )

    def forward(self, x1=None, x2=None):
        if x1 != None and x2 != None:
            x = x1.add(x2)
            x = self.FC(x)
            # print("x1 add x2  ")
        elif x1 != None and x2 == None:
            x = self.FC(x1)
            # print("only x1  ")
        elif x1 == None and x2 != None:
            x = self.FC(x2)
            # print("only x2  ")
        else:
            print("Alexnet_Con has wrong")

        return x



def resnet18(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

function.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import xlwt



def writer_into_excel_onlyval(excel_path,loss_train_list, acc_train_list, val_acc_list,dataset_name:str=""):
    workbook = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    worksheet = workbook.add_sheet("sheet1", cell_overwrite_ok=True)  # 新增一个sheet
    worksheet.write(0, 0, label='Train_loss')
    worksheet.write(0, 1, label='Train_acc')
    worksheet.write(0, 2, label='Val_acc')


    for i in range(len(loss_train_list)):  # 循环将a和b列表的数据插入至excel
        worksheet.write(i + 1, 0, label=loss_train_list[i])  # 切片的原来是传进来的Imgs是一个路径的信息
        worksheet.write(i + 1, 1, label=acc_train_list[i])
        worksheet.write(i + 1, 2, label=val_acc_list[i])


    workbook.save(excel_path + str(dataset_name) +".xls")  # 这里save需要特别注意,文件格式只能是xls,不能是xlsx,不然会报错
    print('save success!   .')



parameters.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm




# 训练的次数
epoch = 2

# 训练的批次大小
batch_size = 4

# 数据集的分类类别数量
CIFAR10_class = 10

# 模型训练时候的学习率大小
resnet_lr = 0.002

# 保存模型权重的路径 保存xml文件的路径
resnet_save_path_CIFAR10 = './res/'
resnet_save_model = './res/best_model.pth'

其中部分参数,例如是学习率的大小,训练的批次大小,数据增强的一些小参数,可以根据自己的经验和算力的现实情况进行调整。

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/156676.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言常用内存函数的深度解析

文章目录前言memcpymemcpy函数的使用memcpy函数的自我实现memmovememmove函数的使用memmove函数的自我实现memcmpmemcmp函数的使用memcmp函数的自我实现memsetmemset函数的使用memset函数的自我实现写在最后前言 内存函数的使用广泛度大于常用字符串函数的使用广泛度&#xff0…

教程- VTK.js的基本介绍

VTK.js的核心是标准可视化工具包(VTK)库的JavaScript移植,这是一个c库,旨在促进数据可视化,在此基础上构建了科学可视化应用程序Paraview。VTK.js没有使用OpenGL,而是利用WebGL,主要关注几何和体渲染。因此&#xff0c…

JavaFx程序使用Gloun打包成Android平台App教程

0. 提要 !!! 适合有Maven基础,对JavaFx或JavaFX移动端感兴趣的朋友 提示必须在Linux环境下进行,可以使用虚拟机 推荐使用CentOS系统进行,虚拟机硬盘大小推荐最少给30G 不要像我一样,搞一半又去给文件系统根目录扩大容量 如果容量不够可以看篇博客: http…

C++模板(第二版)笔记之第十八章:模板的多态性

文章目录一、动态多态(dynamic polymorphism)二、静态多态三、静态多态VS动态多态1.术语2.优点和缺点3.结合两种多态形式:CRTP四、使用concepts五、新形势的设计模式六、泛型编程七、总结一、动态多态(dynamic polymorphism&#…

【C语言】内存函数介绍

它们所在的头文件: (这里出现的arr都为char类型数组)strlen作用:计算一个字符串的长度本质:历经千辛找一个 \0 ,找到 \0 就立马停止。(就是找 \0 )易错:strlen 返回值为 …

物联网无线通信技术中蓝牙和WIFI有哪些区别?

在物联网快速发展的时代,联网运行的设备越来越多,无线通信技术在物联网中发挥着举足轻重的作用,无线通信技术的发展改变了信息传输的方式,人们在任何时间、任何地点都可以访问设备,目前最常用的两种无线通信技术分别是…

云服务器CentOS前后端部署流程记录

部署流程记录 购买云服务ecs服务器,建立CentOS系统 通过xftpxshell访问远程服务 doker部署(https://www.runoob.com/docker/centos-docker-install.html) docker docker部署环境(mysql) docker常用命令 1. docker i…

【Linux】进程状态与优先级

文章目录进程状态概念Linux中的进程状态R(running)状态S(sleeping)状态D(disk sleep)状态T(stopped)状态t(tracing stop)状态X(dead)状态Z(zombie)状态特殊的孤儿进程进程优先级进程性质补充进程状态概念 《现代操作系统》中给出的进程状态的定义如下: 进程状态反映…

Qt+C++窗体界面中英文多语言切换

程序示例精选 QtC窗体界面中英文语言切换 如需安装运行环境或远程调试&#xff0c;见文章底部个人微信名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<QtC窗体界面中英文语言切换>>编写代码&#xff0c;代码整洁&#xff0c;规则&#x…

【Linux】软件包管理器 yum

目录 一、什么是软件包 二、如何进行软件安装 1、yum 的使用 2、yum 配置 一、什么是软件包 在Linux下安装软件&#xff0c;一个通常的办法是下载到程序的源代码&#xff0c;并进行编译&#xff0c;得到可执行程序。但是这样太麻烦了&#xff0c;于是有些人把一些常用的软…

InnoDB数据存储结构

InnoDB数据存储结构 本专栏学习内容来自尚硅谷宋红康老师的视频 有兴趣的小伙伴可以点击视频地址观看 1. 数据库的存储结构&#xff1a;页 索引结构给我们提供了高效的索引方式&#xff0c;不过索引信息以及数据记录都是保存在文件上的&#xff0c;确切来说是存储在页结构中。…

不讨论颜色的前提下,如何证明自己不是色盲?神奇的零知识证明

0x01 一个小故事 《阿里巴巴与四十大盗》中有这样一段小故事&#xff1a; 阿里巴巴会芝麻开门的咒语&#xff0c;强盗向他拷问打开山洞石门的咒语&#xff0c;他不想让人听到咒语&#xff0c;又要向强盗证明他知道这个咒语。 那应该怎么办呢&#xff1f; 便对强盗说&#xf…

基于KVM安装部署RHCOS操作系统

参考&#xff1a;Openshift 4.4 静态 IP 离线安装系列&#xff1a;初始安装 - 米开朗基杨 - 博客园 一、Openshift OCP集群安装架构示意图 RHCOS 的默认用户是 core 如果安装有问题会进入 emergency shell&#xff0c;检查网络、域名解析是否正常&#xff0c;如果正常一般是以…

重修JAVA

程序员的差距是在构思上&#xff1a;思想决定了深度&#xff0c;思想的精髓高深是很多人学不来的&#xff01; 每一门语言都有它的特点&#xff0c;有优势也有劣势&#xff0c; 所以不必拘泥于招式&#xff0c;掌握底层原理即可&#xff01; 每一们语言实际上都是一个“工具”&…

如何在您的香港主机帐户上注册多个域名

注册多个域名非常普遍。事实上&#xff0c;香港主机服务提供商鼓励这样做&#xff0c;因为它既有意义又是必要的。下面将介绍决定为什么您可能需要在香港主机上注册多个域名的几个因素。注册多个域名的原因是什么?方便多个项目如果香港主机帐户的所有者在网络上有多个不同域名…

优化vue项目后, 启动编译项目过程中 报 javaScript heap out of memory 错误 及 nodejs内存溢出

项目场景&#xff1a; 提示&#xff1a;这里简述项目相关背景&#xff1a; 1、优化vue项目后&#xff0c;运行npm run serve 启动编译项目过程中 报 javaScript heap out of memory 错误 2、项目启动时&#xff0c;出现 nodejs 内存溢出错误 问题描述 提示&#xff1a;遇到…

分布式事务的背景和解决方案

在常用的关系型数据库&#xff0c;都是具备事务特性的。 那什么是事务呢&#xff1f;事务是数据库运行的一个逻辑工作单元&#xff0c;在这个工作单元内的一系列SQL命令具有原子性操作的特点&#xff0c;也就是说这一系列SQL指令要么全部执行成功&#xff0c;要么全部回滚不执…

经典算法之深度优先搜索(DFS)

&#x1f451;专栏内容&#xff1a;算法学习笔记⛪个人主页&#xff1a;子夜的星的主页&#x1f495;座右铭&#xff1a;日拱一卒&#xff0c;功不唐捐。 目录一、前言二、基本概念1.简单介绍2. 官方概念三、动图分析四、模板框架五、例题分析组合问题题干描述&#xff1a;思路…

leetcode146. LRU 缓存【python3哈希表+双向链表】利用OrderedDict以及自实现双向链表

题目&#xff1a; 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。实现LRUCache类&#xff1a; LRUCache(int capacity) 以正整数作为容量capacity初始化 LRU 缓存int get(int key) 如果关键字key存在于缓存中&#xff0c;则返回关键字的值&#xff0c;否则…

【论文速递】9位院士Science88页长文:人工智能的进展、挑战与未来

【论文速递】9位院士Science88页长文&#xff1a;人工智能的进展、挑战与未来 【论文原文】&#xff1a;Intelligent Computing: The Latest Advances, Challenges and Future 获取地址&#xff1a;https://spj.science.org/doi/10.34133/icomputing.0006摘要&#xff1a; ​…