3.1.数据增广

news2024/11/14 12:09:38

数据增广

​ 以图片为例,在不同的灯光,色温,以及灯光反射的影响下,对识别可能会造成很大影响。这时候我们希望样本有更多的多样性,则可以在语言里面加入各种不同的背景噪音,或者改变图片的颜色和形状

在这里插入图片描述

1.常见的增强方法

1.1 翻转和裁剪

​ 左右翻转通常不会改变对象的类别(上下翻转可能会),我们可以使用transforms模块来创建RandomFlipLeftRight实例,就有50%的概率左右翻转。使用RandomFlipTopBottom实例,使图像各有50%的几率向上或向下翻转。

在这里插入图片描述

​ 切割可使用RandomResizedCrop来实现:

shape_aug = torchvision.transforms.RandomResizedCrop(
    (200, 200), scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

​ 意为随机裁剪一个面积为原始面积10%到100%的区域,该区域的宽高比从0.5~2之间随机取值,然后,区域的宽度和高度都被缩放到200像素。

1.2 改变颜色

​ 另一种增强方法时改变颜色,有四个方面:亮度,对比度,饱和度和色调。

apply(img, torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0, saturation=0, hue=0))

​ 四个参数分别为亮度,对比度,饱和度和色调,值的含义是在原始值的对应区间波动,比如亮度取0.5,即在原始值的50%到150%之间。可同时设置。

在这里插入图片描述

1.3 其他方法

​ 显然可以结合起来使用,然后也有锐化,高斯模糊,加黑块等处理方法:
http://github.com/aleju/imgaug

2.使用数据增广来训练

遇到的问题:

​ 直接使用d2l.resnet18()会遇到问题,自己重写一遍就能正确运行了(用的A卡,device用的torch_directml)

import torch
import torchvision
from torch import nn
from d2l import torch as d2l
from torch.nn import functional as F
import torch_directml

d2l.set_figsize()
img = d2l.Image.open('../img/cat1.jpg')
d2l.plt.imshow(img)

device = torch_directml.device()


# 作用多少次,生成两行四列
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    d2l.show_images(Y, num_rows, num_cols, scale=scale)


apply(img, torchvision.transforms.RandomHorizontalFlip())  # 左右
apply(img, torchvision.transforms.RandomVerticalFlip())  # 上下

shape_aug = torchvision.transforms.RandomResizedCrop(
    (200, 200), scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)  # 裁剪

color_aug = torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
# 亮度,对比度,饱和度,色调
apply(img, color_aug)

# 结合多种方法
augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

# 使用CIFAR10数据集,彩色图
all_images = torchvision.datasets.CIFAR10(train=True, root="../data",
                                          download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8);

train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()])  # 先做变换。再totensor转变成张量进行训练

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()])  # 测试的直接转换就行,不用增广


def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train,
                                           transform=augs, download=True)
    # transform 就是在读的时候就做增广
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=is_train, num_workers=d2l.get_dataloader_workers())
    return dataloader


# @save
def train_batch_ch13(net, X, y, loss, trainer, devices):
    """用多GPU进行小批量训练"""
    if isinstance(X, list):
        # 微调BERT中所需
        X = [x.to(devices) for x in X]
    else:
        X = X.to(devices)
    y = y.to(devices)
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum


# @save 多GPU的
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               device):
    """用多GPU进行模型训练"""
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    # net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # 4个维度:储存训练损失,训练准确度,实例数,特点数
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(
                net, features, labels, loss, trainer, device)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')


def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)


def resnet18(num_classes, in_channels=3):
    class Residual(nn.Module):
        def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
            super(Residual, self).__init__()
            self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)
            self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
            if use_1x1conv:
                self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)
            else:
                self.conv3 = None
            self.bn1 = nn.BatchNorm2d(num_channels)
            self.bn2 = nn.BatchNorm2d(num_channels)

        def forward(self, X):
            Y = F.relu(self.bn1(self.conv1(X)))
            Y = self.bn2(self.conv2(Y))
            if self.conv3:
                X = self.conv3(X)
            return F.relu(Y + X)

    def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
        blk = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
            else:
                blk.append(Residual(num_channels, num_channels))
        return blk

    b1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
                       nn.BatchNorm2d(64), nn.ReLU(),
                       nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

    b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
    b3 = nn.Sequential(*resnet_block(64, 128, 2))
    b4 = nn.Sequential(*resnet_block(128, 256, 2))
    b5 = nn.Sequential(*resnet_block(256, 512, 2))

    net = nn.Sequential(b1, b2, b3, b4, b5,
                        nn.AdaptiveAvgPool2d((1, 1)),
                        nn.Flatten(), nn.Linear(512, num_classes))
    return net


# batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3).to(device)
batch_size, devices, net = 256, d2l.try_all_gpus(), resnet18(10, 3)

net.apply(init_weights)
net.to(device)


def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)


net.apply(init_weights)


def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, device)


train_with_data_aug(train_augs, test_augs, net)
d2l.plt.show()

在这里插入图片描述

​ 不知道为什么,效果不是很好,明明用的都是resnet18,难道d2l的resnet18有什么优化?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1957459.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【2024蓝桥杯/C++/B组/进制】

题目 代码 #include <bits/stdc.h> using namespace std;// 定义一个字符串 str&#xff0c;其内容为 "8100178706957568" string str "8100178706957568";// 函数 check 用于检查传入的字符串是否全部由数字组成 bool check(const string& st…

Java(二十七)---二叉搜索树以及Map和Set

文章目录 前言1.二叉搜索树1.1.概念1.2.操作--- 插入1.3.操作---搜索1.4.操作---删除1.6.性能分析1.7 和 java 类集的关系 2.搜索2.1.概念和场景2.2.模型 3.Map的使用3.1.关于Map的说明3.2.Map.Entry<K,V>的说明3.3.Map中常用的方法3.4.TreeMap的使用案例 4.Set的使用4.1…

探索 Milvus 存储系统:如何评估和优化 Milvus 存储性能

欢迎来到探索 Milvus 系列。Milvus 是一款支持水平扩展和具备出色性能的开源向量数据库。Milvus 的核心是其强大的存储系统&#xff0c;是数据持久化和存储的关键基础。该系统包括几个关键组成部分&#xff1a;元数据存储&#xff08;meta storage&#xff09;、消息存储&#…

Vs2022+QT+Opencv 一些需要注意的地方

要在vs2022创建QT项目&#xff0c;先要安装一个插件Qt Visual Studio Tools&#xff0c;根据个人经验选择LEGACY Qt Visual Studio Tools好一些&#xff0c;看以下内容之前建议先在vs2022中配置好opencv&#xff0c;配置方式建议以属性表的形式保存在硬盘上。 设置QT路径 打开v…

数学建模--差值算法

目录 插值方法的种类 应用实例 编程实现 算法实现 拉格朗日插值算法 ​编辑 多项式差值算法 样条插值 牛顿插值算法 插值算法在数据预测中的最新应用和案例研究是什么&#xff1f; 如何比较不同插值方法&#xff08;如线性插值、多项式插值&#xff09;在实际工程问…

bjtu数据库课程设计--基于Spring Boot框架的门店点餐系统

一、安装与配置 1 安装与配置 下载IntelliJ IDEA&#xff0c;需要下载安装jdk1.8.0_152&#xff0c;安装tomcat-9.0.88&#xff0c;安装MySQL8.0数据库。安装成功后打开IntelliJ IDEA&#xff0c;使用 Spring Boot 2.6.13框架&#xff0c;服务器URL窗口使用start.aliyun.com&a…

AI副业玩法:开启你的智能赚钱之路

在这个数码时代&#xff0c;人工智能&#xff08;AI&#xff09;已经不仅仅是科技巨头的专利&#xff0c;它逐渐渗透到我们生活的方方面面。如今&#xff0c;越来越多的人开始利用AI技术进行副业尝试&#xff0c;既拓宽了收入来源&#xff0c;也提升了自我技能。那么&#xff0…

【前端 07】JavaScript中的数组对象

JavaScript中的数组对象 在JavaScript中&#xff0c;数组&#xff08;Array&#xff09;对象是一种非常基础且强大的数据结构&#xff0c;用于在单个变量中存储多个值。这些值可以是任何数据类型&#xff0c;包括数字、字符串、甚至是其他数组&#xff08;多维数组&#xff09…

实验2-4-2 求N分之一序列前N项和**注意小细节

