生动理解深度学习精度提升利器——测试时增强(TTA)

news2025/2/25 18:33:31

测试时增强(Test-Time Augmentation,TTA)是一种在深度学习模型的测试阶段应用数据增强的技术手段。它是通过对测试样本进行多次随机变换或扰动,产生多个增强的样本,并使用这些样本进行预测的多数投票或平均来得出最终预测结果。

为了直观理解TTA执行的过程,这里我绘制了流程示意图如下所示:

TTA的过程如下:

  1. 数据增强:

    • 在测试时,对每个测试样本应用随机的变换或扰动操作,生成多个增强样本。
    • 常用的数据增强操作包括随机翻转、随机旋转、随机裁剪、随机缩放等。这些操作可以增加样本的多样性,模拟真实世界中的不确定性和变化。
  2. 多次预测:

    • 使用训练好的模型对生成的增强样本进行多次预测。
    • 对于每个增强样本,都会得到一个预测结果。
  3. 预测结果集成:

    • 对多次预测的结果进行集成,常用的集成方式有多数投票和平均。
    • 对于分类任务,多数投票即选择预测结果中出现次数最多的类别作为最终的预测类别。对于回归任务,平均即将多次预测结果进行平均。

接下来针对性地对比分析下使用TTA带来的优点和缺点:

优点:

  • 提高鲁棒性:通过应用数据增强,TTA可以增加样本的多样性和泛化能力,提高模型在面对未见过的输入分布和未知变化时的鲁棒性。
  • 提高准确性:通过多次预测和集成,TTA可以减少预测结果的随机性和偶然误差,提高最终预测结果的稳定性和准确性。
  • 模型评估和排名:TTA可以改变模型预测的不确定性,使得模型评估更可靠,能够更好地对不同模型进行性能排名。

缺点:

  • 计算开销:生成和预测多个增强样本会增加计算量。特别是在大型模型和复杂任务中,可能导致推理时间的显著增加,限制了TTA的实际应用。
  • 可能造成过拟合:对于已包含在训练数据中的变换或扰动,如果在测试时反复应用,可能会导致模型对这些特定样本的过拟合,从而影响模型的泛化能力。

TTA是一种常用的技术手段,通过应用数据增强和集成预测结果,可以提高深度学习模型在测试阶段的性能和鲁棒性。然而,TTA的应用需要平衡计算开销和预测准确性,并谨慎处理可能导致模型过拟合的问题。根据具体任务和需求,可以灵活选择合适的增强操作和集成策略来使用TTA。

下面是demo代码实现,如下所示:

import numpy as np
import torch
import torchvision.transforms as transforms

def test_time_augmentation(model, image, n_augmentations):
    # 定义数据增强的变换
    transform = transforms.Compose([
        transforms.ToTensor(),
        # 在此添加你需要的任何其他数据增强操作
    ])

    # 存储多次预测结果的列表
    predictions = []

    # 对图像应用多次增强和预测
    for _ in range(n_augmentations):
        augmented_image = transform(image)
        augmented_image = augmented_image.unsqueeze(0)  # 增加一个维度作为批次
        with torch.no_grad():
            # 切换模型为评估模式,确保不执行梯度计算
            model.eval()
            # 使用增强的图像进行预测
            output = model(augmented_image)
            _, predicted = torch.max(output.data, 1)
            predictions.append(predicted.item())

    # 执行多数投票并返回最终预测结果
    final_prediction = np.bincount(predictions).argmax()

    return final_prediction

在前文鸟类细粒度识别项目实验中测试发现,应用TTA技术后,对应的评估指标上有明显的涨点,但是很明显地可以发现:在整个测试过程中资源消耗增加明显,且耗时显著增长,这也是TTA无法避免的劣势,在对精度要求较高的场景下可以有限考虑引入TTA,但是对于计算时耗要求较高的场景则不推荐使用TTA。

开源社区里面也有一些优秀的实现,这里推荐一个,地址在这里,如下所示:

目前有将近1k的star量,还是蛮不错的。

安装方法如下所示:

pip安装:
pip install ttach


源码安装:
pip install git+https://github.com/qubvel/ttach
        Input
             |           # input batch of images 
        / / /|\ \ \      # apply augmentations (flips, rotation, scale, etc.)
       | | | | | | |     # pass augmented batches through model
       | | | | | | |     # reverse transformations for each batch of masks/labels
        \ \ \ / / /      # merge predictions (mean, max, gmean, etc.)
             |           # output batch of masks/labels
           Output

目前支持分割、分类、关键点检测三种任务,实例使用如下所示:

Segmentation model wrapping [docstring]:
import ttach as tta
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')


Classification model wrapping [docstring]:
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform())


Keypoints model wrapping [docstring]:
tta_model = tta.KeypointsTTAWrapper(model, tta.aliases.flip_transform(), scaled=True)
data transforms 实例实现如下所示:
# defined 2 * 2 * 3 * 3 = 36 augmentations !
transforms = tta.Compose(
    [
        tta.HorizontalFlip(),
        tta.Rotate90(angles=[0, 180]),
        tta.Scale(scales=[1, 2, 4]),
        tta.Multiply(factors=[0.9, 1, 1.1]),        
    ]
)

tta_model = tta.SegmentationTTAWrapper(model, transforms)

Custom model (multi-input / multi-output)实现如下所示:

# Example how to process ONE batch on images with TTA
# Here `image`/`mask` are 4D tensors (B, C, H, W), `label` is 2D tensor (B, N)

for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 
    
    # augment image
    augmented_image = transformer.augment_image(image)
    
    # pass to model
    model_output = model(augmented_image, another_input_data)
    
    # reverse augmentation for mask and label
    deaug_mask = transformer.deaugment_mask(model_output['mask'])
    deaug_label = transformer.deaugment_label(model_output['label'])
    
    # save results
    labels.append(deaug_mask)
    masks.append(deaug_label)
    
# reduce results as you want, e.g mean/max/min
label = mean(labels)
mask = mean(masks)

Transforms详情如下所示:

TransformParametersValues
HorizontalFlip--
VerticalFlip--
Rotate90anglesList[0, 90, 180, 270]
Scalescales
interpolation
List[float]
"nearest"/"linear"
Resizesizes
original_size
interpolation
List[Tuple[int, int]]
Tuple[int,int]
"nearest"/"linear"
AddvaluesList[float]
MultiplyfactorsList[float]
FiveCropscrop_height
crop_width
int
int

支持的结果融合方法如下:

mean
gmean (geometric mean)
sum
max
min
tsharpen (temperature sharpen with t=0.5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/990840.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenCV(二十九):图像腐蚀

1.图像腐蚀原理 腐蚀操作的原理是将一个结构元素(也称为核或模板)在图像上滑动,并将其与图像中对应位置的像素进行比较。如果结构元素的所有像素与图像中对应位置的像素都匹配,那么该位置的像素值保持不变。如果结构元素的任何一个…

【软考】系统集成项目管理工程师(三)信息系统集成专业技术知识③

一、云计算 1、定义 通过互联网来提供大型计算能力和动态易扩展的虚拟化资源;云是网络、互联网的一种比喻说法。是一种大集中的服务模式。 2、特点 (1)超大规模(2)虚拟化(3)高可扩展性&…

Unity UGUI(二)核心组件

Unity Canvas相关知识学习 文章目录 Unity Canvas相关知识学习1. Canvas:1.1 Render Mode1.2 多个Canvas的显示顺序 2.Canvas Scaler:屏幕分辨率自适应2.1 UI Scale Mode 3. EventSystem4. Standalone Input Module5. Graphic Raycaster:图形…

创邻科技图数据库课程走进一流高校

《图数据库原理和实践》 正式开课! 最近,浙江大学计算机学院新开了一门名为 《图数据库原理和实践》 的新课程,该课程由创邻科技和浙江大学联合推出,吸引了许多学生踊跃参与! 曾为浙大学子的创邻科技CTO周研博士作为…

「Java开发指南」在MyEclipse中的Spring开发(一)

MyEclipse v2023.1.2离线版下载(Q技术交流:742336981) 1. 什么是Spring? 在MyEclipse中引入Spring比大多数框架更难,因为它不是一种单一用途的技术。Spring被认为是Java软件开发在几乎每个领域都有最佳实践的巨大框架&#xff0…

canvas绘制渐变色三角形金字塔

项目需求:需要绘制渐变色三角形金字塔,并用折线添加标识 (其实所有直接用图片放上去也行,但是ui没切图,我也懒得找她要,正好也没啥事,直接自己用代码绘制算了,总结一句就是闲的) 最终效果如下图: (以上没用任何图片,都是代码绘制的) 在网上找了,有用canvas绘…

sql server 查询某个字段是否有值 返回bool类型

sql server 查询某个字段是否有值 返回bool类型,true 或 false SELECT ColumnCode,CONVERT(BIT,CASE WHEN LEN(ColumnCode) > 0 THEN 1 ELSE 0 END) AS HasValue FROM dbo.TF_LessonCatalog

Java(一)安装并使用 java(Windows)

安装并使用java 前言一、初识Java1.Java的安装1.1下载JDK1.2JDK安装与使用1.2.1安装1.2.2 IDEA(编译器)使用 2.Java运行编程逻辑(重要后面要用)总结 前言 学习很重要,复习也很重要,对于编程语言的复习更为…

初试小程序轮播组件

文章目录 一、轮播组件(一)swiper组件1、功能描述2、属性说明 (二)swiper-item组件1、功能描述2、属性说明 二、案例演示(一)运行效果(二)实现步骤1、创建小程序项目2、准备图片素材…

Azure + React + ASP.NET Core 项目笔记一:项目环境搭建(二)

有意义的标题 pnpm 安装umi4 脚手架搭建打包语句变更Visual Studio调试Azure 设置变更发布 pnpm 安装 参考官网,或者直接使用npm安装 npm install -g pnpmumi4 脚手架搭建 我这里用的umi4,官网已附上 这里需要把clientapp清空,之后 cd Cl…

构建普适通用的企业网络安全体系框架

在当今数字化时代,网络安全已成为企业保护信息资产和业务运行的重要任务。恶意攻击、数据泄露、网络病毒等威胁不断演进,给企业和个人带来了巨大风险。为了应对这一挑战,许多企业已经采取了一系列网络安全措施,如制定了网络安全政…

Nacos:Spring Cloud Alibaba服务注册与配置中心

Nacos 英文全称为 Dynamic Naming and Configuration Service,是一个由阿里巴巴团队使用 Java 语言开发的开源项目。 Nacos 是一个更易于帮助构建云原生应用的动态服务发现、配置和服务管理平台(参考自 Nacos 官网)。 Nacos 的命名是由 3 部…

系统架构设计师(第二版)学习笔记----嵌入式系统及软件

【原文链接】系统架构设计师(第二版)学习笔记----嵌入式系统及软件 文章目录 一、嵌入式系统1.1 嵌入式系统的组成1.2 嵌入式系统的特点1.3 嵌入式系统的分类 二、嵌入式软件2.1 嵌入式系统软件分层2.2 嵌入式软件的主要特点 三、安全攸关软件的安全性设…

ubuntu20.04 Supervisor 开机自启动脚本一文配置

前言: 最近发现一种非常好的开机启动服务方式,不光可以开机自启动,而且还可以进行开机节点的进程守护,这样大大确保了线程的稳定情况,这种服务甚至可以守护开机的进程,所以比之前设置 rc.local 开机自启动脚本一文配置节点好出很多,它甚至可以使用网页登录监管我开机自启…

RabbitMQ: 死信队列

一、在客户端创建方式 1.创建死信交换机 2.创建类生产者队列 3.创建死信队列 其实就是一个普通的队列,绑定号私信交换机,不给ttl,给上匹配的路由,等待交换机发送消息。 二、springboot实现创建类生产者队列 1.在消费者里的…

OpenText EnCase Endpoint Security 识别潜在的网络安全威胁并快速消灭威胁

如今,敏感数据丢失和 IT 系统中断是各类组织面临的最大危机。网络攻击频率不断攀升、修复成本日益增加以及响应时间延长都加剧了数据丢失的隐患。 OpenText EnCase Endpoint Security 的高效体现在能够加速检测恶意活动,并在其导致不可挽回的损失或丢失敏…

SpringSecurity OAuth2 配置 token有效时长

1.这种方式配置之后,并没有生效 package com.enterprise.auth.config;import com.enterprise.auth.handler.OAuthServerWebResponseExceptionTranslator; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bea…

2023国赛数学建模B题思路代码 - 多波束测线问题

# 1 赛题 B 题 多波束测线问题 单波束测深是利用声波在水中的传播特性来测量水体深度的技术。声波在均匀介质中作匀 速直线传播, 在不同界面上产生反射, 利用这一原理,从测量船换能器垂直向海底发射声波信 号,并记录从声波发射到…

java文件命令行报错: 找不到或无法加载主类XXX报错及解决

前言 之前遇到过几次,后面稀里糊涂的解决了。今天详细记录一下,可能不全或有些错误,还请各位指正。 你要启动一个类的话首先要有类。 在这里,类有两种, 一个是带包名(package)的还有一个是没包…

lvs负载均衡、

四:LVS集群部署 lvs给nginx做负载均衡项目 218lvs yum -y install ipvsadm 设置VIP 定义策略 ipvsadm -C //清空现有规则 -A增加虚拟服务器记录 -D删除虚拟服务器记录 -L查看 150web-111 配置好网站服务器,测试所有RS [nginx-stable] namengin…