使用ResNet34实现CIFAR10数据集的训练

news2025/1/10 23:31:19

 如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄     

        使用ResNet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到32X32,因为算力相关的问题所以我选择了较低的训练图像分辨率。但是假如你自己的算力比较充足的话,我建议使用训练的使用图像的分辨率设置为224X224(这个可以在代码里面的transforms.RandomResizedCrop(32)和transforms.Resize((32, 32)),进行修改,很简单),因为在测试训练的时候,发现将CIFAR10数据集的分辨率拉大可以让模型更快地进行收敛,并且识别的效果也是比低分辨率的更加好。

首先来介绍一下,ResNet:

1.论文下载地址:https://arxiv.org/pdf/1512.03385.pdf 

2.ResNet的介绍:

 

代码实现:

数据集的处理:
        调用torchvision里面封装好的数据集进行数据的训练,并且利用官方已经做好的数据集分类是数据集的划分大小。进行了一些简单的数据增强,分别是随机的随机剪切和随机的水平拉伸操作。

模型的代码结构目录:

train.py文件内容:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm



import torchvision

from model import resnet34
import os
import parameters
import function
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm



def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    epochs = parameters.epoch
    save_model = parameters.resnet_save_model
    save_path = parameters.resnet_save_path_CIFAR10


    data_transform = {
        "train": transforms.Compose([ transforms.RandomResizedCrop(32),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

        "val": transforms.Compose([transforms.Resize((32, 32)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        }

    train_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=True,
                                                download=True, transform=data_transform["train"])

    val_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=False,
                                           download=False, transform=data_transform["val"])


    train_num = len(train_dataset)
    val_num = len(val_dataset)
    print("using {} images for training, {} images for validation.".format(train_num, val_num))
    # #################################################################################################################

    batch_size = parameters.batch_size

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    # ##################################################################################################################
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             )

    model = resnet34(num_classes=parameters.CIFAR10_class)
    model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=parameters.resnet_lr)
    best_acc = 0.0

    # 为后面制作表图
    train_acc_list = []
    train_loss_list = []
    val_acc_list = []

    for epoch in range(epochs):
        # train
        model.train()
        running_loss_train = 0.0
        train_accurate = 0.0
        train_bar = tqdm(train_loader)
        for images, labels in train_bar:
            optimizer.zero_grad()

            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            predict = torch.max(outputs, dim=1)[1]
            train_accurate += torch.eq(predict, labels.to(device)).sum().item()
            running_loss_train += loss.item()

        train_accurate = train_accurate / train_num
        running_loss_train = running_loss_train / train_num
        train_acc_list.append(train_accurate)
        train_loss_list.append(running_loss_train)

        print('[epoch %d] train_loss: %.7f  train_accuracy: %.3f' %
              (epoch + 1, running_loss_train, train_accurate))

        # validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_loader = tqdm(val_loader)
            for val_data in val_loader:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        val_acc_list.append(val_accurate)
        print('[epoch %d] val_accuracy: %.3f' %
              (epoch + 1, val_accurate))

        function.writer_into_excel_onlyval(save_path, train_loss_list, train_acc_list, val_acc_list,"CIFAR10")

        # 选择最best的模型进行保存 评价指标此处是acc
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_model)



if __name__ == '__main__':
    main()

model.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:24
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)

        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(512 * block.expansion, 512),  # [2 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            # x = self.fc(x)
            # print((x.shape()))
            x = self.classifier(x)

        return x


class AlexnetChange(nn.Module):
    def __init__(self, ):
        super(AlexnetChange, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=13, stride=2, padding=0),                  # output[128, 1, 1]
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 1 * 1, 512),        # [batchsize值 512 1 1]
            nn.ReLU(inplace=True),
            # nn.Linear(512, num_classes),

        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)      # output[512, 1, 1]

        return x



class Classifier(nn.Module):
    def __init__(self, num_classe=1000):
        super(Classifier, self).__init__()

        self.FC = nn.Sequential(
            nn.Linear(512 * 1 * 1, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, num_classe),
        )

    def forward(self, x1=None, x2=None):
        if x1 != None and x2 != None:
            x = x1.add(x2)
            x = self.FC(x)
            # print("x1 add x2  ")
        elif x1 != None and x2 == None:
            x = self.FC(x1)
            # print("only x1  ")
        elif x1 == None and x2 != None:
            x = self.FC(x2)
            # print("only x2  ")
        else:
            print("Alexnet_Con has wrong")

        return x



def resnet18(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)

