【从零开始学习深度学习】44. 图像增广的几种常用方式并使用图像增广训练模型【Pytorch】

news2025/1/20 12:02:24

大规模数据集是成功应用深度神经网络的前提,图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力。例如,我们可以对图像进行不同方式的裁剪,使感兴趣的物体出现在不同位置,从而减轻模型对物体出现位置的依赖性。我们也可以调整亮度、色彩等因素来降低模型对色彩的敏感度。本文将详细介绍该项技术。

目录

  • 1. 常用的图像增广方法
    • 1.1 翻转和裁剪
    • 1.2 变化颜色
    • 1.3 叠加多个图像增广方法
  • 2 使用图像增广训练模型
    • 2.1 训练模型
  • 总结

首先,导入实验所需的包或模块。

%matplotlib inline
import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image

import sys
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

1. 常用的图像增广方法

我们来读取一张形状为 400 × 500 400\times 500 400×500(高和宽分别为400像素和500像素)的图像作为实验的样例。

d2l.set_figsize()
img = Image.open('./img/cat1.jpg')
d2l.plt.imshow(img)

在这里插入图片描述

下面定义绘图函数show_images

def show_images(imgs, num_rows, num_cols, scale=2):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j])
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    return axes

大部分图像增广方法都有一定的随机性。为了方便观察图像增广的效果,接下来我们定义一个辅助函数apply。这个函数对输入图像img多次运行图像增广方法aug并展示所有的结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    # aug为传入的图像增广方法函数
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    show_images(Y, num_rows, num_cols, scale)

1.1 翻转和裁剪

左右翻转图像通常不改变物体的类别。它是最早也是最广泛使用的一种图像增广方法。下面我们通过torchvision.transforms模块创建RandomHorizontalFlip实例来实现一半概率的图像水平(左右)翻转。

# transforms.RandomHorizontalFlip(p=0.5)默认p=0.5表示一半概率翻转
apply(img, torchvision.transforms.RandomHorizontalFlip())

在这里插入图片描述

上下翻转不如左右翻转通用。但是至少对于样例图像,上下翻转不会造成识别障碍。下面我们创建RandomVerticalFlip实例来实现一半概率的图像垂直(上下)翻转。

apply(img, torchvision.transforms.RandomVerticalFlip())

在这里插入图片描述

在我们使用的样例图像里,猫在图像正中间,但一般情况下可能不是这样。之前介绍了池化层能降低卷积层对目标位置的敏感度。除此之外,我们还可以通过对图像随机裁剪来让物体以不同的比例出现在图像的不同位置,这同样能够降低模型对目标位置的敏感性。

在下面的代码里,我们每次随机裁剪出一块面积为原面积 10 % ∼ 100 % 10\% \sim 100\% 10%100%的区域,且该区域的宽和高之比随机取自 0.5 ∼ 2 0.5 \sim 2 0.52,然后再将该区域的宽和高分别缩放到200像素。若无特殊说明,文中 a a a b b b之间的随机数指的是从区间 [ a , b ] [a,b] [a,b]中随机均匀采样所得到的连续值。

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

在这里插入图片描述

1.2 变化颜色

另一类增广方法是变化颜色。我们可以从4个方面改变图像的颜色:亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。在下面的例子里,我们将图像的亮度随机变化为原图亮度的 50 % 50\% 50% 1 − 0.5 1-0.5 10.5 ∼ 150 % \sim 150\% 150% 1 + 0.5 1+0.5 1+0.5)。

apply(img, torchvision.transforms.ColorJitter(brightness=0.5))

在这里插入图片描述

我们也可以随机变化图像的色调。

apply(img, torchvision.transforms.ColorJitter(hue=0.5))

在这里插入图片描述

类似地,我们也可以随机变化图像的对比度。

apply(img, torchvision.transforms.ColorJitter(contrast=0.5))

在这里插入图片描述

我们也可以同时设置如何随机变化图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。

color_aug = torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
apply(img, color_aug)

在这里插入图片描述

1.3 叠加多个图像增广方法

实际应用中我们会将多个图像增广方法叠加使用。我们可以通过Compose实例将上面定义的多个图像增广方法叠加起来,再应用到每张图像之上。

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

在这里插入图片描述

2 使用图像增广训练模型

下面我们来看一个将图像增广应用在实际训练中的例子。这里我们使用CIFAR-10数据集,而不是之前我们一直使用的Fashion-MNIST数据集。这是因为Fashion-MNIST数据集中物体的位置和尺寸都已经经过归一化处理,而CIFAR-10数据集中物体的颜色和大小区别更加显著。下面展示了CIFAR-10数据集中前32张训练图像。

all_imges = torchvision.datasets.CIFAR10(train=True, root="~/Datasets/CIFAR", download=True)
# all_imges的每一个元素都是(image, label)
show_images([all_imges[i][0] for i in range(32)], 4, 8, scale=0.8);

在这里插入图片描述

