pytorch基础模块:Tensorboard、Dataset、Transforms、Dataloader

news2024/9/20 8:14:10

Tensorboard、Dataset、Transforms、Dataloader

该文档主要参考【土堆】的视频教程:pytorch入门教程–土堆

一、Tensorboard

安装tensorboardpip install tensorboard

使用步骤

  • 引入相关库:from torch.utils.tensorboard import SummaryWriter
  • 构建SummaryWriter对象:writer = SummaryWriter(log_dir="logs")
    • 在工程目录下创建一个名为logs的文件夹,用于存放Tensorboard绘图所用的文件
  • 打开tensorboard
    • 命令行执行:tensorboard --logdir=logs
      • 如果有错误,使用logs的绝对地址
    • 点击链接,即可查看

[常用函数]

add_scalaradd_imagesadd_graph

1.1、add_scalar

功能:添加标量数据(例如记录训练epoch及对应的loss)

常用参数:

  • 标题:tag
  • 标量数据(Y轴):scalar_value
  • 计步数据(X轴):global_step
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="logs")
for i in range(100):
    writer.add_scalar(tag="y=3x", scalar_value=3 * i, global_step=i)

注意:使用同一个SummaryWriter对象且tag相同时,会绘制在同一幅图上,为避免该情况可以删除logs中的内容,或者每次都新建文件夹(log_dir="logs_1"

图片如下(左图为只绘制一次的结果y=3x,右图为在同一幅图上分别绘制y=2x以及y=10x的结果)
在这里插入图片描述

在这里插入图片描述

1.2、add_image

功能:添加图片数据(例如记录每个batch的输入图片)

常用参数:

标题:tag

图片数据:img_tensor(一般是torch.Tensor或者numpy.array类型)

计步数据:global_step

图片格式:dataformats(CHW, HWC, HW, WH等,默认为'CHW')

from torch.utils.tensorboard import SummaryWriter
import os
import cv2


project_path = os.getcwd()
file_name_1 = 'dog_1.png'
file_name_2 = 'dog_2.png'

file_path = os.path.join(project_path, r'data\dog')
full_file_path_1 = os.path.join(file_path, file_name_1)
full_file_path_2 = os.path.join(file_path, file_name_2)
# 使用cv2(即OpenCV库)读取图片时,图片通常是以HWC(高度、宽度、通道)格式存储的,并且每个像素的颜色值(对于RGB图像)都是0到255之间的整数
image_data_1 = cv2.imread(full_file_path_1)
image_data_2 = cv2.imread(full_file_path_2)
# 将图片从BGR转换为RGB(因为cv2默认通道顺序为BGR,使用PIL读取图片的通道顺序为RGB):如果不进行调整,则图片颜色会失真
image_data_1 = cv2.cvtColor(image_data_1, cv2.COLOR_BGR2RGB)
image_data_2 = cv2.cvtColor(image_data_2, cv2.COLOR_BGR2RGB)

writer = SummaryWriter(log_dir="logs")
writer.add_image(tag='dog', img_tensor=image_data_1, global_step=0, dataformats='HWC')
writer.add_image(tag='dog', img_tensor=image_data_2, global_step=1, dataformats='HWC')
writer.close()

图片如下(通过拖动进度条可以查看不同step对应的图片)
在这里插入图片描述
在这里插入图片描述

1.3、add_graph

功能:添加模型结构数据(例如记录神经网络的结构)

常用参数:

模型:model(可以构建自己的模型或者使用公开的经典模型,例如VGG-16

模型输入:input_to_model(要求图片数据是Tensor类型)

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
import cv2
from torchvision import transforms


if __name__ == '__main__':
    # 使用内置的分类模型VGG-16(后续会更新如何搭建模型的文章)
    vgg_16 = torchvision.models.vgg16(progress=False)
    print(vgg_16)

    image = cv2.imread('data\\dog\\dog_1.png')
    # 图片类型转换为Tensor并调整尺寸为224*224(vgg16的标准输入尺寸)
    image = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])])(image)
    # 添加一个批次维度(模型通常期望输入具有批次维度),使得形状从(C, H, W)变为(1, C, H, W),其中1表示批次大小。
    image = torch.unsqueeze(image, 0)

    writer = SummaryWriter('model')
    # image提供模型在前向传播过程中所需的输入数据,TensorBoard据此生成模型的计算图
    writer.add_graph(model=vgg_16, input_to_model=image)

    writer.close()

