实战ResNet:CIFAR-10数据集分类

news2025/2/28 21:22:52

本节将使用ResNet实现CIFAR-10数据集的分类。

7.2.1  CIFAR-10数据集简介

CIFAR-10数据集共有60 000幅彩色图像,这些图像是32×32像素的,分为10类,每类6 000幅图,如图7-9所示。这里面有50 000幅图用于训练,构成了5个训练批,每一批10 000幅图;另外,10 000幅用于测试,单独构成一批。测试批的数据取自100类中的每一类,每一类随机取1000幅。抽剩下的就随机排列组成训练批。注意,一个训练批中的各类图像的数量并不一定相同,总的来看,训练批每一类都有5 000幅图。

图7-9  CIFAR-10数据集

读者自行搜索CIFAR-10数据集下载地址,进入下载页面后,选择下载方式,如图7-10所示。

图7-10  下载方式

由于PyTorch 2.0采用Python语言编程,因此选择Python Version的版本下载。下载之后解压缩,得到如图7-11所示的文件。

图7-11  得到的文件

data_batch_1~data_batch_5是划分好的训练数据,每个文件中包含10 000幅图片,test_batch是测试集数据,也包含10 000幅图片。

读取数据的代码如下:

import pickle

def load_file(filename):

    with open(filename, 'rb') as fo:

        data = pickle.load(fo, encoding='latin1')

    return data

首先定义读取数据的函数,这几个文件都是通过 pickle 产生的,所以在读取的时候也要用到这个包。返回的data是一个字典,先来看这个字典里面有哪些键。

data = load_file('data_batch_1')

print(data.keys())

输出结果如下:

dict_key3(['batch_label', 'labels', 'data', 'filenames'])

具体说明如下。

  1. batch_label:对应的值是一个字符串,用来表明当前文件的一些基本信息。
  2. labels:对应的值是一个长度为10 000的列表,每个数字取值范围为0~9,代表当前图片所属的类别。
  3. data:10000×3072的二维数组,每一行代表一幅图片的像素值。
  4. filenames:长度为10 000的列表,里面每一项是代表图片文件名的字符串。

完整的数据读取函数如下:

【程序7-1】

i

import pickle
import numpy as np
import os

def get_cifar10_train_data_and_label(root=""):
    def load_file(filename):
        with open(filename, 'rb') as fo:
            data = pickle.load(fo, encoding='latin1')
        return data

    data_batch_1 = load_file(os.path.join(root, 'data_batch_1'))
    data_batch_2 = load_file(os.path.join(root, 'data_batch_2'))
    data_batch_3 = load_file(os.path.join(root, 'data_batch_3'))
    data_batch_4 = load_file(os.path.join(root, 'data_batch_4'))
    data_batch_5 = load_file(os.path.join(root, 'data_batch_5'))
    dataset = []
    labelset = []
    for data in [data_batch_1, data_batch_2, data_batch_3, data_batch_4, data_batch_5]:
        img_data = (data["data"])
        img_label = (data["labels"])
        dataset.append(img_data)
        labelset.append(img_label)
    dataset = np.concatenate(dataset)
    labelset = np.concatenate(labelset)
    return dataset, labelset

def get_cifar10_test_data_and_label(root=""):
    def load_file(filename):
        with open(filename, 'rb') as fo:
            data = pickle.load(fo, encoding='latin1')
        return data

    data_batch_1 = load_file(os.path.join(root, 'test_batch'))
    dataset = []
    labelset = []
    for data in [data_batch_1]:
        img_data = (data["data"])
        img_label = (data["labels"])
        dataset.append(img_data)
        labelset.append(img_label)
    dataset = np.concatenate(dataset)
    labelset = np.concatenate(labelset)
    return dataset, labelset

def get_CIFAR10_dataset(root=""):
    train_dataset, label_dataset = get_cifar10_train_data_and_label(root=root)
    test_dataset, test_label_dataset = get_cifar10_train_data_and_label(root=root)
    return train_dataset, label_dataset, test_dataset, test_label_dataset

if __name__ == "__main__":
    train_dataset, label_dataset, test_dataset, test_label_dataset = get_CIFAR10_dataset(root="../dataset/cifar-10-batches-py/")

    train_dataset = np.reshape(train_dataset,[len(train_dataset),3,32,32]). astype(np.float32)/255.
    test_dataset = np.reshape(test_dataset,[len(test_dataset),3,32,32]). astype(np.float32)/255.
    label_dataset = np.array(label_dataset)
    test_label_dataset = np.array(test_label_dataset)

其中的root是下载数据解压后的目录参数,os.join函数将其组合成数据文件的位置。最终返回训练文件和测试文件以及它们对应的label。需要说明的是,提取出的文件数据格式为[-1,3072],因此需要重新对数据维度进行调整,使之适用于模型的输入。

7.2.2  基于ResNet的CIFAR-10数据集分类

