使用make_grid多批次显示网格图像(使用CIFAR数据集介绍)

news2024/11/15 15:50:46

背景介绍

在机器学习的训练数据集中,我们经常使用多批次的训练来实现更好的训练效果,具体到cv领域,我们的训练数据集通常是[B,C,W,H]格式,其中,B是每个训练批次的大小,C是图片的通道数,如果是1则为灰度图像,如果是3则为彩色图像,W,H分别是图像的像素宽和像素高,在torchvision中,为我们提供了方便的方法显示多通道的图像显示成网格的格式

数据集介绍

这里使用机器学习中经典的CIFAR10数据集,具体可以参考博客CIFAR-10数据集详解与可视化_cifar10数据集可视化-CSDN博客

数据集读取

我们假设已经下载好CIFAR数据集保存在本地计算机的路径中,可以通过CIFAR函数进行读取

# 依赖的库环境
import torchvision
import torch
from torchvision.datasets import CIFAR10
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor,Compose,Resize

读取CIFAR数据集中的训练数据集

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())

这里的转换方式是使用简单的ToTensor()将图片格式转换成经典的[C,W,H]格式,方便后续的可视化操作

此时我们可以简单地对数据集中的第一张图片进行可视化

img,label = train_dataset[0]
plt.imshow(img.permute(1,2,0))
plt.show()

构造批次数据集

如何构造批次的训练数据集呢?可以通过DataLoader的方式获得批次生成器,也可以通过torch.stack函数自定义地构成

cifar_img = torch.stack([train_dataset[i][0] for i in range(4)], dim=0)

这里使用列表推导式获得前4张图片组成的数据列表,通过torch.stack指定dim=0进行多个数据的堆加,这里需要注意的是,stack是在指定的维度新增一个维度进行多矩阵的合并,cat是在指定的维度上合并多个矩阵而不增加新的维度

cat与stack的区别

我们来具体看看两者的区别

cat_img = torch.cat([train_dataset[i][0] for i in range(4)],dim=0)
stack_img = torch.stack([train_dataset[i][0] for i in range(4)],dim=0)
print(f'cat_shape:{cat_img.shape}')
print(f'stack_shape:{stack_img.shape}')
cat_shape:torch.Size([12, 32, 32])
stack_shape:torch.Size([4, 3, 32, 32])

train_dataset[i][0]的形状为[3,32,32],当使用cat时,直接在第一维度上进行累加获得[12,32,32];使用stack时,在指定的第一维度上新增一个维度进行累加,有[4,3,32,32]

进行网格化显示

使用torchvision.utils.make_grid函数进行网格格式转换

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
cifar_img = torch.stack([train_dataset[i][0] for i in range(4)], dim=0)
img_grid = torchvision.utils.make_grid(cifar_img,nrow=4,normalize=True,pad_value=0.9,padding=1)
plt.imshow(img_grid.permute(1,2,0))
plt.show()

nrow是指定每一行的图片的数量,这里只有四张图片,所以是4,默认nrow=8

normalize是对图片数据进行标准化

pad_value是对图片间隔之间的像素进行填充的像素值

padding是指定图片之间的像素间隔数量

同时显示100张图片

train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
cifar_img = torch.stack([train_dataset[i][0] for i in range(100)], dim=0)
img_grid = torchvision.utils.make_grid(cifar_img,nrow=10,normalize=True,pad_value=0.9,padding=1)
plt.imshow(img_grid.permute(1,2,0))
plt.show()

批次图片可视化

我们对使用DataLoader生成的批次数据进行可视化

if __name__=='__main__':
    train_dataset = CIFAR10(r'D:\deep_learning\12_16\data', train=True, download=False,transform=ToTensor())
    trainloader = DataLoader(train_dataset,shuffle=True,batch_size=128,num_workers=8)
    trainloader = iter(trainloader)
    trainloader_first_batch = next(trainloader)

    imgs,labels = trainloader_first_batch
    batch_grid = torchvision.utils.make_grid(imgs)
    plt.imshow(batch_grid.permute(1,2,0))
    plt.show()

对训练数据集更好的了解是为了在训练的时候获得更好的模型性能,欢迎大家讨论交流~


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1414596.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

告别无法访问的Github

告别无法访问的Github 最近在使用github的时候又登不上去了,挂着VPN都没用 但是自己很多项目都存在github,登不上去那不得损失很大 所以一行必须整点儿特殊手段来访问,顺便分享一下 1.加速器 网上很多解决方案都是在分享各种加速器来登陆…

【Vue】1-2、Webpack 中的插件

一、Webpack 插件的作用 通过安装和配置第三方的插件,可以拓展 webpack 的能力,从而让 webpack 用起来更方便。 二、两个常用插件 1)webpack-dev-server 类似于 node.js 使用的 nodemon 工具 每当修改了源代码,webpack 会自动…

Python算法题集_接雨水

本文为Python算法题集之一的代码示例 题目42:接雨水 说明:给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水 示例 1: 输入:height [0,1,0,2,1,0,1,3,2,1,2,1]…

前端工程化基础(二):前端包管理工具npm/yarn/cnpm/npx/pnpm

前端包管理工具 代码共享方案 创建自己的官网, 将代码放到官网上面将代码提交到GitHub上面,负责让使用者下载将代码提交到npm registry上面 下载比较方便,使用npm install xxx即可下载相应的代码npm管理的包 npm配置文件 主要用于存储项目…

一篇文章带你了解C++中隐含的this指针

