笔记2:cifar10数据集获取及pytorch批量处理

news2025/1/15 6:42:35

(1)cifar10数据集预处理

CIFAR-10是一个广泛使用的图像数据集,它由10个类别的共60000张32x32彩色图像组成,每个类别有6000张图像。
CIFAR-10官网
以下为CIFAR-10数据集data_batch_*表示训练集数据,test_batch表示测试集数据
在这里插入图片描述
预处理结果(将CIFAR-10保存为图片格式)
在这里插入图片描述

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: LIFEI
@time: 2024/5/8 15:00 
@file: 加载cifar10数据.py
@project: 深度学习(4):深度神经网络(DNN)
@describe: TEXT
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
import glob
import pickle
import numpy as np
import cv2 as
import os
#%% md
cifar10官网处理函数:
#%%
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
#%% md
利用上面的函数进行读取数据:
#%%
label = ["airplane","automobile", "bird","cat", 'deer',"dog","frog","horse","ship","truck"]  #标签矩阵
filepath = glob.glob("../../test_doucments/cifar-10-batches-py/data_batch_*") # 获取当前文件的路径,返回路径矩阵,获取test数据集时将data_batch——*改为test_batch*
write_path =["./train","./test"] #
print(filepath)
for file in filepath:
    if not file:
        print("空集出错")
    else:
        # print(file)
        data_dic = unpickle(file) # 将二进制表示形式转换回 Python 对象的反序列化过程,结果为字节型数据
        # print(data_dic.keys()) #此处的keys主要有b"data",b"labels",b"filenames"
        index = 0
        for im_data in data_dic[b"data"]:  # 遍历影像矩阵数据
            im_label = data_dic[b"labels"][index] # 赋值标签数据
            im_filename = data_dic[b"filenames"][index] # 赋值影像名字
            index +=1
            # print(f"图像的文件名为:{im_filename}\n",f"图像的所属标签为:{im_label}\n",f"图像的矩阵数据为:{im_data}\n")

            #开始存放数据
            im_label_name = label[im_label]
            im_data_data = np.reshape(im_data,(3,32,32)) # 将影像矩阵数据转换为图像形式

            # 由于需要opencv进行写出图像,因此需要转化通道
            im_data_data = np.transpose(im_data_data,(1,2,0))
            imgname = f"当前图像名称{im_label},所属标签{im_label_name}"
            cv.imshow(str( im_label_name),cv.resize(im_data_data,(500,500))) # 将显示时的图像变大,图像数据本身大小不变
            cv.waitKey(0)
            cv.destroyAllWindows()

            #创建文件夹
            for path in write_path:
                if not os.path.exists("{}/{}".format(path,im_label_name)): #查看存储路径中的文件夹是否存在
                    os.mkdir("{}/{}".format(path,im_label_name)) # 没有就创建文件
                else:
                    break
            cv.imwrite("{}/{}/{}".format(write_path[0],im_label_name,str(im_filename,'utf-8')),im_data_data)
            # #write_path[1]写出测试数据的时候将write_path[0]改为write_path[1]
#%% md
将cifar10数据转为图片格式并保存

(2)利用pytorch将图像转为张量数据

或是批量读取训练集和测试集数据
在这里插入图片描述

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: LIFEI
@time: 2024/5/8 15:00 
@file: 加载cifar10数据.py
@project: 深度学习(4):深度神经网络(DNN)
@describe: TEXT
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
# 导入库
import glob
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
import cv2 as cv
# DataLoader参考网址https://blog.csdn.net/sazass/article/details/116641511

from PIL import Image

label_name = ["airplane","automobile", "bird","cat", 'deer',"dog","frog","horse","ship","truck"]
label_list = {} # 创建一个字典用于存储标签和下标
index = 0
for name in label_name:  # 也可以采用for index,name in enumerate(label_name)
    label_list[name] = index # 字典的常规赋值操作
    index += 1

def default_loder(path):
    # return Image.open(path).convert("RGB") # 也可采用opencv读取
    img = cv.imread(path)
    return cv.cvtColor(img,cv.COLOR_BGR2RGB)


# 定义训练集数据的增强   下面的Compose表示拼接需要增强的操作
train_transform = transforms.Compose([
    transforms.RandomCrop(28,28), #进行随机裁剪为28*28大小
    transforms.RandomHorizontalFlip(), #垂直方向翻转
    transforms.RandomVerticalFlip(), #水平方向的翻转
    transforms.RandomRotation(90), #随机旋转90度
    transforms.RandomGrayscale(0.1), #灰度转化
    transforms.ColorJitter(0.3,0.3,0.3,0.3), #随机颜色增强
    transforms.ToTensor() #将数据转化为张量数据
])

