使用ResNet18实现CIFAR10数据集的训练

news2025/1/11 22:49:17

 如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄     

        使用ResNet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到32X32,因为算力相关的问题所以我选择了较低的训练图像分辨率。但是假如你自己的算力比较充足的话,我建议使用训练的使用图像的分辨率设置为224X224(这个可以在代码里面的transforms.RandomResizedCrop(32)和transforms.Resize((32, 32)),进行修改,很简单),因为在测试训练的时候,发现将CIFAR10数据集的分辨率拉大可以让模型更快地进行收敛,并且识别的效果也是比低分辨率的更加好。

首先来介绍一下,ResNet:

1.论文下载地址:https://arxiv.org/pdf/1512.03385.pdf 

2.ResNet的介绍:

 

代码实现:

数据集的处理:
        调用torchvision里面封装好的数据集进行数据的训练,并且利用官方已经做好的数据集分类是数据集的划分大小。进行了一些简单的数据增强,分别是随机的随机剪切和随机的水平拉伸操作。

模型的代码结构目录:

train.py文件内容:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm



import torchvision

from model import resnet18
import os
import parameters
import function
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm



def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    epochs = parameters.epoch
    save_model = parameters.resnet_save_model
    save_path = parameters.resnet_save_path_CIFAR10


    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(32),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

        "val": transforms.Compose([transforms.Resize((32, 32)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        }

    train_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=True,
                                                download=True, transform=data_transform["train"])

    val_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=False,
                                           download=False, transform=data_transform["val"])


    train_num = len(train_dataset)
    val_num = len(val_dataset)
    print("using {} images for training, {} images for validation.".format(train_num, val_num))
    # #################################################################################################################

    batch_size = parameters.batch_size

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    # ##################################################################################################################
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             )

    model = resnet18(num_classes=parameters.CIFAR10_class)
    model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=parameters.resnet_lr)
    best_acc = 0.0

    # 为后面制作表图
    train_acc_list = []
    train_loss_list = []
    val_acc_list = []

    for epoch in range(epochs):
        # train
        model.train()
        running_loss_train = 0.0
        train_accurate = 0.0
        train_bar = tqdm(train_loader)
        for images, labels in train_bar:
            optimizer.zero_grad()

            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            predict = torch.max(outputs, dim=1)[1]
            train_accurate += torch.eq(predict, labels.to(device)).sum().item()
            running_loss_train += loss.item()

        train_accurate = train_accurate / train_num
        running_loss_train = running_loss_train / train_num
        train_acc_list.append(train_accurate)
        train_loss_list.append(running_loss_train)

        print('[epoch %d] train_loss: %.7f  train_accuracy: %.3f' %
              (epoch + 1, running_loss_train, train_accurate))

        # validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_loader = tqdm(val_loader)
            for val_data in val_loader:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        val_acc_list.append(val_accurate)
        print('[epoch %d] val_accuracy: %.3f' %
              (epoch + 1, val_accurate))

        function.writer_into_excel_onlyval(save_path, train_loss_list, train_acc_list, val_acc_list,"CIFAR10")

        # 选择最best的模型进行保存 评价指标此处是acc
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_model)



if __name__ == '__main__':
    main()

model.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:24
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)

        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(512 * block.expansion, 512),  # [2 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            # x = self.fc(x)
            # print((x.shape()))
            x = self.classifier(x)

        return x


class AlexnetChange(nn.Module):
    def __init__(self, ):
        super(AlexnetChange, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=13, stride=2, padding=0),                  # output[128, 1, 1]
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 1 * 1, 512),        # [batchsize值 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)      # output[512, 1, 1]

        return x



class Classifier(nn.Module):
    def __init__(self, num_classe=1000):
        super(Classifier, self).__init__()

        self.FC = nn.Sequential(
            nn.Linear(512 * 1 * 1, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classe),
        )

    def forward(self, x1=None, x2=None):
        if x1 != None and x2 != None:
            x = x1.add(x2)
            x = self.FC(x)
            # print("x1 add x2  ")
        elif x1 != None and x2 == None:
            x = self.FC(x1)
            # print("only x1  ")
        elif x1 == None and x2 != None:
            x = self.FC(x2)
            # print("only x2  ")
        else:
            print("Alexnet_Con has wrong")

        return x



