[PyTorch][chapter 40][数据增强]

news2025/1/12 20:45:49

前言:

      深度学习对数据量要求非常大,

      我们通常会遇到图像的数据集比较小,影响Train效果。

这个时候可以通过transformer 方法,增加图像的多样性,达到数据

增强的效果。

transformer 不会单独使用,通常和其它torch 其他类一起使用

transformer 常用方法如下

方法

说明

Resize

调整图片大小

Normalize

按照指定的均值,方差 正规化

ToTensor

convert a PIL image to tensor

ToPILImage

convert a tensor to PIL imageScale

ResizeCenterCrop

在图片的中间区域进行裁剪

RandomCrop

在一个随机的位置进行裁剪

RandomHorizontalFlip

0.5的概率水平翻转给定的PIL图像

RandomVerticalFlip

0.5的概率竖直翻转给定的PIL图像

RandomResizedCrop

PIL图像裁剪成任意大小和纵横比

Grayscale

将图像转换为灰度图像

RandomGrayscale

将图像以一定的概率转换为灰度图像

FiceCrop

把图像裁剪为四个角和一个中心T

enCropPad

填充ColorJitter:随机改变图像的亮度对比度和饱和度

这里结合summaryWriter,torchvision.datasets, torch.utils.data.DataLoader

介绍一下其使用方法


目录:

  1.     summaryWriter
  2.     torchvision.datasets
  3.     torch.utils.data.DataLoader

    