在这里插入图片描述
在这里插入图片描述

二、Dataset

2.1、使用公开数据集

  • 常用的数据集:MNISTCIFAR10

  • 这些封装好的数据集都继承了torch.utils.data中的Dataset类,该类有两个重要的方法:getitem()len()

  • 可以通过参数transform以及target_transform在加载数据时进行实时的数据增强操作(如旋转、裁剪、缩放等);

    • 对图像数据的增强操作详见章节三Transforms
  • 可以通过继承Dataset类并重写getitem()len()方法创建自己的数据集类(使用自己的数据)

from torchvision import datasets


if __name__ == '__main__':
    # 指定数据集路径(下载好的数据集会自动解压到路径下)
    data_path = 'common_dataset'
    # train=True表示为训练集,download=True表示下载数据集(若已经下载好则自动加载本地数据集)
    # 若在线下载速度慢,可进入CIFAR10类中,直接通过数据集的下载链接下载(下载好放在data_path下即可)
    train_data = datasets.CIFAR10(root=data_path, train=True, download=True)
    test_data = datasets.CIFAR10(root=data_path, train=False, download=True)

    # 打印数据集所包含的数据个数
    print(len(train_data), len(test_data))

    # 获取第一个数据的图片(PIL Image类型)及标签(类别)
    img, label = train_data[0]
    # 打印类别索引及真实的类别
    print(label, train_data.classes[label])
    img.show()

2.2、使用自己的数据

from torch.utils.data import Dataset
import cv2
import os


class MyDataset(Dataset):
    def __init__(self, data_path, label):
        # super.__init__()
        self.data_path = data_path
        self.label = label
        self.full_path = os.path.join(self.data_path, self.label)
        self.images_name = os.listdir(self.full_path)

    def __getitem__(self, item):
        image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))
        # BGR转换为RGB,不然会失真
        image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
        return image_data, self.label

    def __len__(self):
        return len(self.images_name)
    

if __name__ == '__main__':
    data_path = os.path.join(os.getcwd(), 'data')
    label = 'dog'

    dataset_instance = MyDataset(data_path, label)
    print(len(dataset_instance))
    image, label = dataset_instance[0]
    print(image.shape, label)
    print(type(image))

三、Transforms

Transforms是用于处理图片的库,内置的类基本可以满足图片处理的需求,例如图片类型转换(PIL Imagendarraytensor)、尺寸调整、裁剪等

  • 若没有torchvision则需要先安装:pip install torchvision

  • [常用功能(类)]

    ToTensorNormalizeResizeRandomCropCompose

  • 使用方式:根据需求选择类,创建类的实例,使用类的实例完成图片处理