# 定义pytorh的dataset类
class MyData(Dataset):
    def __init__(self,im_list,
                 transform = None,
                 loder = default_loder):     #初始化函数
        super(MyData,self).__init__() #初始化这个类

        # 获取图片的路径以及标签号
        images = []
        for item_data in im_list:
            # 注意下面这一步,split("\\")根据不同的操作系统会不相同,有的是"/"
            img_label_name = item_data.split("\\")[-2] #通过遍历每一个路径进行获取当前图片的文字标签
            images.append([item_data,label_list[img_label_name]])

        self.images = images
        self.tranform =transform
        self.loder = loder

    def __getitem__(self, index_num): # 此处的index_num是在训练的时候反复传进来的值
        img_path , img_label = self.images[index_num] #这里的
        img_data = self.loder(img_path)  # 这里用到了self.loder(path)==>default_loder(path)外置函数

        if self.tranform is not None: # 判断数据是否增强
            img_data = self.tranform(img_data)
        return img_data,img_label

    def __len__(self):
         return len(self.images)

train_list = glob.glob("./train/*/*.png") # glob.glob 获取改路径下的所有文件路径并返回为列表
test_list = glob.glob("./test/*/*.png")

train_dataset = MyData(train_list,transform = train_transform)
test_dataset = MyData(test_list,transform = transforms.ToTensor()) #测试集无需进行图像增强操作,直接转为张量

train_data_loder = DataLoader(dataset =train_dataset,
                              batch_size=6,
                              shuffle=True,
                              num_workers=4)
test_data_loder = DataLoader(dataset =test_dataset,
                              batch_size=6,
                              shuffle=False,
                              num_workers=4)
print(f"训练集的大小:{len(train_dataset)}")
print(f"测试集的大小:{len(test_dataset)}")

注:以上代码非原创,仅供个人记录学习笔记,若有侵权,请我联系删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1654006.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【第13章】spring-mvc之validator

文章目录 前言一、准备1. 引入库2. add.jsp3. show.jsp 二、代码部分1.实体类2. 控制器类3. 效果4. 展示 总结 前言 【第20章】spring-validator 虽然前面已经在spring介绍过,但是为了保证代码可用,还是会从头讲到尾,尽量把关键点列出来讲给…

Spring后端参数校验——自定义校验方式(validation)

