李沐36_数据增广——自学笔记

news2024/9/24 7:19:47

数据增强

增强一个已有的数据集,使得有更多的多样性

1.在语言里面加入各种不同的背景噪音

2.改变图片的颜色和形状

一般是在线生成、随机增强

常见数据增强

1.左右翻转

2.上下翻转(不总可行)

3.切割:从图片中切割一块,变形到固定形状(随机高宽比[3/4 or 4/3]、随机大小[8% or 100%]、随机位置)

4.颜色:改变色调,饱和度,明亮度

总结

1.数据增广通过变形数据来获得多样性,从而使得模型泛化性能更好

2.常见操作:翻转、切割、变色

额外补充,colab读取图片遇到的问题

img = d2l.Image.open(‘…\img\cat1.jpg’)会遇到找不到图片的问题,可参考https://blog.csdn.net/weixin_47505105/article/details/128514242

在这里插入图片描述

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
d2l.set_figsize()
img = d2l.Image.open('cat1.jpg')
d2l.plt.imshow(img);

# 此处图片保存在同个文件夹下,具体操作参考 http://t.csdnimg.cn/DP6au

在这里插入图片描述

# 或者读取URL路径

import requests
from PIL import Image
import io
import matplotlib.pyplot as plt

# 图片的URL
image_url = 'https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/catdog.jpg'  # 替换为实际的图片URL

# 发送请求获取图片内容
response = requests.get(image_url)
image_content = response.content

# 使用BytesIO来读取二进制流,然后转换成Image对象
image_stream = io.BytesIO(image_content)
img = Image.open(image_stream)

# 显示图片
plt.imshow(img)
plt.axis('off')  # 不显示坐标轴
plt.show()

在这里插入图片描述

定义辅助函数apply:此函数在输入图像img上多次运行图像增广方法aug并显示所有结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    d2l.show_images(Y, num_rows, num_cols, scale=scale)

翻转

左右:使用transforms模块来创建RandomFlipLeftRight实例,这样就各有50%的几率使图像向左或向右翻转。

apply(img, torchvision.transforms.RandomHorizontalFlip())

在这里插入图片描述

上下:RandomFlipTopBottom实例,使图像各有50%的几率向上或向下翻转。

apply(img, torchvision.transforms.RandomVerticalFlip())

在这里插入图片描述

裁剪

随机裁剪一个面积为原始面积10%到100%的区域,该区域的宽高比从0.5~2之间随机取值。 然后,区域的宽度和高度都被缩放到200像素。

shape_aug = torchvision.transforms.RandomResizedCrop(
    (200, 200), scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

在这里插入图片描述

改变颜色

随机更改图像的亮度,随机值为原始图像的50%(1-0.5)到150%(1+0.5)之间。

apply(img, torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0, saturation=0, hue=0))

在这里插入图片描述

随机更改色调

apply(img, torchvision.transforms.ColorJitter(
    brightness=0, contrast=0, saturation=0, hue=0.5))

在这里插入图片描述

创建一个RandomColorJitter实例,并设置如何同时随机更改图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)

color_aug = torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
apply(img, color_aug)

在这里插入图片描述

多种图像增广

通过使用一个Compose实例来综合上面定义的不同的图像增广方法,并将它们应用到每个图像。

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

在这里插入图片描述

图像增广进行训练

使用CIFAR-10数据集, CIFAR-10数据集中的前32个训练图像如下所示

all_images = torchvision.datasets.CIFAR10(train=True, root="../data",
                                          download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8);
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 40330181.37it/s]


Extracting ../data/cifar-10-python.tar.gz to ../data

在这里插入图片描述

只使用最简单的随机左右翻转,我们使用ToTensor实例将一批图像转换为深度学习框架所要求的格式,即形状为(批量大小,通道数,高度,宽度)的32位浮点数,取值范围为0~1。

train_augs = torchvision.transforms.Compose([
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor()])

test_augs = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()])

PyTorch数据集提供的transform参数应用图像增广来转化图像。

def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train,
                                           transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                    shuffle=is_train, num_workers=d2l.get_dataloader_workers())
    return dataloader

