图像增广:强化深度学习的视觉表现力

news2024/12/23 10:23:41

目录

摘要:

1. 图像增广简介

2. 图像增广的原理

3. 常见的图像增广技术

4. 如何在实际项目中应用图像增广

5.实际应用


摘要:

当今,深度学习已经在计算机视觉领域取得了令人瞩目的成就。图像增广作为一种数据处理技术,让我们在使用有限的图像数据集时能够充分挖掘图像特征,提高模型的泛化能力。本文将详细介绍图像增广的概念、原理以及如何在实际项目中应用。

1. 图像增广简介

图像增广(Image Augmentation)是一种通过对原始图像进行各种变换来生成新的图像的方法。这些变换包括旋转、翻转、缩放、剪切、色彩变换等。通过图像增广,我们可以扩大数据集的规模,增加模型训练时的输入样本。这有助于提高模型的泛化能力,从而在面对新的、未知的数据时,也能达到较高的准确性。

2. 图像增广的原理

深度学习模型在训练过程中需要大量的数据来学习特征表达。然而,在实际应用中,我们并不总是能获得足够多的数据。图像增广通过对原始图像进行各种变换,创造出具有不同视觉特征的新图像。这样一来,模型在训练时可以接触到更多样的数据,从而学习到更丰富的特征表达,提高泛化能力。

值得注意的是,图像增广并不能完全解决数据不足的问题,但它可以在一定程度上缓解这个问题,提高模型的性能。

3. 常见的图像增广技术

以下是一些常见的图像增广技术:

- **旋转**:将图像按一定的角度进行旋转。
- **翻转**:对图像进行水平或垂直翻转。
- **缩放**:对图像进行放大或缩小。
- **剪切**:在图像上随机选择一块区域,将其裁剪为新的图像。
- **色彩变换**:改变图像的亮度、对比度、饱和度等色彩属性。
- **噪声添加**:在图像中添加随机噪声。
- **仿射变换**:对图像进行平移、旋转、缩放等操作。

4. 如何在实际项目中应用图像增广

许多深度学习框架都提供了图像增广的相关工具,例如 TensorFlow、PyTorch、Keras 等。在使用这些框架时,我们可以轻松地将图像增广技术应用到我们的项目中。以下是一个使用 Keras 进行图像增广的简单示例:

from keras.preprocessing.image import ImageDataGenerator

# 创建一个图像数据生成器
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

# 将数据生成器应用到训练集
train_generator = datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical')

在上述代码中,我们定义了一个图像数据生成器,并设置了一些增广参数。然后,我们使用这个数据生成器对训练集进行处理。

5.实际应用

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

导入图片: 图片大小为400*500

d2l.set_figsize()
img = d2l.Image.open('./img/1.jpg') #务必将图片放到该根目录
d2l.plt.imshow(img)

 大多数图像增广方法都具有一定的随机性。为了便于观察图像增广的效果,我们下面定义辅助函数apply。 此函数在输入图像img上多次运行图像增广方法aug并显示所有结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):#aug增广
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    d2l.show_images(Y, num_rows, num_cols, scale=scale)

 左右翻转图像通常不会改变对象的类别。这是最早且最广泛使用的图像增广方法之一。 使用transforms模块来创建RandomFlipLeftRight实例,这样就各有50%的几率使图像向左或向右翻转。

 上下翻转并不常用:

apply(img, torchvision.transforms.RandomVerticalFlip())

 随机剪切函数:剪切大小200*200像素,在10%-100%的范围内剪切,长宽比为1:2

shape_aug = torchvision.transforms.RandomResizedCrop((200,200),scale=(0.1,1),ratio=(0.5,2))
apply(img,shape_aug)

随机改变亮度

apply(img,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0,saturation=0,hue=0))

 

 改变色调:

apply(img, torchvision.transforms.ColorJitter(
    brightness=0, contrast=0, saturation=0, hue=0.5))

 创建一个RandomColorJitter实例,并设置如何同时随机更改图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。

color_aug = torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
apply(img, color_aug)

 最常用的就是将各种增广方法结合起来:

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

 下载一个常用数据集:

#下载数据集
all_images = torchvision.datasets.CIFAR10(train=True, root="./data",
                                          download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8);

 通常对训练样本只进行图像增广,且在预测过程中不使用随机操作的图像增广。使用ToTensor实例将一批图像转换为深度学习框架所要求的格式,即形状为(批量大小,通道数,高度,宽度)的32位浮点数,取值范围为0~1。

train_augs = torchvision.transforms.Compose([
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor()]) #变成一个4d矩阵

test_augs = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()])
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root="./data", train=is_train,
                                           transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                    shuffle=is_train, num_workers=d2l.get_dataloader_workers())
    return dataloader

 使用gpu进行训练数据:

#@save
#使用多GPU对模型进行训练和评估
def train_batch_ch13(net, X, y, loss, trainer, devices):
    """用多GPU进行小批量训练"""
    if isinstance(X, list):
        # 微调BERT中所需
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum

#@save
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """用多GPU进行模型训练"""
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # 4个维度:储存训练损失,训练准确度,实例数,特点数
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(
                net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')

 可以定义train_with_data_aug函数,使用图像增广来训练模型。该函数获取所有的GPU,并使用Adam作为训练的优化算法,将图像增广应用于训练集,最后调用刚刚定义的用于训练和评估模型的train_ch13函数。

batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
train_with_data_aug(train_augs, test_augs, net)

        总之,图像增广作为一种数据处理技术,在深度学习领域具有重要的意义。通过应用图像增广,我们能够充分挖掘图像特征,提高模型的泛化能力。在实际项目中,我们可以根据需求选择不同的增广技术,从而优化模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/726257.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Cache】Redis的高可用与持久化

文章目录 一、Redis 高可用1. 概念2. 高可用技术以及作用2.1 持久化2.2 主从复制2.3 哨兵2.4 集群 二、Redis 持久化1. 持久化的功能2. Redis 持久化方式 三、RDB 持久化1. 概述2. 触发条件2.1 手动触发2.2 自动触发2.3 其他自动发机制 3. 执行流程4. 启动时加载 四、AOF 持久化…

【UEFI实战】UEFI图形显示(字符输出)

HII Font 接下来介绍EFI_HII_FONT_PROTOCOL,它在UEFI代码中完成了字符到像素的转换,本节主要介绍这个转换关系,它的实现代码在edk2\MdeModulePkg\Universal\HiiDatabaseDxe\HiiDatabaseDxe.inf中,除了EFI_HII_FONT_PROTOCOL&…

【Axure教程】多选树穿梭选择器

多选树在有分层的领域是经常用到的,例如不同城市下的门店、不同部门的员工等等,用多选树就可以让我们在不同层级快速挑选到对应的对象。 今天作者就教大家在Axure中如何制作多选树穿梭选择器的原型模板,我们会以不同部门之间挑选员工位案例。…

leetcode极速复习版-第二章链表

目录 链表 203.移除链表元素 707.设计链表 206.反转链表 24. 两两交换链表中的节点 19.删除链表的倒数第N个节点 面试题 02.07. 链表相交 链表部分总结 链表 203.移除链表元素 题意:删除链表中等于给定值 val 的所有节点。 示例 1: 输入&a…

基于Java在线电影评价系统设计实现(源码+lw+部署文档+讲解等)

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

【Linux】十分钟理解软硬链接

目录 1.磁盘的物理结构2.磁盘的物理存储结构3.文件系统4.硬链接4.14.2 5.软链接6.三种时间 1.磁盘的物理结构 盘片:一片两面,有一摞盘片。磁头:一面一个磁头,一个磁头负责一面的读取(磁头是一起动的)。马达…

二叉树 — 返回二叉树最大距离

题目: 给定二叉树头结点head,任何两个节点之间都有距离,求整棵二叉树最大距离。 二叉树如下图所示,假设从x到b,中间节点只能走一次,我们人为规定距离就是整条路径的节点数量,所以距离是3&#x…

Spring Boot 中的 Spring Cloud Gateway

Spring Boot 中的 Spring Cloud Gateway Spring Cloud Gateway 是一个基于 Spring Boot 的网关框架,它提供了一种统一的入口,将所有的请求路由到不同的后端服务中。Spring Cloud Gateway 采用了 Reactive 编程模型,可以处理大量并发请求&…

idea闪退,端口占用处理

1、idea --> Terminal 2、 输入命令 jps 查看进程 3、找到对应的进程,使用 taskkill /pid 端口号 /f 4、 重启项目 ,即可

Golang快速鸟瞰

文章目录 引子知识图谱包代理设置关键字数据类型变量struct 和 interface控制语句字符串单引号、双引号、反引号数组与切片字典make和newjson与yaml基本语法指针Channeldeferinit函数类error, panic, recoverchannel与协程调试热加载Gin的热加载Iris的热加载 常用Golang框架常用…

数据库基础作业(linux系统)

数据库作业 在linux系统下的MySQL 创建数据库 使用数据库 查询当前默认的数据库以及使用的编码方式校验规则 查询创建数据的语句 删除数据库 创建数据表 定义多个字段,用上所有数据类型 mysql> SHOW CREATE TABLE multi_tb; -----------------------------------------…

重新理解z-index

一,前言 今天遇到一个布局兼容问题,调试了一番,发现z-index的表现和自己的认知不相符,才知道自己对z-index的认知有错误,于是写篇文章总结下这个z-index的具体使用。有基础的朋友可以直接看第四节。 二,标…

Android 内存治理之线程

1、 前言 当我们在应用程序中启动一个线程的时候,也是有可能发生OOM错误的。当我们看到以下log的时候,就说明系统分配线程栈失败了。 java.lang.OutOfMemoryError: pthread_create (1040KB stack) failed: Out of memory这种情况可能是两种原因导致的。…

行业追踪,2023-07-06,市场反馈平平

自动复盘 2023-07-06 成交额超过 100 亿 排名靠前,macd柱由绿转红 成交量要大于均线 有必要给每个行业加一个上级的归类,这样更能体现主流方向 rps 有时候比较滞后,但不少是欲杨先抑, 应该持续跟踪,等 macd 反转时参与…

rust 从转移说起

Rust 专门提出了所有权和转移的概念,第一次接触总感觉晦涩,不属于正常思维,但还是得耐下性子,观摩观摩 Rust 所谓的转移。 Rust 中,对大多数类型而言,给变量赋值、给函数传值或者从函数返回值,…

Eclipse显示层级目录结构(像IDEA一样)

有的小伙伴使用IDEA习惯了,可能进入公司里面要求使用eclipse,但是eclipse默认目录是并列显示,而不是层级显示。部分人用起来感觉十分不方便。我们可以更改一下设置。 1、打开eclipse,找到这里 2、选择PackagePresentation 3、选…

支持跨语言、人声狗吠互换,仅利用最近邻的简单语音转换模型有多神奇

AI 语音转换真的越复杂越好吗?本文就提出了一个方法简单但同样强大的语言转换模型,与基线方法相比自然度和清晰度毫不逊色,相似度更是大大提升。 AI 参与的语音世界真神奇,既可以将一个人的语音换成任何其他人的语音,…

express框架中间件

1.介绍 说明:Express框架中间件是指在处理HTTP请求前或后对请求和响应进行处理的函数。具体而言,中间件可以: 执行一些公共的逻辑,比如身份验证、日志记录、错误处理等。修改请求和响应,比如缓存、压缩等。控制请求流…

ModaHub魔搭社区:基于 Amazon EKS 搭建开源向量数据库 Milvus

目录 01 前言 02 架构说明 03 先决条件 04 创建 EKS 集群 05 部署 Milvus 数据库 06 优化 Milvus 配置 07 测试 Milvus 集群 08 总结 01 前言 生成式 AI(Generative AI)的火爆引发了广泛的关注,也彻底点燃了向量数据库&…

Ubuntu中删除LibreOffice方法

目录 删除LibreOffice套件 删除所有与LibreOffice相关的软件包 删除与LibreOffice相关的配置文件 删除LibreOffice套件 1、打开终端。您可以使用快捷键Ctrl Alt T来打开终端。 2、输入以下命令以卸载LibreOffice套件: sudo apt-get remove libreoffice* 删…