function.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import xlwt



def writer_into_excel_onlyval(excel_path,loss_train_list, acc_train_list, val_acc_list,dataset_name:str=""):
    workbook = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    worksheet = workbook.add_sheet("sheet1", cell_overwrite_ok=True)  # 新增一个sheet
    worksheet.write(0, 0, label='Train_loss')
    worksheet.write(0, 1, label='Train_acc')
    worksheet.write(0, 2, label='Val_acc')


    for i in range(len(loss_train_list)):  # 循环将a和b列表的数据插入至excel
        worksheet.write(i + 1, 0, label=loss_train_list[i])  # 切片的原来是传进来的Imgs是一个路径的信息
        worksheet.write(i + 1, 1, label=acc_train_list[i])
        worksheet.write(i + 1, 2, label=val_acc_list[i])


    workbook.save(excel_path + str(dataset_name) +".xls")  # 这里save需要特别注意,文件格式只能是xls,不能是xlsx,不然会报错
    print('save success!   .')



parameters.py文件:

# -*- coding:utf-8 -*-
# @Time : 2023-01-11 20:25
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm




# 训练的次数
epoch = 2

# 训练的批次大小
batch_size = 4

# 数据集的分类类别数量
CIFAR10_class = 10

# 模型训练时候的学习率大小
resnet_lr = 0.002

# 保存模型权重的路径 保存xml文件的路径
resnet_save_path_CIFAR10 = './res/'
resnet_save_model = './res/best_model.pth'

其中部分参数,例如是学习率的大小,训练的批次大小,数据增强的一些小参数,可以根据自己的经验和算力的现实情况进行调整。

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/157152.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

摘要/哈希/散列算法MD5 SHA1 SHA256 SHA512的区别和MAC算法

一、摘要算法大致都要经过以下步骤 1. 明文数据预处理 1.1 填充比特 MD5、SHA1、SHA256 的分组长度都是512bit 需要填充比特使其位长对512求余的结果等于448 SHA512 的分组长度是 1024bit 需要填充比特使其对1024求余的结果等于896 相同&am…

ECharts基本使用

文章目录Echarts概述Echarts初体验ECharts基础配置Echarts社区介绍Echarts-map使用Echarts概述 常见的数据可视化库: D3.js 目前 Web 端评价最高的 Javascript 可视化工具库(入手难)ECharts.js 百度出品的一个开源 Javascript 数据可视化库Highcharts.js 国外的前…

项目合并后,font字体资源被替换导致TextMeshPro不能显示文字,抢救方法

一,字体消失 项目合并时,因为资源更替,导致TextMeshPro不能找到自己原来使用的font资源,以致不能显示文字。 二、抢救方式 1、找到所有用到TextMeshPro的物体2、把他们的字体重新设置成你要的字体 关键步骤: 1、找…

赛事推荐| 建筑物实例分割和高度估计的多任务学习——2023 IEEE GRSS 数据融合赛道2

1. 赛题名称 联合建筑物提取和高度估计的多任务学习 2. 赛题背景 该轨道定义了建筑物提取和高度估计的联合任务。两者都是建筑改造的两个非常基础和必不可少的任务。与轨道 1 相同,输入数据是多模态光学和 SAR 卫星图像。单视图卫星图像中的建筑物提取和高度估计…

记录redis连接被打满的踩坑之路

一、系统异常现象系统有一个功能向别的系统多线程推送用户数据信息,前几天发现该推送功能报内部错误,经过查看后台日志文件,发现org.redisson.client.RedisConnectionException: Unable to connect to Redis server:,io.netty.cha…

使用docker训练yolov5

使用docker训练yolov5 配置docker,配置的好处是docker中的环境或者说容器坏了不影响主机,并且可以减少配置环境的时间和精力 sudo apt update sudo apt install apt-transport-https ca-certificates curl gnupg-agent software-properties-common # c…

Docker 部署SQL Server 2017

Docker 部署SQL Server 2017 Docker部署 registry Docker搭建 svn Docker部署 Harbor Docker 部署SQL Server 2017 Docker 安装 MS SqlServer Docker部署 Oracle12c 文章目录Docker 部署SQL Server 2017一、部署步骤1.下载镜像2.创建容器并运行二、参考文档一、部署步骤 1.下…

Unity 之 资源加载 -- 可寻址系统概念介绍 -- 入门(一)

