【GridMask】《GridMask Data Augmentation》

news2024/11/25 15:23:13

在这里插入图片描述

arXiv-2020


文章目录

  • 1 Background and Motivation
  • 2 Related Work
  • 3 Advantages / Contributions
  • 4 GridMask
  • 5 Experiments
    • 5.1 Image Classification
    • 5.2 Object Detection on COCO Dataset
    • 5.3 Semantic Segmentation on Cityscapes
    • 5.4 Expand Grid as Regularization
  • 6 Conclusion(own)


1 Background and Motivation

数据增广方法可以有效的缓解模型的过拟合

现有的数据增广方法可以大致分成如下3类

  • spatial transformation(random scale, crop, flip and random rotation)
  • color distortion( brightness, hue)
  • information dropping(random erasing, cutout,HaS)

好的 information dropping 数据增广方法要 achieve reasonable balance between deletion and reserving of regional information on the images

删太多,把数据变成了噪声

删太少,目标没啥变化,失去了增广的意义

在这里插入图片描述
本文,作者提出GridMask,deletes uniformly distributed areas and finally forms a grid shape,在多个任务的公开数据集上效果均有提升
在这里插入图片描述

2 Related Work

  • spatial transformation(random scale, crop, flip and random rotation)
  • color distortion( brightness, hue)
  • information dropping(random erasing, cutout,HaS)

3 Advantages / Contributions

提出 GridMask structured data augmentation 方法,在公开的分类、目标检测、分割的benchmark 上比 baseline 好

4 GridMask

在这里插入图片描述
作用形式
x ~ = x × M \widetilde{x}= x \times M x =x×M

其中 x ∈ R H × W × C x \in \mathbb{R}^{H \times W \times C} xRH×W×C 为 输入图像, x ~ ∈ R H × W × C \widetilde{x} \in \mathbb{R}^{H \times W \times C} x RH×W×C 为增广后的图像, M ∈ { 0 , 1 } H × W M \in \{0,1\}^{H \times W} M{0,1}H×W 为 binary mask that stores pixels to be removed,0 的话表示挡住,1 的话表示保留

形成 M M M 的话有 4 个超参数 ( r , d , δ x , δ y ) (r, d, \delta_x, \delta_y) (r,d,δx,δy)

在这里插入图片描述
1)Choice of r r r

r r r is the ratio of the shorter gray edge in a unit,determines the keep ratio of an input image,值介于 0~1 之间

the keep ratio k k k of a given mask M M M as

k = s u m ( M ) H × W k = \frac{sum(M)}{H \times W} k=H×Wsum(M)

r r r k k k 的关系是

k = 1 − ( 1 − r ) 2 = 2 r − r 2 k = 1-(1-r)^2 = 2r-r^2 k=1(1r)2=2rr2

r r r 的值小于1, r r r k k k 正相关

k k k 越大,灰色区域越多,遮挡越少
k k k 越小,黑色区域越多,遮挡越多

2)Choice of d d d

d d d is the length of one unit

一个 unit 内(橙色虚线框),灰色区域的长度为 l = r × d l = r \times d l=r×d

d = r a n d o m ( d m i n , d m a x ) d = random(d_{min}, d_{max}) d=random(dmin,dmax)

在这里插入图片描述
这么画歧义更合适

3)Choice of δ x \delta_x δx and δ y \delta_y δy

δ x \delta_x δx and δ y \delta_y δy are the distances between the first intact unit and boundary of the image. can shift the mask

δ x ( δ y ) = r a n d o m ( 0 , d − 1 ) \delta_x(\delta_y) = random(0, d-1) δx(δy)=random(0,d1)

4)Statistics of Unsuccessful Cases
在这里插入图片描述
99 percent of an object is removed or reserved, we call it a failure case

GridMask has lower chance to yield failure cases than Cutout and HaS

5)The Scheme to Use GridMask

increase the probability of GridMask linearly with the training epochs until an upper bound P is achieved.

中间的概率用 p p p 表示,后续实验中有涉及到