文章目录 一、this指针的引出二、this指针的特性【面试题】 一、this指针的引出 我们先来定义一个日期类Date,下面这段代码执行的结果是什么呢? class Date { public:void Init(int year, int month, int day){_year year;_month month;_day day;}v…

高级自动驾驶LiDAR反射白板

随着自动驾驶技术的不断发展,激光雷达作为其核心传感器之一,正逐渐成为业界关注的焦点。激光雷达通过发射激光束并测量反射回来的时间来获取周围环境的三维信息。为了确保激光雷达能够准确、稳定地工作,对其进行标定是必不可少的环节。本文将…

开发微信小程序,将图片下载到相册的方法,saveImageToPhotosAlbum怎么用

在开发微信小程序的时候,经常能看到小程序里面有下载按钮,如何将小程序中的图片下载到手机相册中那,下面给大家说一下怎么做,代码如何去写。 一、到微信小程序后台开启“用户隐私保护指引” 1.进入小程序后台,侧拉拉到…

JSP在线阅读系统myeclipse定制开发SQLServer数据库网页模式java编程jdbc

一、源码特点 JSP 小说在线阅读系统是一套完善的web设计系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库 ,系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发,数据库为SQLServer2008&#…

【JavaWeb】监听器 Listener

文章目录 一、监听器是什么二、监听器的分类三、监听器的六个主要接口3.1 application域监听器测试代码 :3.1.1 定义监听器3.1.2 定义触发监听器的代码 3.2 session域监听器测试代码 :3.2.1 定义监听器3.2.2 定义触发监听器的代码 3.3 request域监听器测试代码:3.3.…

大创项目推荐 题目:基于卷积神经网络的手写字符识别 - 深度学习

文章目录 0 前言1 简介2 LeNet-5 模型的介绍2.1 结构解析2.2 C1层2.3 S2层S2层和C3层连接 2.4 F6与C5层 3 写数字识别算法模型的构建3.1 输入层设计3.2 激活函数的选取3.3 卷积层设计3.4 降采样层3.5 输出层设计 4 网络模型的总体结构5 部分实现代码6 在线手写识别7 最后 0 前言…

静态时序分析:传播延迟与转换时间

相关阅读 静态时序分析https://blog.csdn.net/weixin_45791458/category_12567571.html?spm1001.2014.3001.5482 一、传播延迟 在数字集成电路中,一个门的传播延迟(Propagation Time)定义为从输入的转变发生到输出转变发生的时间&#xff0…

IDEA创建一个web项目部署到tomcat

在 IntelliJ IDEA 中创建并部署一个 Web 项目到 Tomcat,您可以按照以下步骤操作: 1.安装 IntelliJ IDEA: 如果尚未安装 IntelliJ IDEA,请从官方网站 JetBrains 下载并安装 IntelliJ IDEA。 2.启动 IntelliJ IDEA: 打开 IntelliJ IDEA,并确保您已经安装了合适的插件,例如…

Cesium反向遮罩指定区域挖空---Primitive、PolygonGeometry、PolylineGeometry实现

PolylineRegionalExcavationFun2() {import("./data/安徽省.json").then((res) => {console.log(`res`, res);let features = res.features;let positionArray = [];let borderLinePositionArray = [];// 获取区域的经纬度坐标if (features[0]?.geometry?.coord…

一篇带你学会Git基础操作

📙 作者简介 :RO-BERRY 📗 学习方向:致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 📒 日后方向 : 偏向于CPP开发以及大数据方向,欢迎各位关注,谢谢各位的支持 目录 1.认识⼯作区、暂存区…

CMU15-445 Project0

CMU14445 Task #1 - Copy-On-Write Trie Get()思路: 获取根节点指针,顺着key逐字符往下找节点,最后根据题意可以使用dynamic_cast检查是否是TrieNodeWithValue(dynamic_pointer_cast也可以),以下为两者用法&#xff1…

OpenHarmony关系型数据库

1 概述 关系型数据库(Relational Database, 以下简称RDB)是一种基于关系模型来管理数据的数据库,是在SQLite基础上提供一套完整的对本地数据库进行管理的机制,为开发者提供无需编写原生SQL语句即可实现数据增、删、改、查等接口,同时开发者也…

css设置不可点击

文章目录 一、前言二、MDN三、使用四、注意五、总结六、最后 一、前言 在网页开发中,经常会遇到一种情况,就是需要将某个元素的点击事件屏蔽,使其在用户点击时没有任何反应。这时候,我们可以通过CSS的pointer-events属性设置为no…

Jmeter接口测试总结

🍅 视频学习:文末有免费的配套视频可观看 🍅 关注公众号【互联网杂货铺】,回复 1 ,免费获取软件测试全套资料,资料在手,涨薪更快 Jmeter介绍&测试准备 Jmeter介绍:Jmeter是软件…

PositiveSSL多域名通配符证书买一年送一月

SSL数字证书是一种安全协议,用于在网络通信中提供加密和身份验证服务,是保护网站安全的重要手段之一。PositiveSSL是Sectigo旗下的子品牌,经营着各种SSL证书,例如,单域名SSL证书、多域名SSL证书、通配符SSL证书和多域名…

Java通过模板替换实现excel的传参填写

以模板为例子 将上面$转义的内容替换即可 package com.gxuwz.zjh.util;import org.apache.poi.ss.usermodel.*; import org.apache.poi.xssf.usermodel.XSSFWorkbook; import java.io.*; import java.util.HashMap; import java.util.Map; import java.io.IOException; impor…