为了在预测时得到确定的结果,我们通常只将图像增广应用在训练样本上,而不在预测时使用含随机操作的图像增广。在这里我们只使用最简单的随机左右翻转。此外,我们使用ToTensor将小批量图像转成PyTorch需要的格式,即形状为(批量大小, 通道数, 高, 宽)、值域在0到1之间且类型为32位浮点数。

flip_aug = torchvision.transforms.Compose([
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor()])

no_aug = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()])

接下来我们定义一个函数来读取图像并应用图像增广。

num_workers = 0 if sys.platform.startswith('win32') else 4
def load_cifar10(is_train, augs, batch_size, root="~/Datasets/CIFAR"):
    dataset = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download=True)
    return DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)

2.1 训练模型

我们在CIFAR-10数据集上训练之前介绍的残差网络ResNet-18模型。

我们先定义train函数使用GPU训练并评价模型。

def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = d2l.evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

然后就可以定义train_with_data_aug函数使用图像增广来训练模型了。该函数使用Adam算法作为训练使用的优化算法,然后将图像增广应用于训练数据集之上,最后调用刚才定义的train函数训练并评价模型。

def train_with_data_aug(train_augs, test_augs, lr=0.001):
    batch_size, net = 256, d2l.resnet18(10)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()
    # 对训练数据使用图像增广
    train_iter = load_cifar10(True, train_augs, batch_size)
    # 对测试数据不使用图像增广
    test_iter = load_cifar10(False, test_augs, batch_size)
    train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)

下面使用随机左右翻转的图像增广来训练模型。

train_with_data_aug(flip_aug, no_aug)

输出:

training on  cuda
epoch 1, loss 1.3615, train acc 0.505, test acc 0.493, time 123.2 sec
epoch 2, loss 0.5003, train acc 0.645, test acc 0.620, time 123.0 sec
epoch 3, loss 0.2811, train acc 0.703, test acc 0.616, time 123.1 sec
epoch 4, loss 0.1890, train acc 0.735, test acc 0.686, time 123.0 sec
epoch 5, loss 0.1346, train acc 0.765, test acc 0.671, time 123.1 sec
epoch 6, loss 0.1029, train acc 0.787, test acc 0.674, time 123.1 sec
epoch 7, loss 0.0803, train acc 0.804, test acc 0.749, time 123.1 sec
epoch 8, loss 0.0644, train acc 0.822, test acc 0.717, time 123.1 sec
epoch 9, loss 0.0526, train acc 0.836, test acc 0.750, time 123.0 sec
epoch 10, loss 0.0433, train acc 0.851, test acc 0.754, time 123.1 sec

总结

  • 图像增广基于现有训练数据生成随机图像从而应对过拟合。
  • 为了在预测时得到确定的结果,通常只将图像增广应用在训练样本上,而不在预测时使用含随机操作的图像增广。
  • 可以从torchvision的transforms模块中获取有关图片增广的类。

如果文章内容对你有帮助,感谢点赞+关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/152096.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PCB入门学习—PCB封装的创建2

3.2 IC类PCB封装的创建注:PCB封装的名字一定要和原理图上填写的封装名字一样,不然对不上。规格书里有最大值最小值,就按最大值来做。快捷键EA是特殊粘贴。SOP-8:焊盘比较多时(BGA)可以利用向导去创建。做封装从规格书需要读取的数据&#xff…

19-FreeRTOS 任务通知API

