【SSD 代码精读】之 数据增强(Data Augmentation)

news2024/11/24 12:54:11

SSD 数据增强

  • 前言
    • 1、Compose
    • 2、SSDCropping
    • 3、Resize
    • 4、ColorJitter
    • 5、ToTensor
    • 6、RandomHorizontalFlip
    • 7、Normalization
    • 8、AssignGTtoDefaultBox


前言

原论文
在这里插入图片描述
根据原论文,我们需要处理的有以下:

data_transform = {
    "train": transforms.Compose([transforms.SSDCropping(),
                                 transforms.Resize(),
                                 transforms.ColorJitter(),
                                 transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.Normalization(),
                                 transforms.AssignGTtoDefaultBox()]),
    "val": transforms.Compose([transforms.Resize(),
                               transforms.ToTensor(),
                               transforms.Normalization()])
}

因为 torchvision.transforms 默认只处理图像,而我们在做图像翻转的时候,需要连 ground truth box 的坐标一并翻转。 所以我们需要重写 torchvision.transforms 那一套的操作。

(mac系统下,只要按住 command 键,再点击 torchvision.transforms ,就可以查看源码,在源码上修修改改就可以)


1、Compose

输入输出 带上 target

class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target=None):
        for trans in self.transforms:
            image, target = trans(image, target)
        return image, target

2、SSDCropping

目的:从图像中裁剪出一部分,删除不在其中的 gt box 和 label,对于在其中的 gt box做相应的坐标调整。

To make the model more robust to various input object sizes and shapes, each training image is randomly sampled by one of the following options:

  • Use the entire original input image.
  • Sample a patch so that the minimum jaccard overlap with the objects is 0.1, 0.3, 0.5, 0.7, or 0.9
  • Randomly sample a patch.
    \

The size of each sampled patch is [0.1, 1] of the original image size, and the aspect ratio is between 0.5 and 2.
We keep the overlapped part of the ground truth box if the center of it is in the sampled patch.

相关说明:
gt box 的坐标,在这之前 已经被处理为了 0 ~ 1 的相对位置
在这里插入图片描述

pseudo code

''' 
图像尺寸:(1, 1)。 因为图像尺寸不一样,这里的计算按照比例, 后面的计算也都是按照图片比例进行计算

mode = (None, (0.1, None), (0.3, None), (0.5, None), (0.7, None), (0.9, None), (None, None))
mode,表示iou的阈值,其中:
     --  None 表示:不做裁剪
     --  (0.1, None), ... ,(0.9, None)表示: (min_iou, max_iou) 
     --  (None, None)  表示 无上限和下限,也就是iou的范围属于 [0, 1] 都可以

target 是一个字典,其中 包括 gt_box 的坐标, 及对应的 label
'''


while True:
    1、随机挑选一个 mode
    if mode is None, 不做随机裁剪处理
        return image, target       
    else:
        min_iou = mode[0],max_iou = mode[1]    (None 表示无上限)

    for _ in range(5):
        2、创建一个 crop_box: 宽和高的范围都在 (0.3, 1.0)之间,需要保证crop_box的四个角都落在原图中, 且保证宽高比例在0.5-2之间
		3、取图像的 gt_box 坐标
		
		# 判断这个crop_box是不是能用的 条件一 : iou 要满足条件
        4、计算 gt_box 和 crop_box 的 iou
        if 有 iou 不在 (min_iou, max_iou) 范围之间:
            continue
            
        # 判断这个crop_box是不是能用的 条件二: 中心坐标要满足条件
        5、计算 gt_box 的中心坐标
        if 所有 gt_box 的中心都没落在 crop_box 中
            continue

		# 已经确定 crop_box 可用,做相关的坐标处理
        6、筛选出 中心坐标落在 crop_box 中的 gt_box, 及对应的 labels
        7、修改 gt_box 坐标, 防止出现越界的情况: 如果超出 crop_box 的边界,就截断到 crop_box 的边界
        8、重新计算 crop_box 的坐标, 并在 原图 中截取出来, 记为 croped_image
        9、重新计算 gt_box 在 croped_image 中的坐标位置, 记录 new_gt_box

        return croped_image, new_gt_box

代码