def resnet18(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

function.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import xlwt



def writer_into_excel_onlyval(excel_path,loss_train_list, acc_train_list, val_acc_list,dataset_name:str=""):
    workbook = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    worksheet = workbook.add_sheet("sheet1", cell_overwrite_ok=True)  # 新增一个sheet
    worksheet.write(0, 0, label='Train_loss')
    worksheet.write(0, 1, label='Train_acc')
    worksheet.write(0, 2, label='Val_acc')


    for i in range(len(loss_train_list)):  # 循环将a和b列表的数据插入至excel
        worksheet.write(i + 1, 0, label=loss_train_list[i])  # 切片的原来是传进来的Imgs是一个路径的信息
        worksheet.write(i + 1, 1, label=acc_train_list[i])
        worksheet.write(i + 1, 2, label=val_acc_list[i])


    workbook.save(excel_path + str(dataset_name) +".xls")  # 这里save需要特别注意,文件格式只能是xls,不能是xlsx,不然会报错
    print('save success!   .')



parameters.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm




# 训练的次数
epoch = 2

# 训练的批次大小
batch_size = 1024

# 数据集的分类类别数量
CIFAR10_class = 10

# 模型训练时候的学习率大小
resnet_lr = 0.002

# 保存模型权重的路径 保存xml文件的路径
resnet_save_path_CIFAR10 = './res/'
resnet_save_model = './res/best_model.pth'

 其中部分参数,例如是学习率的大小,训练的批次大小,数据增强的一些小参数,可以根据自己的经验和算力的现实情况进行调整。

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/156837.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git使用详解(图文+代码):基础内容

基础内容前言版本控制本地版本控制系统集中化版本控制系统分布式控制系统Git使用详解Git基础理解Git基础指令取得项目的Git仓库记录每次更新的仓库检查当前文件状态跟踪文件暂存已修改文件忽略某些文件查看已暂存和未暂存的更新提交更新跳过使用暂存区域移除文件远程操作的使用…

熊市里再看GameFi,为什么说链游潜力巨大?

大方向上来看,区块链项目本质上分为两类,一类是金融资产属性的项目,比如我们常说的DeFi、DAO、公链等,另一类则具有娱乐艺术属性的,比如NFT、GameFi、元宇宙等,熊市环境下如何看待这两类项目,以…

【iOS】—— 初识GCD

GCD(Grand Central Dispatch) 文章目录GCD(Grand Central Dispatch)什么是GCDperformSelector方法:GCD的优点:任务和队列队列的创建方式任务的创建方法六种情况的例子1.并发队列 同步执行2. 并发队列 异步…

Vector - VT System - Ethernet板卡_VT6306

前面介绍了支持CAN&CANFD&LIN板卡,但是对于当前日益火爆的车载以太网来说,Vector也是提供了类似于VN5000系列一样的板卡,那就是VT6306。它给提供6路的百兆或者6路千兆的车载以太网(2022年之前选择后是固定的,有…

央视点赞百度智能云激活民营经济”数字“活力

2023年,对民营企业究竟意味着什么。 2022年12月,新华社发表重磅长文解读中国经济发展大势:“迎接更加壮阔的光明前程”。 随后央视新闻联播连续4天发声,关注民营经济发展。4条新闻中,“创新”一词共出现了29次&#…

小程序的运行机制以及安全机制

接触小程序有一段时间了,总得来说小程序开发门槛比较低,但其中基本的运行机制和原理还是要懂的. 了解小程序的由来 在小程序没有出来之前,最初微信WebView逐渐成为移动web重要入口,微信发布了一整套网页开发工具包,称…

199:vue+openlayers 添加删除修改feature信息,双向不同颜色指示互动

第199个 点击查看专栏目录 本示例的目的是介绍如何在vue+openlayers项目中绘制多边形,每绘制一个,左侧输出一个feature指示标志,双向颜色互动指示。 直接复制下面的 vue+openlayers源代码,操作2分钟即可运行实现效果; 注意如果OpenStreetMap无法加载,请加载其他来练习 …

前缀和讲解

目录 一、前言 二、前缀和 1、基本概念 2、前缀和与差分的关系 3、差分数组能提升修改的效率 三、例题 1、统计子矩阵(lanqiao2109,2022年省赛) (1)处理输入 (2)方法一:纯暴…

设计模式面试题

工厂模式是我们最常用的实例化对象模式了,是用工厂方法代替new操作的一种模式,工厂模式在Java程序中可以说是随处可见。本文来给大家详细介绍下工厂模式 面向对象设计的基本原则: OCP(开闭原则,Open-Closed Principle&#xff0…

字符串函数介绍——C语言

文章目录 一、引言 二、函数的介绍与模拟实现 2、1 求字符串长度strlen()函数 2、1、1 strlen()函数介绍 2、1、2 strlen()函数的模拟实现 2、2 字符串拷贝strcpy()函数 2、2、1 s…

「旷野俱乐部」在 The Sandbox 开业,SMCU 宫殿等你来体验!

简要概括 KWANGYAThe Sandbox 是「旷野俱乐部」在 The Sandbox 元宇宙中的虚拟空间; SMCU 宫殿体验呈现了 2022 年冬季 SM 小镇的视觉效果,SMCU 宫殿专辑封面将于 1 月 10 日發佈; 将向全球粉丝展示更多基于韩国文化内容的元宇宙体验。 The…

Appium+Pytest+pytest-testreport框架轻松实现app自动化

有任何环境问题,可以参考我的文章 Appium自动化测试<一>, Appium自动化测试<二>有任何定位问题、触屏操作、等待问题、Toast 信息操作问题、手机操作问题及H5页面的操作请参考我的文章:Appium自…

【论文速递】TNNLS2022 - 一种用于小样本分割的互监督图注意网络_充分利用有限样本的视角

【论文速递】TNNLS2022 - 一种用于小样本分割的互监督图注意网络_充分利用有限样本的视角 【论文原文】:A Mutually Supervised Graph Attention Network for Few-Shot Segmentation: The Perspective of Fully Utilizing Limited Samples 获取地址:ht…

Java设计模式-组合模式Composite

介绍 组合模式(Composite Pattern),又叫部分整体模式,它创建了对象组的树形结构,将对象组合成树状结构以表示“整体-部分”的层次关系。组合模式依据树形结构来组合对象,用来表示部分以及整体层次。这种类…

【Nginx】Nginx的常用命令和配置文件

1. 常用命令 1. 查看版本2. 查看 Nginx 配置语法的正确性3. 为Nginx指定一个配置文件4. 启动 Nginx 服务5. 开机自启动6. 重启 Nginx 服务7. 查看 Nginx 服务状态8. 重载 Nginx 服务9. 停止 Nginx 服务10. 查看命令帮助 2. 配置文件 第一部分:全局块第二部分&#x…

RT-Thread系列--内存池MEMPOOL源码分析

一、目的嵌入式RTOS中最重要也是最容易被忽略的一个组件就是内存管理,像FreeRTOS单单内存管理组件就提供了heap_1/2/3/4/5这五种方案,每种方案都有其特点和应用场景。一般情况下小系统所运行的芯片平台本身内存就很少,有些时候内存空间还不连…

libdrm-2.4.112

编译 这个版本使用了meson进行构建、ninja进行编译 ; 安装meson 编译 报错如上,查看meson.build文件, 我们的meson版本不正确, 查阅发现apt安装的版本过低; 安装meson sudo apt-get install python3 python3-pip …

LeetCode 111. 二叉树的最小深度

🌈🌈😄😄 欢迎来到茶色岛独家岛屿,本期将为大家揭晓LeetCode 111. 二叉树的最小深度,做好准备了么,那么开始吧。 🌲🌲🐴🐴 一、题目名称 二、…

程序的编译与链接——ARM可执行文件ELF

读书《嵌入式C语言自我修养》笔记 目录 读书《嵌入式C语言自我修养》笔记 ARM编译工具 使用readelf命令查看ELF Header 使用readelf命令查看ELF section header 程序编译 预处理器 编译器 (1)词法分析。 (2)语法分析。 …

班级人员可视化项目

页面分布文件分布index.html(搭建页面)index.css (修饰页面)fonts (放图标)images (放图片)jsjquery.js (调整页面的js)flexible.js (尺寸大小的js)echarts.min.js (charts图表的js)chinaMap…