一 summaryWriter

   1.1 功能简介

     Writes entries directly to event files in the log_dir to be
    consumed by TensorBoard.
 
    The `SummaryWriter` class provides a high-level API to create an event file
    in a given directory and add summaries and events to it. The class updates the
    file contents asynchronously. This allows a training program to call methods
    to add data to the file directly from the training loop, without slowing down
    training.
    """

 SummaryWriter` 类用于在给定目录中创建事件文件,并向其中添加摘要和事件。 然后通过cmd 命令启动该TensorBoard 服务,在浏览器中可以查看对应的图形化界面.

1.2 环境安装

      pip install tensorboard

      pip install tensorflow(不安装UI 显示不出来)

1.3 张量添加

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("ZCH_Tensorboard_Trying_logs")      #第一个参数指明 writer 把summary内容 写在哪个目录下
 
for i in range(100):
    writer.add_scalar("y=x",i,i)
 
for i in range(100):
    writer.add_scalar("2I",2*i,i)
 
for i in range(100):
    writer.add_scalar("5I",5*i,i)
 
for i in range(100):
    writer.add_scalar("9I",9*i,i)
 
writer.close()     #将event log写完之后,记得close()

  两步:

         step1  生成summaryWriter 对象writer

         step2  通过add_scalar 方法,添加数据

               重要的常用的其实就是前三个参数:

           ( 1)tag:要求是一个string,用以描述 该标量数据图的 标题

           (2)scalar_value :可以简单理解为一个y轴值的列表

          (3)global_step:可以简单理解为一个x轴值的列表,与y轴的值相对应

当启动Tensorboard 可以通过红色的部分过滤,切换想要查看的项目

1.4 启动tensorBoard

   windows 命令窗口中输入

      tensorboard --logdir=ZCH_Tensorboard_Trying_logs(最好是完整路径) 
  http://localhost:6006/

    程序执行前可以加上如下,把之前旧的删除掉
   if os.path.exists('logs'):
        shutil.rmtree('logs')# 如果文件存在,则递归的删除文件内容
        print('Remove log dir')

1.5 add_image

      


from PIL import Image
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
 

img_path = "src/1.jpeg"
img = Image.open(img_path)
 
writer = SummaryWriter("imglogs")
 
# ToTensor
trans_tensor = transforms.ToTensor()    # PIL Image or numpy.array
img_tensor = trans_tensor(img)
writer.add_image("src img", img_tensor)


# Compose中参数需要一个列表,列表形式为[数据1, 数据2, ...]
# 在Compose中,数据需要的是transforms类型, Compose([transforms参数1, transforms参数2, ...])
trans_resize_2 = transforms.Resize(256)     # Resize中一个数,为按照图片最小边进行缩放
trans_compose = transforms.Compose([trans_resize_2, trans_tensor])   # 第一个参数:改变图片大小,第二个参数:转换类型
img_resize_2 = trans_compose(img)
writer.add_image("reseize img", img_resize_2, 1)


writer.close()

   在cmd 命令中输入

  tensorboard --logdir=D:\AI\Image\imglogs


二  torchvision.datasets

      torch.utils.data.Dataset() (官方文档),它是 Pytorch 中表示数据集的抽象类

      datasets这个包有很多数据集,比如MINIST、COCO、CIFAR10 and CIFAR100、LSUN 、Classification、ImageFolder、Imagenet-12、STL10。torchvision.datasets中的数据集封装都是torch.utils.data.Dataset子类,它们都实现了__getitem__ 和 __len__方法,都可以用DataLoader进行数据加载。

  torchvision.datasets.MNIST(root,train = True,transform = None,target_transform = None,download = False )

参数

介绍

root

根目录

train

如果为True,训练集,否则是测试集

download

如果为true,根目录没有数据集就会自动在这个目录下载

transform

数据集预处理,比如归一化当图形转换类的操作

target_transform

接收目标并对其进行转换的函数/转换



from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import ssl

 
 
root_dir = "./data_cifar10"
print("\n step1: ")
ssl._create_default_https_context = ssl._create_unverified_context

dataset_train = datasets.CIFAR10(root=root_dir, train=True, transform=transforms.ToTensor(), download=True)
dataset_test = datasets.CIFAR10(root=root_dir,  train=False, transform=transforms.ToTensor(), download=True)
print("\n step2: ")
dataloader_train = DataLoader(dataset=dataset_train, batch_size=64, shuffle=True, drop_last=True)
dataloader_test = DataLoader(dataset=dataset_test, batch_size=64, shuffle=True, drop_last=True)
 
log_dir = "logs"
writer = SummaryWriter(log_dir=log_dir)
print("\n writer")
number = 0
for epoch in range(2):
    step = 0
    for data_imgs, data_targets in dataloader_test:
        print("\n data_imgs ",number)
        number+=1
        writer.add_images(f"epoch{number}", data_imgs, step) #
    step += 1
print("\n---end---")
writer.close()

输入: tensorboard –logdir=D:\AI\Image\imglogs


1 COCO数据集
     是一个可用于图像检测(image detection),语义分割(semantic segmentation)和图像标题生成(image captioning)的大规模数据集。它有超过330K张图像(其中220K张是有标注的图像),包含150万个目标,80个目标类别(object categories:行人、汽车、大象等),91种材料类别(stuff categoris:草、墙、天空等),每张图像包含五句图像的语句描述,且有250,000个带关键点标注的行人。
mscoco.org/

2 CIFAR-10数据集
    由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。

3、LSUN数据集
PASCAL VOC和ImageNet ILSVRC比赛使用的数据集,数据领域包括卧室、冰箱、教师、厨房、起居室、酒店等多个主题。
它包含10个场景类别和20个对象类别中的每个类别的大约一百万张带标签的图像。
下载地址:https://www.yf.io/p/lsun


4、谷歌Open Images图像数据集
其中包括大约9百万标注图片、横跨6000个类别标签,平均每个图像拥有8个标签。
该数据集的标签涵盖比拥有1000个类别标签的ImageNet具体更多的现实实体,可用于计算机视觉方向的训练。

5、ImageNet数据集
ImageNet数据集是目前深度学习图像领域应用得非常多的一个领域,该数据集有1000多个图像,涵盖图像分类、定位、检测等应用方向。
Imagenet数据集文档详细,有专门的团队维护,在计算机视觉领域研究论文中应用非常广,几乎成为了目前深度学习图像领域算法性能检验的“标准”数据集。很多大型科技公司都会参加ImageNet图像识别大赛,包括百度、谷歌、微软等


三  torch.utils.data.DataLoader

一般来说PyTorch中深度学习训练的流程是这样的:

1. 创建Dateset

2. Dataset传递给DataLoader

3. DataLoader迭代产生训练数据提供给模型

torch.utils.data.DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False, 
    sampler=None,
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=None,
    pin_memory=False, 
    drop_last=False, 
    timeout=0,
    worker_init_fn=None)

参数

说明

dataset

加载数据的数据集

batch_size

每个batch加载多少个样本

shuffle 

设置为True时会在每个epoch重新打乱数据(默认: False)

sampler

定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序

num_workers

用多少个子进程加载数据。0:数据将在主进程中加载(默认: 0)

collate_fn 

将一个batch的数据和标签进行合并操作

pin_memory 

设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。

drop_last 

如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

timeout

用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

worker_init_fn

如果不是None,将在播种之后和数据加载之前,对每个worker子进程使用worker id (int in [0, num_workers - 1])作为输入调用。(默认值:None)

   

#案例 from https://blog.csdn.net/weixin_43981621/article/details/119685671

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

# 准备的测试数据集  数据放在了CIFAR10文件夹下
root_dir = "./data_cifar10"
dataset_train = datasets.CIFAR10(root=root_dir, train=True, transform=transforms.ToTensor(), download=True)

 
train_loader = DataLoader(
                dataset=dataset_train, 
                batch_size=4, 
                shuffle=True, 
                num_workers=0, 
                drop_last=False
)

# 设置参数batch_size=4时,每次取了4张照片,并获得4个targets标签。
# 在定义test_loader时,设置了batch_size=4,表示一次性从数据集中取出4个数据
for data in train_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)
 
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 13 16:45:55 2023

@author: chengxf2
"""

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