前面对ResNet模型以及CIFAR-10数据集进行了介绍,本小节开始使用前面定义的ResNet模型进行分类任务。

上一节已经介绍了CIFAR-10数据集的基本构成,并讲解了ResNet的基本模型结构,接下来直接导入对应的数据和模型即可。完整的模型训练如下:

import torch
import resnet
import get_data
import numpy as np

train_dataset, label_dataset, test_dataset, test_label_dataset = get_data.get_CIFAR10_dataset(root="../dataset/cifar-10-batches-py/")

train_dataset = np.reshape(train_dataset,[len(train_dataset),3,32,32]). astype(np.float32)/255.
test_dataset = np.reshape(test_dataset,[len(test_dataset),3,32,32]). astype(np.float32)/255.
label_dataset = np.array(label_dataset)
test_label_dataset = np.array(test_label_dataset)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = resnet.resnet18()               #导入Unet模型
model = model.to(device)                #将计算模型传入GPU硬件等待计算
model = torch.compile(model)           #PyTorch 2.0的特性,加速计算速度
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数
loss_fn = torch.nn.CrossEntropyLoss()

batch_size = 128
train_num = len(label_dataset)//batch_size
for epoch in range(63):
    train_loss = 0.
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size
        x_batch = torch.from_numpy(train_dataset[start:end]).to(device)
        y_batch = torch.from_numpy(label_dataset[start:end]).to(device)
        pred = model(x_batch)
        loss = loss_fn(pred, y_batch.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()  # 记录每个批次的损失值

    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == y_batch).type(torch.float32).sum().item() / batch_size
    
    #2048可根据读者GPU显存大小调整
        test_num = 2048
    x_test = torch.from_numpy(test_dataset[:test_num]).to(device)
    y_test = torch.from_numpy(test_label_dataset[:test_num]).to(device)
    pred = model(x_test)
    test_accuracy = (pred.argmax(1) == y_test).type(torch.float32).sum().item() / test_num
    print("epoch:",epoch,"train_loss:", round(train_loss,2), ";accuracy:",round(accuracy,2),";test_accuracy:",round(test_accuracy,2))

在这里使用训练集数据对模型进行训练,之后使用测试集数据对其输出进行测试,训练结果如下:

可以看到,经过5轮训练后,模型在训练集的准确率达到0.99,而在测试集的准确率也达到了0.98,这是一个较好的成绩,模型的性能达到较高水平。

其他层次的模型请读者自行尝试,根据不同的硬件设备,模型的参数和训练集的batch_size都需要做出调整,具体数值读者可以根据需要进行设置。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/990075.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

概念:推理 训练 模型

训练 训练是通过从已有的数据中学习到某种能力; 推理 推理是简化并使用该能力,使其能快速、高效地对未知的数据进行操作,以获得预期的结果。 模型 训练是计算密集型操作,模型一般都需要使用大量的数据来进行训练,通…

婚恋相亲交友红娘小程序源码开发搭建方法

目前婚恋市场基本处于兵家必争之地,从一二线城市到四五线城市单身男女多,传统婚恋相亲很多已经不满足现在年轻人市场,因此我们推出婚恋相亲交友小程序。 注意:小程序过审需ICP经营许可证。 程序支持多端:H5端、小程序…

go开发之个微机器人的二次开发

请求URL: http://域名/addRoomMemberFriend 请求方式: POST 请求头Headers: Content-Type:application/jsonAuthorization:login接口返回 参数: 参数名必选类型说明wId是String登录实例标识chatRoom…

国家网络安全周2023时间是什么时候?有什么特点?谁举办的?

国家网络安全周2023时间是什么时候? 2023年国家网络安全宣传周将于9月11日至17日在全国范围内统一开展。其中开幕式等重要活动将在福建省福州市举行。今年网安周期间,除开幕式外,还将举行网络安全博览会、网络安全技术高峰论坛、网络安全微视…

计算机竞赛 基于深度学习的动物识别 - 卷积神经网络 机器视觉 图像识别

文章目录 0 前言1 背景2 算法原理2.1 动物识别方法概况2.2 常用的网络模型2.2.1 B-CNN2.2.2 SSD 3 SSD动物目标检测流程4 实现效果5 部分相关代码5.1 数据预处理5.2 构建卷积神经网络5.3 tensorflow计算图可视化5.4 网络模型训练5.5 对猫狗图像进行2分类 6 最后 0 前言 &#…

SpringMVC的增删改查的案例

目录 前言: 1.总体思路: 2.前期准备 3.前台页面 前言: 我们今天来学习研究SpringMVC的增删改查,希望这篇博客能够帮助正在学习,工作的你们!!! 1.总体思路: 首先我们得…

在linux上挂载windows共享目录

挂载要求 非root用户(普通用户)能够读写windows共享目录,比如查看文件、创建文件、修改文件、删除文件 # 让普通用户也可以正常读写 uidvalue and gidvalue Set the owner and group of the root of the file system (default: uidgid0, bu…