5 Experiments

Datasets

  • ImageNet
  • COCO
  • Cityscapes

5.1 Image Classification

1)ImageNet
在这里插入图片描述
比 Cutout 和 HaS 更好,It is because we handle the aforementioned failure cases better

Benefit to CNN
在这里插入图片描述
focus on large important regions

2)CIFAR10
在这里插入图片描述
Combined with AutoAugment, we achieve SOTA result on these models.

3)Ablation Study

(1)Hyperparameter r r r
在这里插入图片描述

r 越大,mask 1 越多,遮挡的越少,说明数据比较复杂

r 越小,mask 1 越少,遮挡的越多,说明数据比较简单

we should keep more information on complex datasets to avoid under-fitting, and delete more on simple datasets to reduce over-fitting

(2)Hyperparameter d d d
在这里插入图片描述

the diversity of d can increase robustness of the network

(3)Variations of GridMask

reversed GridMask:keep what we drop in GridMask, and drop what we keep in GridMask

在这里插入图片描述
效果不错,也印证了 GridMask 有很好的 balance between deletion and reserving

random GridMask:drop a block in every unit with a certain probability of p u p_u pu.

在这里插入图片描述

p u p_u pu 越大,越贴近原始 GridMask

效果不行

5.2 Object Detection on COCO Dataset

在这里插入图片描述
不加 GridMask,training epochs 越多,过拟合越严重,加了以后,训练久一点, 精度还有上升空间

5.3 Semantic Segmentation on Cityscapes

在这里插入图片描述

5.4 Expand Grid as Regularization

联合 GridMask 和 Mixup,ImageNet 上 SOTA在这里插入图片描述

6 Conclusion(own)

GridMask Data Augmentation
在这里插入图片描述


代码实现,考虑了旋转增广,所以 mask 生成的时候是在以原图对角线为边长的情况下生成的,最后取原图区域
https://github.com/dvlab-research/GridMask/blob/master/imagenet_grid/utils/grid.py

在这里插入图片描述

import torch
import numpy as np
import math
import PIL.Image as Image
import torchvision.transforms as T
import matplotlib.pyplot as plt