# This function is from https://github.com/chauhan-utk/ssd.DomainAdaptation.
class SSDCropping(object):
    """
    根据原文,对图像进行裁剪,该方法应放在ToTensor前
    Cropping for SSD, according to original paper
    Choose between following 3 conditions:
    1. Preserve the original image
    2. Random crop minimum IoU is among 0.1, 0.3, 0.5, 0.7, 0.9
    3. Random crop
    Reference to https://github.com/chauhan-utk/src.DomainAdaptation
    """
    def __init__(self):
        self.sample_options = (
            # Do nothing
            None,
            # min IoU, max IoU
            (0.1, None),
            (0.3, None),
            (0.5, None),
            (0.7, None),
            (0.9, None),
            # no IoU requirements
            (None, None),
        )
        self.dboxes = dboxes300_coco()

    def __call__(self, image, target):
        # Ensure always return cropped image
        while True:
            mode = random.choice(self.sample_options)
            if mode is None:  # 不做随机裁剪处理
                return image, target

            htot, wtot = target['height_width']

            min_iou, max_iou = mode
            min_iou = float('-inf') if min_iou is None else min_iou
            max_iou = float('+inf') if max_iou is None else max_iou

            # Implementation use 5 iteration to find possible candidate
            for _ in range(5):
                # 0.3*0.3 approx. 0.1
                w = random.uniform(0.3, 1.0)
                h = random.uniform(0.3, 1.0)

                if w/h < 0.5 or w/h > 2:  # 保证宽高比例在0.5-2之间
                    continue

                # left 0 ~ wtot - w, top 0 ~ htot - h
                left = random.uniform(0, 1.0 - w)
                top = random.uniform(0, 1.0 - h)

                right = left + w
                bottom = top + h

                # boxes的坐标是在0-1之间的
                bboxes = target["boxes"]
                ious = calc_iou_tensor(bboxes, torch.tensor([[left, top, right, bottom]]))

                # tailor all the bboxes and return
                # all(): Returns True if all elements in the tensor are True, False otherwise.
                if not ((ious > min_iou) & (ious < max_iou)).all():
                    continue

                # discard any bboxes whose center not in the cropped image
                xc = 0.5 * (bboxes[:, 0] + bboxes[:, 2])
                yc = 0.5 * (bboxes[:, 1] + bboxes[:, 3])

                # 查找所有的gt box的中心点有没有在采样patch中的
                masks = (xc > left) & (xc < right) & (yc > top) & (yc < bottom)

                # if no such boxes, continue searching again
                # 如果所有的gt box的中心点都不在采样的patch中,则重新找
                if not masks.any():
                    continue

                # 修改采样patch中的所有gt box的坐标(防止出现越界的情况)
                bboxes[bboxes[:, 0] < left, 0] = left
                bboxes[bboxes[:, 1] < top, 1] = top
                bboxes[bboxes[:, 2] > right, 2] = right
                bboxes[bboxes[:, 3] > bottom, 3] = bottom

                # 虑除不在采样patch中的gt box
                bboxes = bboxes[masks, :]
                # 获取在采样patch中的gt box的标签
                labels = target['labels']
                labels = labels[masks]

                # 裁剪patch
                left_idx = int(left * wtot)
                top_idx = int(top * htot)
                right_idx = int(right * wtot)
                bottom_idx = int(bottom * htot)
                image = image.crop((left_idx, top_idx, right_idx, bottom_idx))

                # 调整裁剪后的bboxes坐标信息
                bboxes[:, 0] = (bboxes[:, 0] - left) / w
                bboxes[:, 1] = (bboxes[:, 1] - top) / h
                bboxes[:, 2] = (bboxes[:, 2] - left) / w
                bboxes[:, 3] = (bboxes[:, 3] - top) / h

                # 更新crop后的gt box坐标信息以及标签信息
                target['boxes'] = bboxes
                target['labels'] = labels

                return image, target

3、Resize

因为 target 中的 gt box 的坐标已经被处理为了 在图像中的比例坐标,所以 Resize 中不用对 target 做处理。

class Resize(object):
    """对图像进行resize处理,该方法应放在ToTensor前"""
    def __init__(self, size=(300, 300)):
        self.resize = t.Resize(size)

    def __call__(self, image, target):
        image = self.resize(image)
        return image, target

4、ColorJitter

class ColorJitter(object):
    """对图像颜色信息进行随机调整,该方法应放在ToTensor前"""
    def __init__(self, brightness=0.125, contrast=0.5, saturation=0.5, hue=0.05):
        self.trans = t.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, image, target):
        image = self.trans(image)
        return image, target

5、ToTensor

做了如下 3 个事情:

  • 将 nump.ndarray 或 PIL.Image 转为 tensor,数据类型为 torch.FloatTensor
  • 把灰度范围从0-255 变换到 0-1之间,其将每一个像素值归一化到 [0,1],其归一化方法比较简单,直接除以255即可
  • 将shape 由 (H,W, C) 转为shape为 (C, H, W)
class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image).contiguous()
        return image, target

6、RandomHorizontalFlip

最重要的就是这里,将 gt box 一并做了翻转

class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes,该方法应放在ToTensor后"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            # height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            # bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            bbox[:, [0, 2]] = 1.0 - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
        return image, target

7、Normalization

为什么不在自己的数据集上计算均值和方差,而是简单的使用 ImageNet 数据集的均值和方差呢?