文章目录 开发场景技术名词解释——Spring Validation自定义校验 技术细节小结1.实体参数校验2.自定义校验 完整代码 开发场景 业务场景:新增文章 基本信息 请求路径:/article 请求方式:POST 接口描述:该接口用于新增文章(发布文…

STC8增强型单片机开发

1.C51版本Keil环境搭建 下载地址是 Keil Product Downloads 选择C51进行下载: 2.STC环境添加 STC-ISP下载 进入stc官网 深圳国芯人工智能有限公司-工具软件 3.将STC添加到Keil中 打开stc-isp工具 按照图例点击按钮 选择keil的安装目录,以实际安装目…

【SpringMVC 】什么是SpringMVC(三)?基于springmvc的文件上传、基于springmvc的拦截器、基于springmvc的邮件发送

文章目录 SpringMVC第五章1、SpringMVC文件上传1、基本步骤1-2345-82、邮件发送1、基本步骤1-234-5567-8 简单邮件带附件的邮件第六章1、拦截器的使用使用步骤232、调度的使用基本步骤1-56-8调度规则3、shiro安全框架核心概念基本语法1、基于ini文件的认证**测视类**2、基于rea…

电商API接口:品牌为提升价格竞争力做定价参考

品牌为了提升价格竞争力,在进行产品定价时,可以从以下几个方面作为参考依据: 市场调研: 分析同类竞品在各大电商平台的均价、最高价和最低价,了解市场行情和消费者心理预期价位。 成本核算: 精确计算产…

力扣41. 缺失的第一个正数

Problem: 41. 缺失的第一个正数 文章目录 题目描述思路复杂度Code 题目描述 思路 1.将nums看作为一个哈希表,每次我们将数字n移动到nums[n - 1]的位置(例如数字1应该存在nums[0]处…),则在实际的代码操作中应该判断nums[i]与nums[nums[i] - 1]是否相等,若…

【揭秘!】我国土地管理的基本国策与基本国情,你了解多少?

在这片古老而又充满活力的土地上,每一寸土地都承载着历史的记忆和未来的希望。我国的土地管理政策,正是在基本国情的基础上,精心编织的一张保障国家和人民利益的大网。今天,就让我们一起揭开我国土地管理的基本国策和基本国情的神…

论文分享[cvpr2018]Non-local Neural Networks非局部神经网络

论文 https://arxiv.org/abs/1711.07971 代码https://github.com/facebookresearch/video-nonlocal-net 非局部神经网络 motivation:受计算机视觉中经典的非局部均值方法[4]的启发,非局部操作将位置的响应计算为所有位置的特征的加权和。 非局部均值方法 NLM&#…

【管理咨询宝藏96】企业数字化转型的中台战略培训方案

本报告首发于公号“管理咨询宝藏”,如需阅读完整版报告内容,请查阅公号“管理咨询宝藏”。 【管理咨询宝藏96】企业数字化转型的中台战略培训方案 【格式】PDF版本 【关键词】SRM采购、制造型企业转型、数字化转型 【核心观点】 - 数字化转型是指&…

C++可变参数接口,批量写入和读取参数值的设计和实现

相关文章系列 手撕代码: C实现数据的序列化和反序列化-CSDN博客 目录 1.需求 2.问题分析 3.解决方案 3.1.类型抽象 3.2.参数配置 3.3.参数读取 1.需求 最近在做项目的时候,我们小组做的模块和另外一个小组做的模块的交付通过动态库接口的方式,他们…

模糊的图片文字,OCR能否正确识别?

拍照手抖、光线不足等复杂的环境下形成的图片都有可能会造成文字模糊,那这些图片文字对于OCR软件来说,是否能否准确识别呢? 这其中的奥秘,与文字的模糊程度紧密相连。想象一下,如果那些文字对于我们的双眼来说&#x…

【Android】源码解析Activity的结构分析

源码解析Activity的结构分析 目录 1、Activity、View、Window有什么关联?2、Activity的结构构建流程3 源码解析Activity的构成 3.1 Activity的Attach方法3.2 Activity的OnCreate 4、WindowManager与View的关系总结 1、一个Activity对应几个WindowManage&#xff0…

Linux cmake 初窥【3】

1.开发背景 基于上一篇的基础上,已经实现了多个源文件路径调用,但是没有库的实现 2.开发需求 基于 cmake 的动态库和静态库的调用 3.开发环境 ubuntu 20.04 cmake-3.23.1 4.实现步骤 4.1 准备源码文件 基于上个试验的基础上,增加了动态库…

pycharm中导入rospy(ModuleNotFoundError: No module named ‘rospy‘)

1. ubuntu安装对应版本ros ubuntu20.04可参考: https://wiki.ros.org/cn/noetic/Installation/Ubuntuhttps://zhuanlan.zhihu.com/p/515361781 2. 安装python3-roslib sudo apt-get install python3-roslib3.在conda环境中安装rospy pip install rospkg pip in…

4.26.7具有超级令牌采样功能的 Vision Transformer

Vision Transformer在捕获浅层的局部特征时可能会受到高冗余的影响。 在神经网络的早期阶段获得高效且有效的全局上下文建模: ①从超像素的设计中汲取灵感,减少了后续处理中图像基元的数量,并将超级令牌引入到Vision Transformer中。 超像素…

Python数据分析之绘制相关性热力图的完整教程

前言 文章将介绍如何使用Python中的Pandas和Seaborn库来读取数据、计算相关系数矩阵,并绘制出直观、易于理解的热力图。我们将逐步介绍代码的编写和执行过程,并提供详细的解释和示例,以便读者能够轻松地跟随和理解。 大家记得需要准备以下条…

谷歌十诫 Ten things we know to be true, Google‘s Core values

雷军曾经要求金山人人都必须能背谷歌十诫 我们所知的十件事 当谷歌刚成立几年时,我们首次写下了这“十件事”。我们时不时回顾这个列表,看看它是否仍然适用。我们希望它仍然适用——你也可以要求我们做到这点。 1. Focus on the user and all else wi…

视频号小店常见问题合集,准备做视频号小店的,赶紧收藏起来

大家好,我是电商花花。 现在视频号小店在电商行业中越来越受欢迎,视频号背后依靠者微信和腾讯强大的流量,拥有着超强的流量和市场,在今年的电商市场中有引起了一个热门话题,作为一个有流量有市场的新兴创业自然是吸引…

Springboot+vue项目人事管理系统

开发语言:Java 开发工具:IDEA /Eclipse 数据库:MYSQL5.7 应用服务:Tomcat7/Tomcat8 使用框架:springbootvue JDK版本:jdk1.8 文末获取源码 系统主要分为管理员和普通用户和员工三部分,主要功能包括个人中心,普通用户管理&…

【最经典的79个】软件测试面试题(内含答案)备战“金三银四”

001.软件的生命周期(prdctrm) 计划阶段(planning)-〉需求分析(requirement)-〉设计阶段(design)-〉编码(coding)->测试(testing)->运行与维护(running maintrnacne) 测试用例 用例编号 测试项目 测试标题 重要级别 预置条件 输入数据 执行步骤 预期结果 0002.问&…