基于ResNet-18实现Cifar-10图像分类

news2025/1/13 13:15:30

目录

  • 1、作者介绍
  • 2、数据集介绍
    • 2.1Cifar-10数据集介绍:
  • 3、ResNet网络介绍
    • 3.1Residual Network残差网络
    • 3.2ResNet18网络结构
  • 4、代码复现及实验结果
    • 4.1训练代码
    • 4.2测试代码
    • 4.3实验结果

1、作者介绍

安耀辉,男,西安工程大学电子信息学院,22级研究生
研究方向:小样本图像分类算法
电子邮箱:1349975181@qq.com

张思怡,女,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:981664791@qq.com

2、数据集介绍

2.1Cifar-10数据集介绍:

CIFAR-10 数据集由 60000张图片,每张为32✖32大小的彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像(共10个类,如图所示)。
数据集分为五个训练批次和一个测试批次,每个批次有 10000 张图像。测试批次包含来自每个类的 1000 个随机选择的图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。
在这里插入图片描述

3、ResNet网络介绍

3.1Residual Network残差网络

ResNet全名Residual Network残差网络。Kaiming He 的《Deep Residual Learning for Image Recognition》获得了CVPR最佳论文。他提出的深度残差网络在2015年可以说是洗刷了图像方面的各大比赛,以绝对优势取得了多个比赛的冠军。而且它在保证网络精度的前提下,将网络的深度达到了152层,后来又进一步加到1000的深度。论文的开篇先是说明了深度网络的好处:特征等级随着网络的加深而变高,网络的表达能力也会大大提高。因此论文中提出了一个问题:是否可以通过叠加网络层数来获得一个更好的网络呢?作者经过实验发现,单纯的把网络叠起来的深层网络的效果反而不如合适层数的较浅的网络效果。因此何恺明等人在普通平原网络的基础上增加了一个shortcut, 构成一个residual block。此时拟合目标就变为F(x),F(x)就是残差:
在这里插入图片描述

3.2ResNet18网络结构

这里给出比较详细的ResNet18网络具体参数和执行流程图:
在这里插入图片描述

4、代码复现及实验结果

4.1训练代码

import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from utils.readData import read_dataset
from utils.ResNet import ResNet18

# set device
device = 'cuda' if torch.cuda.is_available() else 'cpu' # gpu使能
# 读数据
batch_size = 128
train_loader,valid_loader,test_loader = read_dataset(batch_size=batch_size,
                                                     pic_path='D:\AnYaohui\Cifar10\data')
# 加载模型(使用预处理模型,修改最后一层,固定之前的权重)
n_class = 10
model = ResNet18()
"""
ResNet18网络的7x7降采样卷积和池化操作容易丢失一部分信息,
所以在实验中我们将7x7的降采样层和最大池化层去掉,替换为一个3x3的降采样卷积,
同时减小该卷积层的步长和填充大小
"""
model.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1,
                        bias=False)
model.fc = torch.nn.Linear(512, n_class) # 将最后的全连接层改掉
model = model.to(device)
# 使用交叉熵损失函数
criterion = nn.CrossEntropyLoss().to(device)

# 开始训练
n_epochs = 250
valid_loss_min = np.Inf # track change in validation loss 表示+∞,是没有确切的数值的,类型为浮点型
accuracy = []
lr = 0.1
counter = 0
for epoch in tqdm(range(1, n_epochs+1)):

    # keep track of training and validation loss
    train_loss = 0.0
    valid_loss = 0.0
    total_sample = 0
    right_sample = 0
    
    # 动态调整学习率
    if counter/10 ==1:
        counter = 0
        lr = lr*0.5
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    ###################
    # 训练集的模型 #
    ###################
    model.train() #作用是启用batch normalization和drop out
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        # clear the gradients of all optimized variables(清除梯度)
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        # (正向传递:通过向模型传递输入来计算预测输出)
        output = model(data).to(device)  #(等价于output = model.forward(data).to(device) )
        # calculate the batch loss(计算损失值)
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        # (反向传递:计算损失相对于模型参数的梯度)
        loss.backward()
        # perform a single optimization step (parameter update)
        # 执行单个优化步骤(参数更新)
        optimizer.step()
        # update training loss(更新损失)
        train_loss += loss.item()*data.size(0)
        
    ######################    
    # 验证集的模型#
    ######################

    model.eval()  # 验证模型
    for data, target in valid_loader:
        data = data.to(device)
        target = target.to(device)
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data).to(device)
        # calculate the batch loss
        loss = criterion(output, target)
        # update average validation loss 
        valid_loss += loss.item()*data.size(0)
        # convert output probabilities to predicted class(将输出概率转换为预测类)
        _, pred = torch.max(output, 1)    
        # compare predictions to true label(将预测与真实标签进行比较)
        correct_tensor = pred.eq(target.data.view_as(pred))
        # correct = np.squeeze(correct_tensor.to(device).numpy())
        total_sample += batch_size
        for i in correct_tensor:
            if i:
                right_sample += 1
    print("Accuracy:",100*right_sample/total_sample,"%")
    accuracy.append(right_sample/total_sample)
 
    # 计算平均损失
    train_loss = train_loss/len(train_loader.sampler)
    valid_loss = valid_loss/len(valid_loader.sampler)
        
    # 显示训练集与验证集的损失函数 
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        epoch, train_loss, valid_loss))
    
    # 如果验证集损失函数减少,就保存模型。
    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,valid_loss))
        torch.save(model.state_dict(), 'checkpoint/resnet18_cifar10.pt')
        valid_loss_min = valid_loss
        counter = 0
    else:
        counter += 1