class Grid(object):
    def __init__(self, d1=96, d2=224, rotate=1, ratio=0.5, mode=1, prob=1.):
        self.d1 = d1
        self.d2 = d2
        self.rotate = rotate
        self.ratio = ratio # r
        self.mode = mode # reversed?
        self.st_prob = self.prob = prob # p

    def set_prob(self, epoch, max_epoch):
        self.prob = self.st_prob * min(1, epoch / max_epoch)

    def forward(self, img):
        if np.random.rand() > self.prob:
            return img
        h = img.size(1)
        w = img.size(2)

        # 1.5 * h, 1.5 * w works fine with the squared images
        # But with rectangular input, the mask might not be able to recover back to the input image shape
        # A square mask with edge length equal to the diagnoal of the input image 
        # will be able to cover all the image spot after the rotation. This is also the minimum square.
        hh = math.ceil((math.sqrt(h * h + w * w)))

        d = np.random.randint(self.d1, self.d2)
        # d = self.d

        # maybe use ceil? but i guess no big difference
        self.l = math.ceil(d * self.ratio)

        mask = np.ones((hh, hh), np.float32)
        st_h = np.random.randint(d)  # delta y
        st_w = np.random.randint(d)  # delta x
        for i in range(-1, hh // d + 1):
            s = d * i + st_h
            t = s + self.l
            s = max(min(s, hh), 0)
            t = max(min(t, hh), 0)
            mask[s:t, :] *= 0
        for i in range(-1, hh // d + 1):
            s = d * i + st_w
            t = s + self.l
            s = max(min(s, hh), 0)
            t = max(min(t, hh), 0)
            mask[:, s:t] *= 0
        r = np.random.randint(self.rotate)
        mask = Image.fromarray(np.uint8(mask))
        mask = mask.rotate(r)
        mask = np.asarray(mask)
        mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (hh - w) // 2:(hh - w) // 2 + w] # 这里结合原理图方便看懂一些

        mask = torch.from_numpy(mask).float().cuda()
        if self.mode == 1:
            mask = 1 - mask

        mask = mask.expand_as(img)
        img = img.cuda() * mask

        return img


if __name__ == "__main__":
    image = Image.open("2.jpg").convert("RGB")
    tr = T.Compose([
        T.Resize((224,224)),
        T.ToTensor()
    ])
    x = tr(image)
    gridmask_image = Grid(d1=64, d2=96).forward(x)
    print(gridmask_image.shape)
    # print(gridmask_image.shape())
    fig, axs = plt.subplots(1,2)
    to_plot = lambda x: x.permute(1,2,0).cpu().numpy()
    axs[0].imshow(to_plot(x))
    axs[1].imshow(to_plot(gridmask_image))
    plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/13802.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MongoDB之完整入门知识(集合操作、文档基本CRUD操作、文档分页查询、索引等相关命令)

MongoDB完整入门知识一、相关概念1、简介2、体系结构3、安装网址二、MongoDB基本常用命令1、Shell连接(mongo命令)2、选择和创建数据库2.1 选择和创建数据库的语法格式(如果数据库不存在,则自动创建)2.2 查看有权限查看…

SpringBoot与Loki的那些事

因为网上好多都没有通过Loki的API自己实现对日志监控系统,所以我就下定决心自己出一版关于loki与springboot的博文供大家参考,这个可以说是比较实用,很适合中小型企业。因此我酝酿了挺久了,对于loki的研究也比较久,希望…

论文精读《OFT: Orthographic Feature Transform for Monocular 3D Object Detection》

OFT: Orthographic Feature Transform for Monocular 3D Object Detection 文章目录OFT: Orthographic Feature Transform for Monocular 3D Object Detection论文精读摘要(Abstract)1. 介绍(Introduction)2. 相关工作&#xff08…

给开源项目做一个漂亮简洁的版本迭代更新图,生成固定链接复制到介绍中、公众号菜单链接中、博客中等

背景 开源项目的版本迭代与更新经常需要更新迭代文档,但是readme.md没有比较美观一点的效果,所以文本分享一种第三方的方式:用TexSpire的免费在线文档分享功能,手机、PC、Pad都可以适配。 效果预览 使用 第一步:创…

浅谈 async/await 和生成器

浅谈 async/await async/await 是ES8规范新增的,使得以同步方式写的代码异步运行不再是白日梦,进一步让代码逻辑更加清晰。 为什么新增 async/await 下面有这样一个需求:有两个请求,请求 1 的结果是请求 2 的参数,所…

机器学习6——EM算法与高斯混合模型GMM

前置内容 Jensen不等式 高斯混合模型 多元高斯模型 拉格朗日乘子法 主要内容 EM算法(Expectation-Maximization),期望-最大化。 用于保证收敛到MLE(最大似然估计)。主要用于求解包含隐变量的混合模型,主要…

R生成三线表

R生成三线表table1包测试数据生成制作三线表Tableone包加载包 探索数据类型数据整理分类构建Table函数Tableone函数细节主要基于table1 或tableone 包table1包 测试数据生成 datadata.frame(性别sample(c("男","女"), 1000,replaceT),年龄round(rnorm(10…

2021年认证杯SPSSPRO杯数学建模A题(第一阶段)医学图像的配准全过程文档及程序

2021年认证杯SPSSPRO杯数学建模 A题 医学图像的配准 原题再现: 图像的配准是图像处理领域中的一个典型问题和技术难点,其目的在于比较或融合同一对象在不同条件下获取的图像。例如为了更好地综合多种信息来辨识不同组织或病变,医生可能使用…

5年自动化测试,终于进字节跳动了,年薪30w其实也并非触不可及

一些碎碎念 什么都做了,和什么都没做其实是一样的,走出“瞎忙活”的安乐窝,才是避开弯路的最佳路径。希望我的经历能帮助到有需要的朋友。 在测试行业已经混了5个年头了,以前经常听到开发对我说,天天的点点点有意思没…

java计算机毕业设计springboot+vue+elementUI永加乡精准扶贫信息管理系统

项目介绍 系统设计的主要意义在于,一方面,对于网站来讲,系统上线后可以带来很大的便利性,精准扶贫网站管理属于非常细致的管理模式,要求数据量大,计算机管理可以提高精确性,更为便利的就是信息…

NF-κB 信号通路调节细胞因子转录

NF-κB 大家族哺乳动物 NF-κB 家族由五种成员组成:RelA/p65、c-Rel、RelB、p50 (NF-κB1) 和 p52 (NF-κB2),它们可以形成各种异源二聚体或者同源二聚体 (如常见 p50/RelA 异源二聚体),并通过与启动子的 κB 位点结合来激活大量基因。所有 N…

Mysql常用函数

Mysql常用函数 字段拼接(concat) CONCAT() 函数用于将多个字符串连接成一个字符串 格式: select CONCAT(str1,str2,…) from table_name; #查询商品表,返回一列:商品名称(价格)。 SELECT concat(prod_name,(,prod…

【论文阅读】Weakly Supervised Semantic Segmentation using Out-of-Distribution Data

一篇弱监督分割领域的论文,发表在CVPR2022上: 论文标题: Weakly Supervised Semantic Segmentation using Out-of-Distribution Data 作者信息: 代码地址: https://github.com/naver-ai/w-ood Abstract 作者认为…

专精特新小巨人的申报条件

专精特新企业分为市级专精特新、省级专精特新和国家级专精特新。 在2018年,开展了国家第一批专精特新“小巨人” 企业申报工作。为了引导中小企业积极走“专精特新”发展之路,加快新旧动能转 换步伐,提升自主创新能力、加快转型升级&#xf…

软考的网络工程师对就业有用吗?

考证只是一个结果,学会技能才是最重要的。 视工作而言,软考中级网络工程师的性价比还是非常高的,对于从事同类的技术人员,基础扎实一般可以裸考通过。 含金量嘛,计算机专业可以以考代凭,毕竟证书是人社部和…

安装谷歌服务框架2022最新版本22.45.15失败

在这里(谷歌play服务框架下载安装安卓版-谷歌服务框架2022最新版本(Google Play 服务)下载22.45.15官方手机版-蜻蜓手游网 (qt6.com)http://www.qt6.com/XiaZai/155507.html)下载了谷歌服务框架(Google Play 服务),其应用信息为: 包名:com.g…

Mutated 源代码解析 client (一)

Mutated , a C project https://github.com/scslab/mutated usage Main function in the client directory, mutated_synthetic.cc Line 14 parse the user arguments, such as “-h, -w, -c” parse_synthetic is implemented in client\opts_synthetic.cc Here, use th…

Dive into TensorFlow系列(3)- 揭开Tensor的神秘面纱

TensorFlow计算图是由op和tensor组成,那么tensor一般都用来代表什么呢?显然,像模型的输入数据、网络权重、输入数据经op处理后的输出结果都需要用张量或特殊张量进行表达。既然tensor在TensorFlow体系架构中如此重要,因此本文将带…

Redis通用命令和key的层级结构

目录 1 Redis数据结构介绍 2 Redis 通用命令 3 Redis命令-Key的层级结构 1 Redis数据结构介绍 Redis是一个key-value的数据库,key一般是String类型,不过value的类型多种多样: value的数据类型共有8种,前面5中为基本数据类型&a…

5000立方米球罐设计

目 录 摘 要 I Abstract II 1 文献综述 1 1.1 课题研究的工程背景及理论、实际意义 1 1.2 球罐用钢 1 1.2.1 球罐用钢基本要求分析 1 1.2.2 国内外球罐的常用钢种 2 1.2.3 几种典型球罐用钢的优劣对比 2 1.3 球罐设计 3 1.3.1 球罐设计的执行标准及法规 3 1.3.2 球壳结构 4 1.3…