多CPU训练

#
def train_batch_ch13(net, X, y, loss, trainer, devices):
    """用多GPU进行小批量训练"""
    if isinstance(X, list):
        # 微调BERT中所需
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum

#
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """用多GPU进行模型训练"""
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # 4个维度:储存训练损失,训练准确度,实例数,特点数
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(
                net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')

可以定义train_with_data_aug函数,使用图像增广来训练模型。该函数获取所有的GPU,并使用Adam作为训练的优化算法,将图像增广应用于训练集,最后调用刚刚定义的用于训练和评估模型的train_ch13函数。

batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/lazy.py:181: UserWarning: Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.
  warnings.warn('Lazy modules are a new feature under heavy development '
train_with_data_aug(train_augs, test_augs, net)
loss 0.223, train acc 0.921, test acc 0.846
1168.0 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1596113.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenCV4.9图像金字塔

目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 pyrUp()和 pyrDown()对给定图像进行下采样或上采样。 理论 注意 下面的解释属于 Bradski 和 Kaehler 的 Learning OpenCV 一书。 通常&#xff0c;我们需要将图像转换为与原始图像不同的大小。为此…

函数的参数命名和默认参数

在Kotlin中&#xff0c;函数可以有多个参数&#xff0c;记住参数的顺序或者仅靠位置理解他们的作用可能会很具有挑战性&#xff0c;特别是对于接受多个参数或者有相同类型参数的函数。命名参数通过允许开发者指定传递给函数的每个参数的名称来解决这个问题。 有一个用来展示用户…

了解 Vue 工程化开发中的组件通信

目录 1. 组件通信语法 1.1. 什么是组件通信&#xff1f; 1.2. 为什么要使用组件通信&#xff1f; 1.3. 组件之间有哪些关系&#xff08;组件关系分类&#xff09;&#xff1f; 1.4. 组件通信方案有哪几类 &#xff1f; 2. 父子通信流程图 3. 父传子 3.1. 父传子核心流程…

【C++成长记】C++入门 | 类和对象(中) |类的6个默认成员函数、构造函数、析构函数

&#x1f40c;博主主页&#xff1a;&#x1f40c;​倔强的大蜗牛&#x1f40c;​ &#x1f4da;专栏分类&#xff1a;C❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 目录 一、类的6个默认成员函数 二、构造函数 1、概念 2、特性 三、析构函数 1、概念 2、特性 一、…

LabVIEW专栏六、LabVIEW项目

一、梗概 项目&#xff1a;后缀是.lvproj&#xff0c;在实际开发的过程中&#xff0c;一般是要用LabVIEW中的项目来管理代码&#xff0c;也就是说相关的VI或者外部文件&#xff0c;都要放在项目中来管理。 在LabVIEW项目中&#xff0c;是一个互相依赖的整体&#xff0c;可以包…

51-40 Align your Latents,基于LDM的高分辨率视频生成

由于数据工程、仿真测试工程&#xff0c;咱们不得不进入AIGC图片视频生成领域。兜兜转转&#xff0c;这一篇与智驾场景特别密切。23年4月&#xff0c;英伟达Nvidia联合几所大学发布了带文本条件融合、时空注意力的Video Latent Diffusion Models。提出一种基于LDM的高分辨率视…

【简单讲解如何安装与配置Composer】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

深入探索:Zookeeper+消息队列(kafka)集群

目录 前言 一、Zookeeper概述 1、Zookeeper概念 2、Zookeeper 特点 3、Zookeeper工作机制 4、Zookeeper 选举机制 4.1 第一次启动选举机制 4.2 非第一次启动选举机制 5、Zookeeper 数据结构 6、Zookeeper 应用场景 二、部署 Zookeeper 集群 1、环境部署 2、安装 z…

构建鸿蒙ACE静态库

搭建开发环境 根据说明文档下载鸿蒙全部代码&#xff0c;一般采取第四种方式获取最新代码(请保证代码为最新) 源码获取Windows下载编译环境 MinGW GCC 7.3.0版本 请添加环境变量IDE 可以使用两种 CLion和Qt,CLion不带有环境需要安装MinGW才可以开发,Qt自带MinGW环境&#xff0…

【Canvas与艺术】绘制磨砂黄铜材质Premium Quality徽章

【关键点】 渐变色的使用、斜纹的实现、底图的寻觅 【成果图】 ​​​​​​​ 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><tit…

跑腿平台隐藏服务用法,搭建平台这些跑腿服务也能做!

跑腿场景竞争愈发激烈激烈 事实上&#xff0c;跑腿行业早已群狼环伺&#xff0c;尽管跑腿领域仍有很大的发展空间&#xff0c;但新晋玩家都普遍把目光投向了外卖配送这个细分领域&#xff0c;难免会增加后来者的市场拓展和发展难度。那么&#xff0c;在跑腿服务行业中还有哪些…

springboot上传模块到私服,再用pom引用下来

有时候要做一个公司的公共服务模块。不能说大家都直接把代码粘贴进去&#xff0c;因为会需要维护很多份&#xff1b;这样就剩下两个方式了。 方式一&#xff1a;自己独立部署一个公共服务的服务&#xff0c;全公司都调用&#xff0c;通过http、rpc或者grpc的方式&#xff0c;这…

Http响应报文介绍

所有HTTP消息(请求与响应)中都包含&#xff1a; 一个或几个单行显示的消息头(header)&#xff0c; 在消息头部分主要包含&#xff1a;响应行信息和响应头信息 一个强制空白行&#xff1b; 最后是响应消息主体&#xff1b; 以下是一个典型的HTTP响应: HTTP/1.1 200 OK -- 响…

【MATLAB源码-第15期】基于matlab的MSK的理论误码率与实际误码率BER对比仿真,采用差分编码IQ调制解调。

1、算法描述 在数字调制中&#xff0c;最小频移键控&#xff08;Minimum-Shift Keying&#xff0c;缩写&#xff1a;MSK&#xff09;是一种连续相位调制的频移键控方式&#xff0c;在1950年代末和1960年代产生。[1] 与偏移四相相移键控&#xff08;OQPSK&#xff09;类似&…

nodejs解析url参数

需要引入 url 模块&#xff1b; var http require(http); var url require(url);http.createServer(function (req, res) {res.writeHead(200, {Content-Type: text/plain});// 解析 url 参数var params url.parse(req.url, true).query;res.write("name: " par…

解决Git 不相关的分支合并

可以直接调到解决方案,接下来是原因分析和每步的解决方式 问题原因: 我之前在自己本机创建了一个初始化了Git仓库,后来有在另一个电脑初始化仓库,并没有clone自己在本机Git远程仓库地址,导致Git历史版本不相关 错误信息 From https://gitee.com/to-uphold-justice-for-other…

MES管理系统中生产物料管理的设计

在数字化工厂建设的浪潮中&#xff0c;MES管理系统作为执行层的核心管理系统&#xff0c;其重要性日益凸显。特别是在生产物料管理方面&#xff0c;MES管理系统不仅承担物料计划指令的接收&#xff0c;还负责物料派工及使用反馈的数据收集&#xff0c;其业务流程的设计对数字化…

【Leetcode】2923. 找到冠军 I

文章目录 题目思路代码复杂度分析时间复杂度空间复杂度 结果总结 题目 题目链接&#x1f517; 一场比赛中共有 n n n 支队伍&#xff0c;按从 0 0 0 到 n − 1 n - 1 n−1 编号。 给你一个下标从 0 0 0 开始、大小为 n ∗ n n * n n∗n 的二维布尔矩阵 g r i d grid gr…

改进的注意力机制的yolov8和UCMCTrackerDeepSort的多目标跟踪系统

基于yolov8和UCMCTracker/DeepSort的注意力机制多目标跟踪系统 本项目是一个强大的多目标跟踪系统&#xff0c;基于[yolov8]链接和[UCMCTracker/DeepSot]/链接构建。 &#x1f3af; 功能 多目标跟踪&#xff1a;可以实现对视频中的多目标进行跟踪。目标检测&#xff1a;可以实…