4.2测试代码

import torch
import torch.nn as nn
from utils.readData import read_dataset
from utils.ResNet import ResNet18
# set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_class = 10
batch_size = 100
train_loader,valid_loader,test_loader = read_dataset(batch_size=batch_size,pic_path='dataset')
model = ResNet18() # 得到预训练模型
model.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
model.fc = torch.nn.Linear(512, n_class) # 将最后的全连接层修改
# 载入权重
model.load_state_dict(torch.load('checkpoint/resnet18_cifar10.pt'))
model = model.to(device)

total_sample = 0
right_sample = 0
model.eval()  # 验证模型
for data, target in test_loader:
    data = data.to(device)
    target = target.to(device)
    # forward pass: compute predicted outputs by passing inputs to the model
    output = model(data).to(device)
    # convert output probabilities to predicted class(将输出概率转换为预测类)
    _, pred = torch.max(output, 1)    
    # compare predictions to true label(将预测与真实标签进行比较)
    correct_tensor = pred.eq(target.data.view_as(pred))
    # correct = np.squeeze(correct_tensor.to(device).numpy())
    total_sample += batch_size
    for i in correct_tensor:
        if i:
            right_sample += 1
print("Accuracy:",100*right_sample/total_sample,"%")

4.3实验结果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/560616.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

144 Tops,特斯拉如何低成本实现了城市NOA?

作者 | 树人 编辑 | 德新 根据特斯拉2022年Q4的财务文件披露&#xff1a;FSD Beta已有将近40万用户。 这是目前全世界部署规模最大的城市NOA系统。 而特斯拉实现这样一套系统&#xff0c;在车端几乎仅用了8个摄像头和144 Tops算力的FSD计算平台。这种性能压榨和成本控制能力让…

2023年内网穿透常用的几个工具

作为一名开发者&#xff0c;先给大家普及一下什么是内网&#xff0c;什么是外网。 所谓内网就是内部建立的局域网络或办公网络。比如一家公司或一个家庭有多台计算机&#xff0c;他们利用不同网络布局将这一台或多台计算机或其它设备连接起来构成一个局部的办公或者资源共享网…

这可能是最全面的Java面试八股文了

Java的特点 Java是一门面向对象的编程语言。面向对象和面向过程的区别参考下一个问题。 Java具有平台独立性和移植性。 Java有一句口号&#xff1a;Write once, run anywhere&#xff0c;一次编写、到处运行。这也是Java的魅力所在。而实现这种特性的正是Java虚拟机JVM。已编…

2024王道数据结构考研丨第五篇:树、图

2024王道数据结构考研笔记专栏将持续更新&#xff0c;欢迎 点此 收藏&#xff0c;共同交流学习… 文章目录 第五章&#xff1a;树5.1树的基本概念5.1.1树的定义5.1.2基本术语5.1.3树的性质 5.2二叉树的概念5.2.1 二叉树的定义与特性5.2.2几种特殊的二叉树5.2.3二叉树的存储结构…

2022级云曦实验室考试(一)reverse

一.Reverse 打开后是个rar压缩包&#xff0c;解压后 发现这玩意儿&#xff0c;我也不知道是个啥&#xff0c;之前没做过这类题 浅搜一下 啊&#xff0c;看不懂 用一下自己的歪办法 用txt打开看看有没有啥&#xff1f; 发现两个里面都有相同的flag&#xff0c;改成正确格式&…

C语言小游戏--三子棋

目录 问题描述 逻辑分析 具体实现 1.进入菜单界面 2.初始化棋盘 3.打印棋盘 4.玩家下棋 5.电脑下棋 6.判断输赢 运行结果 完整代码 game.h game.c test.c 问题描述 结合C语言所学知识&#xff0c;简单实现一个三子棋小游戏。 逻辑分析 进入菜单界面初始化棋盘…

帅地这些年看过的书

大家好&#xff0c;我是帅地。 好久没有给大家推荐书籍了&#xff0c;我一般很少给大家推荐书籍&#xff0c;因为自己没看过的&#xff0c;基本不推&#xff0c;只推荐我自己看过且自己自认为不错的书籍。 因为我自己本身是凭借着扎实的基础拿到大厂 offer 的&#xff0c;所以…

nodej+vues汽车销售4s店服务平台商城系统购物车积分兑换7z9d2

在经济快速发展的带动下&#xff0c;汽车服务平台的发展也是越来越快速。用户对汽车服务信息的获取需求很大。在互联网飞速发展的今天&#xff0c;制作一个汽车服务平台系统是非常必要的。本系统是借鉴其他人的开发基础上&#xff0c;用MySQL数据库和nodejs定制了汽车服务平台系…