//实验2-4-2 求N分之一序列前N项和//计算序列 1 1/2 1/3 ... 的前N项之和。#include<stdio.h> #include<math.h> int main(){int N;double sum0.0;scanf("%d",&N);for(int a1;a<N;a)sum(1.0/a);//这里必须是1.0 不可以是1&#xff01;&#x…

【管理咨询宝藏150】MBB咨询顾问的结构化PPT训练课程

【管理咨询宝藏150】MBB咨询顾问的结构化PPT训练课程 【格式】PDF版本 【关键词】MBB、麦肯锡、罗兰贝格、咨询顾问 【核心观点】 - 在项目的开始阶段你着手发展有效的金字塔式的演示文稿—我们的重点在最后两个步骤&#xff1b;我们用金字塔结构&#xff1a;通过把核心的信息…

使用nginx解决本地环境访问线上接口跨域问题

前言 前端项目开发过程中&#xff0c;经常会遇到各种各样的跨域问题。 虽然大部分时候&#xff0c;由脚手架自带的proxy功能即可解决问题&#xff0c;如webpack&#xff0c;vite等&#xff1b;但是若没有通过脚手架搭建项目&#xff0c;或者必须使用某些特殊规则转发时&#…

<数据集>手机识别数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;16172张 标注数量(xml文件个数)&#xff1a;16172 标注数量(txt文件个数)&#xff1a;16172 标注类别数&#xff1a;1 标注类别名称&#xff1a;[Phone] 使用标注工具&#xff1a;labelImg 标注规则&#xff1a;…

【QT】qt 文件操作

qt 文件 qt 文件1. Qt 文件概述2. 输入输出设备类3. 文件读写类4. 文件和目录信息类 qt 文件 1. Qt 文件概述 文件操作是应用程序必不可少的部分。Qt 作为⼀个通用开发库&#xff0c;提供了跨平台的文件操作能力。 Qt 提供了很多关于文件的类&#xff0c;通过这些类能够对文件…

上海有机所化学数据库:一站式科研资源

上海有机化学研究所是中国科学院上海分院的直属机构&#xff0c;主要从事有机化学、材料化学、生命科学等领域的基础研究和应用研究&#xff0c;化学专业数据库是该所承担建设的综合科技信息数据库的重要组成部分&#xff0c;服务于化学化工研究和开发的综合性信息系统&#xf…

Javaweb项目|springboot医院管理系统

收藏点赞不迷路 关注作者有好处 文末获取源码 一、系统展示 二、万字文档展示 基于springboot医院管理系统 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringSpringMVCMyBatisVue 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 编号&#xff1a;…

【海贼王航海日志:前端技术探索】HTML你学会了吗?(一)

目录 1 -> HTML概念 2 -> HTML结构 2.1 -> 认识HTML标签 2.2 -> HTML文件基本结构 2.3 -> 标签层次结构 3 -> 快速生成代码框架 4 -> HTML常见标签 4.1 -> 注释标签 4.2 -> 标题标签 4.3 -> 段落标签 4.4 -> 换行标签 4.5 ->…

多线程-进阶2

博主主页: 码农派大星. 数据结构专栏:Java数据结构 数据库专栏:MySQL数据库 JavaEE专栏:JavaEE 关注博主带你了解更多数据结构知识 1.CAS 1.1CAS全称:Compare and swap 比较内存和cpu寄存器中的内容,如果发现相同,就进行交换(交换的是内存和另一个寄存器的内容) 一个内存的…

《学会 SpringBoot · 参数校验》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

连锁美业门店收银系统Java源码-如何设置门店仓库自提时间?博弈美业实操

1. 门店仓库自提时间&#xff0c;是当客户在小程序上购买实物商品时&#xff0c;预约上门提货的时间 2. 门店仓库自提时间&#xff0c;需要由各门店&#xff08;店主、店长、店员&#xff09;在PAD上进行设置 ▶ 操作路径&#xff1a; • 第一步&#xff1a; 进入【我的】页…

怎么进行图片压缩?对图片文件的大小进行压缩的四个方法介绍

怎么进行图片压缩&#xff1f;图片压缩是一种常见的技术&#xff0c;用于减小图像文件的大小&#xff0c;同时尽可能地保持图像的视觉质量和细节。这一过程不仅适用于个人用户想要节省存储空间或提高网页加载速度&#xff0c;也对于专业摄影师、网站设计师和应用程序开发者来说…