# 数据预处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081))])
# 训练集
train_dataset = datasets.MNIST(root='../data/mnist',train=True, download=True, transform=transform)
# 测试集
test_dataset=datasets.MNIST(root='../data/mnist',train=False, download=True, transform=transform)
# 数据集加载器 


train_loader= DataLoader(dataset = train_dataset,  # 数据加载
                        batch_size = 4,    # 送入多少张图片
                        shuffle = True,    #对原有数据排序是否打乱
                        num_workers = 0,   #是否进行多进程加载数据设置
                        drop_last = False) #最后的数据组不成一个batch_size 是否丢弃

参数:
dataset:数据加载
batch_size :送入多少张图片
shuffle :是否打乱数据
sampler :指定数据加载中使用的索引/键的序列
batch_sampler = None,#和sampler类似
num_workers :是否进行多进程加载数据设置
collat​​e_fn = None,#是否合并样本列表以形成一小批Tensor
pin_memory :数据加载器会在返回之前将Tensors复制到CUDA固定内存
drop_last :最后的数据组不成一个batch_size 是否丢弃

有时候也综合起来使用,如下:

参考:

https://zhuanlan.zhihu.com/p/463799442 [torchvision 介绍]
https://blog.csdn.net/qq_41764621/article/details/126210936【SummaryWriter类】
https://blog.csdn.net/m0_51233386/article/details/127645795【 SummaryWriter类】
https://blog.csdn.net/qq_43456016/article/details/130072202[图像增强(Transforms]
https://zhuanlan.zhihu.com/p/463799442【torchvision中的数据集使用】

https://blog.csdn.net/weixin_45464524/article/details/128043516

https://www.cnblogs.com/lucky-light/p/15535282.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/643508.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

可视管理 数字孪生智慧隧道一体化管控平台

前言 交通是国家发展的关键,四通八达的交通路线,对国家经济、社会等方面的发展起着至关重要的作用。 建设背景 随着社会经济的持续发展与城市化进程的平稳推进,我国公路工程规模逐步扩大,公路工程建设直接影响着城市未来发展与…

Vue 报错 error:0308010C:digital envelope routines::unsupported

症状 Vue 报错error:0308010C:digital envelope routines::unsupported 原因 出现这个错误是因为 node.js V17版本中最近发布的OpenSSL3.0, 而OpenSSL3.0对允许算法和密钥大小增加了严格的限制,可能会对生态系统造成一些影响. 解决方法 方法1 打开终端&#x…

React 应用 Effect Hook 函数式中操作生命周期

React Hook入门小案例 在函数式组件中使用state响应式数据给大家演示了最简单的 Hook操作 那么 我们继续 首先 Hook官方介绍 他没有破坏性是完全可选的 百分比兼容 也就说 我们一起的 类 class的方式也完全可以用 只要 react 16,8以上就可以使用 Hook本身不会影响你的react的理…

ESXi 7.0 U3m Hitachi (日立) 定制版 OEM Custom Installer CD

VMware ESXi 7.0 Update 3m - 领先的裸机 Hypervisor (All OEM Customized Installer CDs) ESXi 7.0 U3m Standard (标准版) ESXi 7.0 U3m Dell (戴尔) 定制版 OEM Custom Installer CD ESXi 7.0 U3m HPE (慧与) 定制版 OEM Custom Installer CD ESXi 7.0 U3m Lenovo (联想) 定…

4.单表查询

SQL句子中语法格式提示: 1.中括号([])中的内容为可选项; 2.[,...]表示,前面的内容可重复; 3.大括号({})和竖线(|)表示选择项,在选择…

chatgpt赋能python:Python怎么导入第三方库

Python怎么导入第三方库 如果你是Python开发者,你一定会使用各种第三方库来加速你的开发过程。这些库可能是Python标准库之外的代码,或由其他人编写的自定义代码。使用这些库可以让你的开发更高效、更易于管理,并且可以避免重复造轮子。 但…

RabbitMQ虚拟主机无法启动的原因和解决方案

RabbitMQ虚拟主机无法启动的原因和解决方案 摘要: RabbitMQ是一个广泛使用的开源消息代理系统,但在使用过程中可能会遇到虚拟主机无法启动的问题。本文将探讨可能导致该问题的原因,并提供相应的解决方案,以帮助读者解决RabbitMQ虚…

Learning C++ No.31 【线程库实战】