(很多地方都是这么直接使用的)我理解的是 ImageNet 是一个超大型数据集,在其上计算得出的均值和方差,应该就是绝大部分图像所服从的分布了,是满足需求的,而且自己计算自己数据集的均值和方差的话,耗时耗资源。

class Normalization(object):
    """对图像标准化处理,该方法应放在ToTensor后"""
    def __init__(self, mean=None, std=None):
        if mean is None:
            mean = [0.485, 0.456, 0.406]
        if std is None:
            std = [0.229, 0.224, 0.225]
        self.normalize = t.Normalize(mean=mean, std=std)

    def __call__(self, image, target):
        image = self.normalize(image)
        return image, target

8、AssignGTtoDefaultBox

这里的作用是生成 default box ,我们令起一片文章细说。

class AssignGTtoDefaultBox(object):
    """将DefaultBox与GT进行匹配"""
    def __init__(self):
        self.default_box = dboxes300_coco()
        self.encoder = Encoder(self.default_box)

    def __call__(self, image, target):
        boxes = target['boxes']
        labels = target["labels"]
        # bboxes_out (Tensor 8732 x 4), labels_out (Tensor 8732)
        bboxes_out, labels_out = self.encoder.encode(boxes, labels)
        target['boxes'] = bboxes_out
        target['labels'] = labels_out

        return image, target

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/337105.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue2.x中使用vuex

Vuex是什么&#xff1f; Vuex是一个专门为Vue.js应用程序开发的状态管理模式。它采用集中式存储来管理应用程序中所有组件的状态&#xff0c;并以相应的规则保证状态以一种可预测的方式发生变化。Vuex也被集成到了Vue的官方调试工具vue-devtools中&#xff0c;提供了诸如零配置…

Java:使用Java功能确保应用程序安全的方法

与添加新功能一样重要的是&#xff0c;应用程序开发人员需要开始更加重视他们设计的应用程序的安全性。毕竟&#xff0c;更多的应用程序功能意味着更多的数据驻留在应用程序中。如果没有适当的安全控制&#xff0c;这些数据很容易被入侵者窃取。Java是目前世界上最安全、最流行…

如何去阅读源码,我总结了18条心法

在聊如何去阅读源码之前&#xff0c;先来简单说一下为什么要去阅读源码&#xff0c;大致可分为以下几点原因&#xff1a;最直接的原因&#xff0c;就是面试需要&#xff0c;面试喜欢问源码&#xff0c;读完源码才可以跟面试官battle提升自己的编程水平&#xff0c;学习编程思想…

不要慌,我们谈一谈如何用好 ChatGPT

别人贪婪时我恐惧&#xff0c;别人恐惧时我贪婪。 ——巴菲特 ChatGPT 火了&#xff0c;技术领域的社交媒体、自媒体几乎被 ChatGPT 刷屏&#xff0c;这些内容当中最让人惶恐不安的是我们是否会被 AI 取代之类的文章。 比如以下几个文章标题&#xff1a; 《ChatGPT可能马上…

Transformer结构解读

咱们还是照图讨论&#xff0c;transformer结构图如下&#xff0c;本文主要讨论Encoder部分&#xff1a;图一一、首先说一下Encoder的输入部分&#xff1a;在NLP领域&#xff0c;个人理解&#xff0c;这个inputs就是我们的句子分词之后的词语&#xff0c;比如“我&#xff0c;喜…

符号让人疯狂

符号让人疯狂 判断背了个LV符号的包就想可能有钱 趣讲大白话&#xff1a;人是通过符号区分生活的 聪明人想想&#xff1a;能超越或摆脱符号依赖吗&#xff1f; *********** 信息社会加速符号的传递和创造 我们已经被各种信息传递的符号淹没 信息符号的筛选成了人的主要工作 再…

GRB非隔离系列宽电压输入负高电压输出 电压控制型

特点● 效率高达70%以上● 1*2英寸标准封装● 单电压负输出● 价格低● 电压控制,输出电压随控制电压变化线性变化● 工作温度: -40℃~85℃● 阻燃封装&#xff0c;满足UL94-V0 要求● 温度特性好● 可直接焊在PCB 上应用GRB 系列模块电源是一种DC-DC升压变换器。该模块电源的输…

十、Linux文件 - fread函数讲解

目录 1.fread函数讲解 2.fread函数实战 1.fread函数讲解 从文件中读入数据到指定的地址中 函数原型&#xff1a; size_t fread(void*buff , size_t size, size_t count , FILE* stream) /* * description :对已打开的流进行数据读取 * param ‐ ptr &#xff1a;指向 数据块的…

好用的电脑备份软件推荐