《算法竞赛·快冲300题》每日一题:“附近的牛”

《算法竞赛快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。 所有题目放在自建的OJ New Online Judge。 用C/C、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。 文章目录 题目描述题解C代码Java代码Python代码 “ 附…

学习笔记——Java入门第三季

1.1 Java异常简介 异常:有异于常态,和正常情况不一样,有错误出现,阻止当前方法或作用域。 异常处理:将出现的异常提示给编程人员与用户,使原本将要中断的程序继续运行或者退出。并且能够保存数据和释放资源…

独家!网络机顶盒什么牌子好?热门网络电视机顶盒排名TOP5

电视机搭配网络机顶盒看剧是很多人的消遣方式,不过在挑选网络机顶盒时很多人踩过雷,像卡顿、死机、广告多等问题频发,近来很多人咨询我网络机顶盒什么牌子好,我以销量为基础盘点了网络电视机顶盒排名,哪些品牌最受欢迎…

OpenResume简历解析官方技术文档(翻译)

OpenResume简历解析官方技术文档(翻译) 本文是对OpenResume建立解析器官方技术文档《Resume Parser Playground》的翻译。 相关连接: OpenResume官网 OpenResume简历解析器的官方地址 OpenResume的Github 简历解析测试环境 该测试环境展示了 OpenResume 简历…

新型人工智能技术让机器人的识别能力大幅提升

原创 | 文 BFT机器人 在德克萨斯大学达拉斯分校的智能机器人和视觉实验室里,一个机器人在桌子上移动一包黄油玩具。通过达拉斯分校计算机科学家团队开发的新系统,机器人每推动一次,就能学会识别物体。 新系统允许机器人多次推动物体&#xf…

后端/DFT/ATPG/PCB/SignOff设计常用工具/操作/流程及一些文件类型

目录 1.PD/DFT常用工具及流程 1.1 FC和ICC2 1.2 LC (Library compiler) 1.3 PrimeTime 1.4 Redhawk与PA 1.5 Calibre和物理验证PV 1.6 芯片设计流程 2.后端、DFT、ATPG的一些常见文件 2.1 LEF和DEF 2.2 ATPG的CTL和STIL 2.3 BSDL 2.4 IPXCT 3.PCB设计的一些工作和工…

宏定义天坑记录

宏定义天坑记录 事件原委与推理过程 在编译一个使用了Protobuf的项目时出现了如下报错 [ybVM-8-7-centos boost_searcher]$ make g -o http_server http_server.cc data/raw_html.pb.cc -stdc11 -lboost_system -lboost_filesystem -lpthread -ljsoncpp -lprotobuf In file…

JAVA学习-IDEA创建父子项目

JAVA培训-创建父子项目 一、创建父模块 1、new一个新项目,如下图所示: 2、由于这里是父级Maven项目,所以什么都不用选,只需要将SpringBoot版本选成稳定的版本即可。后面带(SNAPSHOT),代表版本…

如何理解focal loss/GIOU(yolo改进损失函数)

Focal Loss的公式如下: Focal Loss -α(1 - p)^γ * log 其中,α是正样本的调节因子,γ是控制难易样本权重分配的参数,p是模型预测的概率值。 根据公式,可以看出当样本属于困难样本时,(1 - p) 的值较大…

如何全方位了解购房信息?VR全景技术为您解答

在存量房贷利率下调政策下,房子逐渐回归到居住属性,在对于有购房刚需的客户来说,无疑是一大利好政策,此类客户有着强烈的看房购房需求,那么该如何全方位的了解购房信息呢? 房企通过VR全景展示、3D样板房、V…

论文阅读 (100):Simple Black-box Adversarial Attacks (2019ICML)

文章目录 1 概述1.1 要点1.2 代码1.3 引用 2 背景2.1 目标与非目标攻击2.2 最小化损失2.3 白盒威胁模型2.4 黑盒威胁模型 3 简单黑盒攻击3.1 算法3.2 Cartesian基3.3 离散余弦基3.4 一般基3.5 学习率 ϵ \epsilon ϵ3.6 预算 1 概述 1.1 要点 题目:简单黑盒对抗攻…

Vue中的图标

Vue中的图标 https://iconpark.oceanengine.com/official 官方教程&#xff1a;icon-park/vue - npm 1.IconPark 2.基本使用 下载 yarn add icon-park/vue --save 启动 yarn run serve 项目中引用 <script> import { TableFile } from icon-park/vue; export defa…

微信小程序遇到的一些问题及解决方法(设备安装)

微信小程序遇到的一些问题及解决方法 1、[js将字符串按照换行符分隔成数组](https://blog.csdn.net/pgzero/article/details/108730175)2、[vue byte数组](https://www.yzktw.com.cn/post/1202765.html)3、使用vant-weapp的文件上传capture"camera" 无法直接调用摄像…