3.1、ToTensor

  • 功能:将PIL Imagendarray类型的图片转换为Tensor类型(Convert a PIL Image or ndarray to tensor and scale the values accordingly

  • 输入:PIL Imagendarray类型的图片(PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]

  • 输出:Tensor类型,shapeCHW,每个元素均为[0.0, 1.0]之间的数(torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import cv2
import os

# 创建Dataset的子类
class MyDataset(Dataset):
    def __init__(self, data_path, label):
        # super.__init__()
        self.data_path = data_path
        self.label = label
        self.full_path = os.path.join(self.data_path, self.label)
        self.images_name = os.listdir(self.full_path)

    def __getitem__(self, item):
        image_data = cv2.imread(os.path.join(self.full_path, self.images_name[item]))
        # BGR转换为RGB
        image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
        return image_data, self.label

    def __len__(self):
        return len(self.images_name)


if __name__ == '__main__':
    data_path = os.path.join(os.getcwd(), 'data')
    label = 'dog'
	
    # 创建dataset子类实例,用于读取图片
    dataset_instance = MyDataset(data_path, label)
    # 输出图片数量
    print(len(dataset_instance))
    image, label = dataset_instance[0]
    # 根据索引获取图片
    print(image.shape, label)

    writer = SummaryWriter(log_dir='transforms_logs')

    # 使用ToTensor
    to_tensor = transforms.ToTensor()
    image_tensor = to_tensor(image)
    writer.add_image(tag='dog', img_tensor=image_tensor, global_step=0)

3.2、Normalize

  • 功能:对每一个通道(channel)分别根据其均值、标准差进行标准化(Normalize a tensor image with mean and standard deviation

  • 输入:Tensor类型的图片(This transform does not support PIL Image

  • 输出:标准化后的Tensor类型的图片,output[channel] = (input[channel] - mean[channel]) / std[channel]

# 使用Normalize
# 创建对象的时候给定均值、标准差
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
image_tensor = normalize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=1)

3.3、Resize

  • 功能:调整图片HWResize the input image to the given size

    • size为序列(例如size=[500, 800]),则调整后的图片H=500W=800

    • size为整数(size=500),则根据HW中较小的值确定调整后的尺寸

      • 例如H=600W=1200,则调整后的H=500,调整后的W 1200 / 600 ∗ 500 = 1000 1200/600*500=1000 1200/600500=1000
  • 输入:PIL ImageTensor类型的图片

  • 输出:与输入的类型相同

# 使用Resize
resize = transforms.Resize(size=(500, 1000))
image_tensor = resize(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=2)

3.4、RandomCrop

  • 功能:对图片进行随机裁剪(Crop the given image at a random location
    • size为序列(例如size=[500, 800]),则裁剪后的图片H=500W=800
    • size为整数(size=500),则裁剪后的图片H=500W=500
    • 裁剪后的图片H、W均不大于原有图片的H、W,否则会报错
  • 输入:PIL ImageTensor类型的图片
  • 输出:与输入的类型相同
# 使用RandomCrop
random_crop = transforms.RandomCrop(size=(300, 800))
image_tensor = random_crop(image_tensor)
writer.add_image(tag='dog', img_tensor=image_tensor, global_step=3)

3.5、结果展示

从上到下分别是NormalizeResizeRandomCrop顺序执行后的结果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.6、Compose

  • 功能:指定一系列图片处理步骤,对图片进行流式处理(Composes several transforms together

    • 使用列表指定需要对图片进行的处理,[transform_1, transform_2,...]
  • 输入:PIL ImagendarrayTensor类型的图片

  • 输出:由图片类型、指定的处理步骤决定

# 使用
compose = transforms.Compose([to_tensor, normalize, resize, random_crop])
image_tensor = compose(image_tensor)
writer.add_image(tag='dog_compose', img_tensor=image_tensor, global_step=0)

四、Dataloader

Dataloader用于批量加载和处理数据,能数据集分成小批量,并在训练过程中按需加载这些小批量数据,以提高训练效率并节省内存。

  • 批量加载数据:参数batch_size,每次加载batch_size个数据,而不是一次性加载整个数据集;
  • 数据“洗牌”:参数shuffle,在每个训练周期开始时随机打乱数据顺序,防止模型过拟合;
  • 并行处理:参数num_workers,利用多个线程或进程加快数据加载过程;
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


if __name__ == '__main__':
    data_path = 'common_dataset'
    
    # 为了方便使用tensorboard进行展示,使用transform=transforms.ToTensor()将图片由PIL类型转换为Tensor类型
    train_data = datasets.CIFAR10(root=data_path, train=True, transform=transforms.ToTensor(), download=True)
    test_data = datasets.CIFAR10(root=data_path, train=False, transform=transforms.ToTensor(), download=True)

    # 打印数据集所包含的数据个数
    print(len(train_data), len(test_data))

    # 获取第一个数据的图片及标签(类别)
    img, label = train_data[0]
    # 打印类别索引及真实的类别
    print(label, train_data.classes[label])
    # img.show()

    writer = SummaryWriter(log_dir='CIFAR10_logs')

    # dataloader示例
    # drop_last=True可以舍弃最后的不足一批(batch_size)的图片
    data_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, drop_last=False)
    # 指定训练的最大epoch
    max_epoch = 1
    for epoch in range(max_epoch):
        i = 0
        for images, labels in data_loader:
            writer.add_images(tag=f'CIFAR10_{epoch}', img_tensor=images, global_step=i)
            i += 1

    writer.close()

tensorboard中部分batch的图片如下:
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1975098.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DDL、DML、DQL、DCL具体实例与关系

一、DDL、DCL、DML、DQL 通过二维表的形式,更加清晰直观的学习、对比其关系。 DDL DCL DML DQL 英文释义 Data Defination Language 数据库定义语言 Data Control Language 数据库控制语言 Data Manipulation Language 数据操作语言 Data Query Language 数…

PyMuPDF-Guide

本文翻译整理自: https://pymupdf.readthedocs.io/en/latest/how-to-open-a-file.html 文章目录 一、打开文件1、支持的文件类型2、如何打开文件打开一个错误的文件扩展名 3、打开远程文件从云服务打开文件 4、以文本形式打开文件例子打开一个C#文件打开一个XML文件…

按摩行业的革新者:从挑战到辉煌的转型之路

在时代浪潮的推动下,一个勇于创新的团队于2018年毅然踏入按摩服务市场,创立了一家颠覆传统的按摩店。面对行业内的激烈竞争与瞬息万变的市场环境,他们凭借独树一帜的经营模式和不懈的努力,不仅稳固了市场地位,更在去年…

使用Greenhills生成Lib并使用Lib的两种方法

文章目录 前言GHS工程生成libmake方式生成liblib的使用总结 前言 在软件交付过程,如果不交付源代码,可以将源码编译之后生成lib文件提供给客户。本文介绍GHS中生成lib的两种方法,一种基于GHS工程,一种基于make文件。生成完lib后的…

uniapp自定义网格布局用于选择金额、输入框焦点事件以及点击逻辑实战

样式 <view class="withdraw-section"><text class="section-title">提现金额</text><view class="amount-options"><view v-for="(item, index) in list" :key="index" class="amount-opt…

使用Leaflet进行船舶航行警告区域绘制实战

目录 前言 一、坐标格式转换 1、数据初认识 2、将区域分割成多个点 3、数据转换 4、数据转换调用 二、WebGIS展示空间位置信息 1、定义底图 2、Polygon的可视化 3、实际效果 三、总结 前言 通常而言&#xff0c;海事部门如海事局&#xff0c;通常会在所述的管辖区域内…

Java从入门到精通(十五) ~ IO流

晚上好&#xff0c;愿这深深的夜色给你带来安宁&#xff0c;让温馨的夜晚抚平你一天的疲惫&#xff0c;美好的梦想在这个寂静的夜晚悄悄成长。 目录 前言 什么是IO流&#xff1f; IO流的作用&#xff1a; 一、基础流 1. 字节流 1.1 字节输入流 FileInputStream 1.2 字节…

找到第一个满足条件的格值

表格第1列是科目&#xff0c;之后几列是每次的考试成绩&#xff0c;顺序排列。 ABCDE1Art03.676.27.82History3.786.217.29.83Maths5.66.36.68.9 要求根据指定的科目和成绩&#xff0c;找到该科目中大于等于该成绩的第1个格值&#xff0c;比如参数是Maths、6.5时&#xff0c;…

element-ui简单入门1.0.0

第一篇&#xff1a;table标签速用 总结&#xff1a;建楼前&#xff0c;先打地基<el-table></el-table>&#xff0c;打完地基看高度&#xff0c;一层楼4米&#xff0c;80米20个<el-table-column></el-table-column>&#xff0c;每次楼的名字是label 第…

[翻译] Asset Administration Shells

关于资产管理外壳 (AAS) 资产管理外壳 (AAS) 是工业4.0中的关键概念&#xff0c;为产品、资源&#xff08;如设备&#xff09;和过程提供信息隐藏和更高层次的抽象。AAS 是技术和设备无关的机器可读描述&#xff0c;提供访问资产属性和功能的统一接口。与现有解决方案不同&…

C# 下的限定符运算详解(全部,任意,包含)与示例

文章目录 1.限定符概述2. 全部限定符运算&#xff08;All&#xff09;3. 任意限定符运算&#xff08;Any&#xff09;4. 包含限定符运算&#xff08;Contains&#xff09;总结 当我们在C#编程中需要进行条件判断或集合操作时&#xff0c;限定符&#xff08;qualifiers&#xff…

Vue项目启动ESLint报错no-unused-vars解决办法

目录 原因分析解决方法 Vue项目启动时报错如下 ✘ http://eslint.org/docs/rules/no-unused-vars index is assigned a value but never usedsrc\views\friend\list.vue:206:17const index this.tableList.indexOf(v)^原因分析 ESLint是一个在JavaScript代码中识别和报告问…

【传知代码】辅助任务改进社交帖子多模态分类(论文复现)

在当今数字化社交时代&#xff0c;社交媒体平台如同人们生活的一部分&#xff0c;每天数以亿计的帖子在网络上涌现。这些帖子不仅仅是信息的载体&#xff0c;更是人们思想、情感和行为的折射。然而&#xff0c;要准确理解和分析这些多样化的社交帖子&#xff0c;仅依靠文本内容…

请问如何做好软件测试工作呢?

一、明确测试目标和范围 理解测试目的&#xff1a;在开始测试之前&#xff0c;首先要明确测试的目标和范围&#xff0c;确保测试计划 与需求相匹配。这有助于测试人员聚焦在关键功能上&#xff0c;避免浪费时间和资源。制定详细的测试计划&#xff1a;根据项目需求&#xff0…

【Python】爬取网易新闻今日热点列表数据并导出

1. 需求 从网易新闻的科技模块爬取今日热点的列表数据&#xff0c;其中包括标题、图片、标签、发表时间、路径、详细文本内容&#xff0c;最后导出这些列表数据到Excel中。 网易科技新闻网址&#xff1a;https://tech.163.com 2. 解决步骤 2.1 前期准备 爬虫脚本中需要引用…

Visio新手安装及超全快捷指令合集

Microsoft Visio是一款专业的流程图和图表绘制软件&#xff0c;是微软旗下的一款图表和矢量图形应用程序&#xff0c;属于Microsoft 365系列的一部分。但Visio需要单独安装&#xff0c;安装完成之后可与Word联用。 一、Visio软件介绍 Visio 是一款用途多样的绘图工具&#xff…

全球氢钎焊市场规划预测:未来六年CAGR为3.4%

随着全球制造业的持续发展和消费者对高质量产品的需求增加&#xff0c;氢钎焊作为一种高效的焊接技术&#xff0c;正逐渐受到市场的广泛关注。本文旨在通过深度分析氢钎焊行业的各个维度&#xff0c;揭示行业发展趋势和潜在机会。 【市场趋势的演变】 1. 市场规模与增长&#…

【uniapp】集成第三方插件示例

文章目录 uniapp芯套Android壳app目录下/libs目录导入全部aar工程目录下导入rewriter文件夹 uniapp芯套Android壳 https://blog.csdn.net/xzzteach/article/details/140800350 app目录下/libs目录导入全部aar工程目录下导入rewriter文件夹 本地引入包内容 在 project 级别的…

解决com.alibaba.csp.sentinel.slots.block.flow.FlowException: null

springboot项目配置sentinel&#xff0c;能限流成功但是不能限流方法 原因 名字没对应上

token和embedding

1. token 2. embedding 1.token token&#xff1a;词元/令牌/词 tokenization&#xff1a;分词 tokenizer&#xff1a;分词器 token是最小语义单元&#xff0c;通常可以是&#xff0c;一个字母、一个词、一个数字、一个汉字或任何其他有意义的字符组合&#xff0c;取决于文本处…