1-xTaskNotifyGive / xTaskNotifyGiveIndexed task.hBaseType_t xTaskNotifyGive( TaskHandle_t xTaskToNotify );BaseType_t xTaskNotifyGiveIndexed( TaskHandle_t xTaskToNotify, UBaseType_t uxIndexToNotify );每个任务都有一组“任务通知” (或仅“通知” &a…

Tomcat Connector运行模式

目录 1 Tomcat Connector运行模式 1.1 BIO 模式 1.2 NIO 模式 1.3 APR模式 2 修改Tomcat Connector运行模式为apr 或者解决问题The APR not found 问题 2.1 linux系统 2.2 windows处理 1 Tomcat Connector运行模式 1.1 BIO 模式 BIO模式(blocking I/O&am…

[C语言]运用函数指针数组构建一个简单计算器

1.函数指针数组 函数指针数组,即为存放函数首地址的数组,类型为函数指针类型。 2.运用函数指针数组构建简单计算器 1.人机交互,首先要用选择加减或乘除的菜单,再分别写出其功能 void menu() {printf("****************…

嵌入式实时操作系统的设计与开发(五)

线程退出 当线程代码执行完后,系统会隐式地调用acoral_thread_exit()函数进行线程退出相关的操作。 acoral_thread_exit()本质上是要执行acoral_kill_thread()。 void acoral_thread_exit(){acoral_kill_thread(acoral_cur_thread); }void acoral_kill_thread(aco…

ccbill 代码分析

ccbill目录概述需求:设计思路实现思路分析1.BillState2.DBList3.DBListAttr4.DBListDBSrc5.DBListDBSrcsDBListsSurvive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wai…

每日一题:Leetcode59. 螺旋矩阵 II

文章目录 系列:数组专练 语言:java & go 题目来源:Leetcode59. 螺旋矩阵 II 难度:中等 考点:边界问题的处理 & 圈数处理 思路和参考答案文章目录题目描述思路java解法:java参考代码go参考代码&…

OceanBase 4.0解读:从TPC-H性能测评看4.0与3.x差异

关于作者 肖帆 OceanBase技术专家 OceanBase技术专家,开源生态团队成员。毕业于华中科技大学软件工程专业,从事数据库领域的质量保障工作,曾就职于有赞、网易,参与关系型数据库、缓存数据库、对象存储相关产品的测试开发&#x…

.net core api调用webserver接口(详细)

这里废话不多说,我就不简述什么事webserver了,相信点进本博客的大佬都是为了解决问题。 .net core 调用webserver的话还挺简单。首先我们先有个.net core api的项目。 1.我们先注入这个HttpClient 这个内置对象,一会要用到。 // 注入HttpC…

Vue Hook Event 解读

前言 Hook Event (钩子事件)相信很多 Vue 开发者都没有使用过,甚至没听过,毕竟 Vue 官方文档中也没有提及。 Vue 提供了一些生命周期钩子函数,供开发者在特定的逻辑点添加额外的处理逻辑,比如: 在组件挂载阶段提供了beforeMount 和…

代码随想录day34

1005. K 次取反后最大化的数组和 题目 给你一个整数数组 nums 和一个整数 k ,按以下方法修改该数组: 选择某个下标 i 并将 nums[i] 替换为 -nums[i] 。 重复这个过程恰好 k 次。可以多次选择同一个下标 i 。 以这种方式修改数组后,返回数…

linux笔记(9):MangoPi-MQ(芒果派麻雀D1s)Tina系统编译烧录

文章目录1.下载相关资料1.1 WhyCan Forum(哇酷开发者社区)提供的sdk1.1.1 SDK解压过程1.2 WhyCan Forum(哇酷开发者社区)提供的补丁1.2.1 补丁包含的文件1.2.2 补丁文件和D1下面的相同文件进行合并1.2.3 引脚PD17被复用,导致LCD变暗,修改设备树2. 编译ti…

【node.js 安装】linux下安装node.js

下面我们介绍安装包安装方法 nodejs官网下载地址1 nodejs官网下载地址2 我们以官网下载地址2打开 直接下载源代码,rz上传到/opt/tools/ 目录下 tar -xJvf node-v18.13.0-linux-x64.tar.xz配置环境变量,vim /etc/profile ,配置内容如下&am…

SFP 收发器居然有那么多种?值得收藏学习

SFP 模块具有广泛的应用范围,可与大部分现代网络配合使用,大多数可以分为四大类:电缆类型、传输范围、传输速率、应用。 一、电缆类型 SFP 模块可以在光纤和铜线上工作,根据光纤的种类,SFP收发器可分为与单模光纤配合…

π122E60 5.0kVrms 200Mbps 双通道数字隔离器 兼容代替Si8622BT-IS

π122E60 5.0kVrms 200Mbps 双通道数字隔离器 兼容代替Si8622BT-IS 具有出色的性能特征和可靠性,整体性能优于光耦和基于其他原理的数字隔离器产品。 产品传输通道间彼此独立,可实现多种传输方向的配置,可实现 5.0kVrms 隔离耐压等级和 DC 到…

Loading 用户体验 - 加载时避免闪烁

🍓 前言 在切换详情页中有这么一个场景,点击上一条,会显示上一条的详情页,同理,点击下一条,会显示下一条的详情页。 伪代码如下所示: 我们定义了一个 switcher 模版, 用户点击上一…

TensorRT部署YOLOv5(03)-TensorRT介绍

TensorRT是本专栏中最重要的内容,绝大多数内容将围绕TensorRT来展开,本文对TensorRT进行一个基本的介绍,让不熟悉TensorRT的读者能够对TensorRT是什么,如何使用它有一个较为全面的认识 Nvidia TensorRT是一个用于Nvidia GPU上高性能机器学习推理的SDK,对开发者屏蔽了模型…

到底什么样的 REST 才是最佳 REST?

说起 REST API,小伙伴们多多少少都有听说过,但是如果让你详细介绍一下什么是 REST,估计会有很多人讲不出来,或者只讲出来其中一部分。 今天松哥就来和大家一起来聊一聊到底什么是 REST,顺便再来看下 Spring HATEOAS 的…

[算法与数据结构]——并查集

目录 1. 概论 定义: 主要构成: 作用: 2. 并查集的现实意义 故事引入: 数据结构的角度来看: 3. find( )函数的定义与实现 故事引入: 实现: 4. join( )函数的定义与实现 故事引入: 实现…

c++11 标准模板(STL)(std::forward_list)(三)

定义于头文件 <forward_list> template< class T, class Allocator std::allocator<T> > class forward_list;(1)(C11 起)namespace pmr { template <class T> using forward_list std::forward_list<T, std::pmr::polymorphic_…