现在几乎每个人都有一台电脑&#xff0c;上面存储着大量的数据&#xff0c;比如宝贵的照片、视频、工作文档等等。但电脑也随时存在许多威胁&#xff0c;比如病毒、Windows 更新错误、死机黑屏、驱动程序问题、系统崩溃等。为防止任何数据丢失&#xff0c;你需要一个专业的电脑…

Oracle数据库故障处理-单块读hang存储异常导致hang死,数据库大量的db file seq read等待(p1 p2无反映)

1 故障描述 2023年1月27日下午接到业务反馈数据库存在大量的锁表阻塞信息&#xff0c;并且业务的页面以及数据库的一些查询均处于阻塞状态&#xff0c;简单的查询sql也需要查询很长时间且未返回结果,数据库hang状态。 问题现象2 1 数据库进程无法杀除。 2 操作系统进程使用…

也许你应该学学 postman了

使用 最简单的方法就是直接在浏览器中复制 Copy as cURL &#xff0c;然后把数据导入 postman&#xff0c;然后 send &#xff0c;收工。 我们这里拿 知乎首页 举例 在对应的请求下复制 cURL 打开 postman &#xff0c; 点击左上角的 Import &#xff0c; 选择Paste Raw Tex…

如何使用逻辑分析仪,解析通信数据

如何使用逻辑分析仪&#xff0c;解析通信数据使用工具&#xff1a;逻辑分析仪&#xff08;几十块买的裸板&#xff09;&#xff0c;软件是&#xff1a;PulseView一、在开发或者移植某一个模块时&#xff0c;你可能遇到这样的问题&#xff1a;二、逻辑分析仪的使用使用工具&…

二级C语言操作例题(四十)

一、程序填空题 在此程序中&#xff0c;函数fun的功能是&#xff1a;在形参s所指字符串中寻找与参数c相同的字符&#xff0c;并在其后插入一个与之相同的字符&#xff0c;若找不到相同的字符则不做任何处理。 例如&#xff0c;若s所指字符串”baacda”&#xff0c;中c的字符为…

JavaWeb-JavaScript API

目录DOM获取元素事务操作操作元素获取/修改元素属性获取/修改表单元素属性实现一个全选效果&#xff0c;主要是操作input的checked属性获取/修改元素样式点击放大字体夜间模式(关灯开灯)操作节点新增节点删除节点案例-猜数字案例-表白墙DOM DOM 全称为 Document Object Model.…

【Spring6源码・MVC】请求处理流程源码解析

上一篇《【Spring6源码・MVC】初始化registry&#xff0c;完成url和controller的映射关系》我们知道&#xff0c;在IOC容器加载的同时&#xff0c;初始化了registry这个HashMap&#xff0c;这个HashMap中存放了请求路径和对应的方法。当我们请求进来&#xff0c;会通过这个regi…

合并两个有序链表-力扣21-java双百方案

一、题目描述将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1&#xff1a;输入&#xff1a;l1 [1,2,4], l2 [1,3,4]输出&#xff1a;[1,1,2,3,4,4]示例 2&#xff1a;输入&#xff1a;l1 [], l2 []输出&#xff1…

C++中编译静态库与动态库

1.库的理解库就是写好的现有的&#xff0c;成熟的&#xff0c;可复用的代码。现实中每个程序都要依赖很多基础的底层库&#xff0c;不可能每个人的代码都从零开始&#xff0c;因此库的存在意义非同寻常。本质上来说库是一种可执行代码的二进制形式&#xff0c;是预编译代码的集…

【Vue3】element-plus中el-tree的递归处理赋值回显问题

目录一&#xff1a;先获取所有权限tree二&#xff1a;在获取所有该角色能有的权限tree三&#xff1a;递归处理勾选tree节点由于项目是从0-1开始构建的 rbac都需要重新构建对接 所以涉及到了权限管理和菜单管理 一级菜单包含多个二级菜单 若二级不全选&#xff0c;则一级显示 半…

scipy超几何函数

文章目录hyp2f1广义超几何函数其他超几何函数hyp2f1 当c不是0,−1,⋯0,-1,\cdots0,−1,⋯时&#xff0c;对于∣z∣<1|z|<1∣z∣<1&#xff0c;超几何函数可表示为 2F1(a;b;c;z)∑n0∞a(n)b(n)c(n)znn!_2F_1(a;b;c;z)\sum^\infty_{n0}\frac{a^{(n)}b^{(n)}}{c^{(n)}}\…

TOOM告诉你企业舆情监测的重要性,企业舆情监测的意义

企业舆情监测是一种有效的企业管理手段&#xff0c;能够帮助企业了解舆情信息&#xff0c;从而更好地管理企业、保护企业利益&#xff0c;TOOM告诉你企业舆情监测的重要性&#xff0c;企业舆情监测的意义。 一、企业舆情监测的重要性 声誉管理&#xff1a;通过对企业在线和离…