可寻址系统面板概念 -- 入门(一)一,可寻址系统概念介绍1.1 官方话术1.2 几个概念二,可寻址系统目录介绍2.1 导入工程2.2 目录介绍概述:本片文章带大家了解可寻址系统的相关概念,为大家介绍可寻址系统导入方…

生成数据分析报告pandas_profiling.ProfileReport

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 生成数据分析报告 pandas_profiling.ProfileReport 选择题 对于以下python代码表述错误的一项是? import pandas as pd import pandas_profiling as pp dfpd.DataFrame({ a:[23,18,21], b:[…

excel数据核对技巧:如何用函数公式标识输入正误

我们平时人工录入较长的文本数据时,稍不注意就容易出错。为了避免出错,通常我们会提前对单元格设置数据验证。有些时候,我们还会考虑列与列之间的关系,根据列关系自动判定数据的对错。比如下表,款号、货号、色号、条码的信息均存在…

【MySQL进阶教程】InnoDB引擎

前言 本文为 【MySQL进阶教程】InnoDB引擎 相关知识,下边将对InnoDB引擎介绍,InnoDB引擎架构,事务原理,MVCC等进行详尽介绍~ 📌博主主页:小新要变强 的主页 👉Java全栈学习路线可参考&#xff…

获取每年的周数据 第几周 开始日及结束日 思路

public static void main(String[] args) {int year 2023;SimpleDateFormat simpleDateFormat new SimpleDateFormat("yyyy-MM-dd");while (true) {int weekValue 1;Calendar calendar new GregorianCalendar();//***踩坑 // calendar.setFirstDayOfW…

冒泡排序终极版(模拟qsort)

目录 普通版冒泡排序 qosrt函数 终极版冒泡排序 终极版冒泡排序整体测试代码 普通版冒泡排序 冒泡排序想必大家都很了解了吧,冒泡排序的算法思想就是两两比大小,一轮一轮比,每比完一轮排出一个数字的顺序,那就让我们先来看一…

软件测试/测试开发丨从 0 开始学 Python 自动化测试开发(二):环境搭建

本文是「从 0 开始学 Python 自动化测试开发」专题系列文章第二篇 —— 环境搭建篇,适合零基础入门的同学。没有阅读过上一篇的同学,请戳蓝色字体阅读。作者方程老师,是前某跨国通信公司高级测试经理,目前为某互联网名企资深测试技…

【算法基础】1.4 高精度(模拟大数运算:整数加减乘除)

文章目录高精度加法题目描述解法高精度减法题目描述解法讲解高精度乘法题目描述解法讲解高精度除法题目描述解法讲解本文主要讲解高精度计算,包括加法、减法、乘法和除法。 对于Python选手,python自带高精度计算;Java也有BigInteger类。但是对…

javaEE 初阶 — 多线程— JUC(java.util.concurrent) 的常见类

文章目录1. Callable 接口1.1 Callable 的用法2. ReentrantLock2.1 ReentrantLock 的缺陷2.1 ReentrantLock 的优势3. 原子类4. 信号量 Semaphore5. CountDownLatch6. 相关面试题1. Callable 接口 类似于 Runnable 一样。 Runnable 用来描述一个任务,描述的任务没有…

我们一直在说数字化转型,什么才是数字化转型?

我们一直在说数字化转型,什么才是数字化转型?深度长文,4000字,融合了很多国内外专业期刊观点,一文讲清到底什么是企业数字化转型,心急的小伙伴可以先看目录: 关于定义——到底什么是“数字化转…

24 届秋招 | 高质量学习交流环境

大家好,我和一些计算机方向、背景非常优秀的、来自清华、新国立等知名大学的几位同学以及工作多年的高级研发工程师一起运营了一个知识星球。 星球里有大量国内top985、海外名校的同学在一起,目的是为了打造一个非常优质量的社群。 如果你也曾苦于在各…

PySimpleGUI图形化界面实现Office文件格式转换

PySimpleGUI图形化界面实现Office文件格式转换Python实现三种文件两个版本的格式转换1、doc与docx格式互相转换2、xls与xlsx格式互相转换3、ppt与pptx格式互相转换PythonPySimpleGUI实现综合版本Python实现三种文件两个版本的格式转换 1、doc与docx格式互相转换 这里主要运用…

excel求和技巧:如何忽略错误值进行求和

按照对应的订单号引用已有的收货金额,这种问题相信很多朋友都会处理,用VLOOKUP函数就能搞定。我们今天要讨论的是如何对含有错误值的数据进行求和。如果直接求和,得到的结果也是一个错误值,如下图:对于这种要对含有错误…