引言: 北京时间:2023/6/11/14:40,实训课中,实训场地有空调,除了凳子坐着不舒服之外,其它条件都挺好,主要是我带上了我自己的小键盘,并且教室可以充电,哈哈哈&#xff0c…

在做自动化测试之前你需要知道的

B站视频教程:Python自动化测试:7天练完这60个实战项目,年薪过35w。 什么是自动化测试? 做测试好几年了,真正学习和实践自动化测试一年,自我感觉这一个年中收获许多。一直想动笔写一篇文章分享自动化测试实践…

信息系统管理工程师-学习笔记1-信息化知识

考点1 信息与信息系统 信息的概念 信息的定义: 是有别与物质与能量的第三种东西,是对事物运动状态或存在方式的不确定行的描述 信息是按特定方式组织在一起的客体属性的集合,具有超出这些客体属性本身之外的价值两层次 1.本体论层次 : 纯客观的层次,只与客体本身的因素有关,与主…

python cv2的一些操作,如膨胀,画线,滤波等

目录 0. cv2简介1. 打开摄像头2. 画图,画线3. 滤波4. 获取角点5. 梯度边缘6. 图形匹配7. 形态学变化-膨胀腐蚀8. 二值化阈值10. 总结 0. cv2简介 在这里先简单介绍一下cv2吧。 cv2 是 OpenCV Python 库的主要模块,提供了许多图像处理和计算机视觉方面的函数和工具。…

vue2组件通信

父传子 传递静态或动态 Prop <!-- 传入静态值 --> <blog-post title"hai hai hai"></blog-post><!-- 传入变量值 --> <blog-post :title"info.title"></blog-post>传入一个对象的所有 property 数据 post: {id: 1…

进程管道:popen函数实例

基础知识 可能最简单的在两个程序之间传递数据的方法就是使用popen和pclose函数了。它们的原型如下所示&#xff1a; #include <stdio.h>FILE *popen(const char *command, const char *type);int pclose(FILE *stream); 1&#xff0e;popen函数 popen函数允许一个程…

因为Json,controller方法单参数 导致脑袋短路

对于单参数方法&#xff0c; 一直喜欢用parameter方式。今天不知道为啥&#xff0c;就想用Json方式&#xff0c;然后无法直接传递。各种自我怀疑&#xff0c;然后尝试。 突然醒悟过来&#xff0c;Json方式是key/value模式&#xff0c;单参数String类型&#xff0c;没有key。必…

TreeMap源码

介绍 如果我们希望Map可以保持key的大小顺序时&#xff0c;就需要利用TreeMap。底层使用了红黑树&#xff0c;左子树总小于root&#xff0c;右子树总大于root&#xff0c;具有很好的平衡性,操作速度达到log(n)。 TreeMap 相比于HashMap多实现了了NavigableMap接口&#xff08…

5. SpringCloudAlibab 集成 gateway

一、什么是 Spring Cloud Gateway 1、网关简介 网关作为流量的入口&#xff0c;常用的功能包括路由转发&#xff0c;权限校验&#xff0c;限流等等。 SpringCloud Gateway是 Spring Cloud 官方推出的第二代网关框架&#xff0c;定位取代 Netflix Zuul。相对Zuul来说&#xf…

【多线程】原子引用ABA问题

目录 一、代码示例二、执行结果截图三、说明四、AtomicStampedReference使用4.1 代码示例4.2 截图 一、代码示例 package com.learning.atomic;import lombok.extern.slf4j.Slf4j; import java.util.concurrent.atomic.AtomicReference; /*** Author wangyouhui* Description …

2023年软考-高级信息系统项目管理工程师考试大纲

高级信息系统项目管理工程师考试大纲 2023年软考高级信息系统项目管理工程师考试大纲已于2023年5月出版。您可以在 中国计算机技术职业资格网 上找到更多关于考试的信息 。 信息系统项目管理师是对从事信息系统项目管理工作的专业技术人员基本理论和实践能力的综合考核,该专业…

新手如何挑选一款合适的功率放大器?

ATA系列功率放大器是&#xff08;AB&#xff09;类功放&#xff0c;相比于甲类功率放大器&#xff0c;它小信号输入时效率更高&#xff0c;随着输出功率的增大&#xff0c;效率也增高&#xff0c;它的效率比以及保真度而言&#xff0c;都优于A类和B类功放。 因为具有这些优势&a…

助你更好的理解 Python 字典

助你更好的理解 Python 字典 字典是Python中的常用数据类型之一&#xff0c;可将数据存储在键/值对中&#xff0c;同 Java 中的 Map 相似。 1、什么是字典理解&#xff1f; 字典理解是创建字典的一种优雅简洁的方法。 字典理解优化 使用字典理解优化函数。 示例&#xff…