Windows安装多个Mysql服务

1、正常安装好第一个 正常安装即可 2、第二个安装方法 1、官网下载zip包 MySQL :: MySQL Downloads 2、解压下载好的压缩包 &#xff08;注意修改文件夹名称&#xff0c;此时文件夹内并没有data文件夹&#xff09; 3、编写my.ini 注意修改端口号port以及安装目录basedir…

龙芯2K1000实战开发-平台介绍

文章目录 概要整体架构流程技术名词解释技术细节小结概要 龙芯 2K1000 处理器主要面向于网络应用,兼顾平板应用及工控领域应 用。采用 40nm 工艺,片内集成 2 个 GS264 处理器核,主频 1GHz,64 位 DDR3 控制器,以及各种系统 IO 接口。 整体架构 龙芯 2K1000 的结构如图 所…

《Oracle高级数据库》期末复习一文总结

文章目录 第一章&#xff1a;数据库基础1.数据库系统数据库数据库管理系统数据库系统 2.数据模型层次模型网状模型关系模型 3.关系型数据库&#xff08;1&#xff09;数据定义语言&#xff08;DDL&#xff09;&#xff08;2&#xff09;数据操纵语言&#xff08;DML&#xff09…

UC-OWOD: Unknown-Classified Open World Object Detection(论文翻译)

文章目录 UC-OWOD: Unknown-Classified Open World Object Detection摘要1.介绍2.相关工作3.未知分类的开放世界目标检测3.1 问题定义3.2 整体架构3.3 未知物体的检测3.4基于相似性的未知分类3.5未知聚类优化3.6训练和优化 4&#xff1a;实验4.1准备工作4.2结果和分析4.3消融研…

数学算法组合与排序

一句话总结&#xff1a;组合得次序是否重要&#xff0c;是否可重复&#xff0c;决定了组合数量 一、什么是组合&排序 组合可以是现实的一切事物、例如 [衣服&#xff0c;鞋子&#xff0c;眼镜...] 等等&#xff0c; 也可以表示一组数字 [1, 2, 3, 4, 5] &#xff0c;从个人…

STL常用容器_2

目录 一、stcak容器&#xff08;栈容器&#xff09; 1、基本概念 2、常用接口 二、queue容器&#xff08;队列容器&#xff09; 1、基本概念 2、常用接口函数 三、list容器&#xff08;链表&#xff09; 1、基本概念 2、构造函数 3、赋值与交换 4、大小操作 5、插入…

网络层和数据链路层

目录 网络层 IP协议 基本概念 协议头格式 ​编辑 网段划分 特殊的IP地址 IP地址的数量限制 私有IP地址和公网IP地址 路由 ​编辑数据链路层 以太网 以太网帧格式 认识MAC地址 对比理解MAC地址和IP地址 认识MTU MTU对IP协议的影响 ​编辑 MTU对UDP协议的影响 …

新产品上线前需要准备哪些产品文档呢

新产品上线前需要准备的产品文档非常重要&#xff0c;不仅有助于产品的开发过程中沟通和协作&#xff0c;而且对于后期的维护和升级也起到十分重要的作用。下面详细介绍新产品上线前需要准备哪些产品文档。 一、市场需求文档 市场需求文档&#xff08;Market Requirement Doc…

保姆级JAVA对接ChatGPT教程 使用 openai-gpt3-java

1. 前言 必须要有chatGTP 账号&#xff0c;如果需要测试账号可以关注公众号 疯狂的野猿 如果有chatGTP 账号就直接往下看。还需要一台外网服务器使用 nginx 代理来访问chatGTP 如果都没有&#xff0c;可以关注公众号联系作者。 还有笔者已经对接完成了&#xff0c;需要源码的关…

(电脑硬件)台式机主板音频端口功能详解

当你想给你的主机插上音响或者耳机时&#xff0c;你会发现主板上有6个接口&#xff0c;同样都是3.5mm接口&#xff0c;你知道该插哪个吗&#xff1f; 一般情况下&#xff0c;后置输入输出端口面板中&#xff0c;大多数的主板音频部分是彩色的。这一类颜色跟功能基本是固定的。当…

竟然支持在流程图、架构图中添加数学公式,安利一款纯免费的画图工具,真不错!

1. 简介 考虑到在绘图中需要添加数学表达式的场景&#xff0c;PDDON提供了LaTeX表达式编辑能力&#xff0c;可以在任何可以编辑的组件上启用LaTeX功能&#xff0c;使用LaTeX语法编写数学公式即可。 LaTeX表达式简介&#xff1a; LaTeX&#xff08;LATEX&#xff0c;音译“拉泰赫…

【偏门技巧】C语言编程实现对IPV4地址的合法性判断(使用正则表达式)

C语言编程实现对IPV4地址的合法性判断&#xff08;使用正则表达式&#xff09; 有了解过我的朋友&#xff0c;可能有点印象&#xff0c;我在N年前的博客中&#xff0c;就写了这个主题&#xff0c;当时确实是工作中遇到了这个问题。本想着等工作搞完之后&#xff0c;就把这个问题…