YoloV8修改分类(Classify)的前处理(记录)

news2024/9/20 5:37:00

修改原因

  • yolo自带的分类前处理对于长方形的数据不够友好,存在特征丢失等问题
  • 修改后虽然解决了这个问题但是局部特征也会丢失因为会下采样程度多于自带的,总之具体哪种好不同数据应该表现不同
  • 我的数据中大量长宽比很大的数据所以尝试修改自带的前处理,以保证理论上的合理性。
修改过程
  1. yolo中自带的分类前处理和检测有一些差异

调试推理代码发现ultralytics/models/yolo/classify/predict.py中对图像进行前处理的操作主要是self.transforms

def preprocess(self, img):
        """Converts input image to model-compatible data type."""
    if not isinstance(img, torch.Tensor):
        is_legacy_transform = any(
            self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
        )
        if is_legacy_transform:  # to handle legacy transforms
            img = torch.stack([self.transforms(im) for im in img], dim=0)
        else:
            # import ipdb;ipdb.set_trace()
            img = torch.stack(
                [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
            )
    img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
    return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32

通过调试打印self.transforms得到

Compose(
    Resize(size=96, interpolation=bilinear, max_size=None, antialias=True)
    CenterCrop(size=(96, 96))
    ToTensor()
    Normalize(mean=tensor([0., 0., 0.]), std=tensor([1., 1., 1.]))
)

假设我设置的imgsz为96,从这里简单的解读可以理解为先进行resize然后进行中心裁切保证输入尺寸为96x96

具体的查看哪里可以修改前处理,首先发现在ultralytics/engine/predictor.py中
def setup_source(self, source):
 """Sets up source and inference mode."""
  self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2)  # check image size
  # import ipdb; ipdb.set_trace()
  self.transforms = (
      getattr(
          self.model.model,
          "transforms",
          classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), #dujiang
      )
      if self.args.task == "classify"
      else None
  )

可以发现self.transforms主要调用的是classify_transforms方法

进一步我们在ultralytics/data/augment.py中找到classify_transforms的实现
if scale_size[0] == scale_size[1]:
      # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
      tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))]
  else:
      # Resize the shortest edge to matching target dim for non-square target
      tfl = [T.Resize(scale_size)]
  tfl.extend(
      [
          T.CenterCrop(size),
          T.ToTensor(),
          T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
      ]
  )

发现和我们的设想基本一致,查看代码逻辑首先是针对正方形图像会将图像缩放到指定的高度,同时保持长宽比,确保较短的一边正好等于目标尺寸,非正方形图片将短边resize到指定大小,长边此时可能是超出的,所以 T.CenterCrop(size)进行中心裁切确保尺寸是我们指定的

针对上面的分析可能问题就很明显了,如果处理的图像是长宽比非常不均匀的图像,那么中心裁切会导致丢失大量信息,我参考了检测的方法,决定将分类的预处理修改为填充而不是裁切

  • 首先确定思想,我想做的是根据长边resize到指定尺寸并且保证长宽比,短边会不足,刚好与原本的代码逻辑相反
  • 然后短边不足的地方进行填充保证短边也达到指定尺寸(填充yolo好像一般是144,这里我也选择144)
  • 具体实现如下
  1. 添加两个类分别实现resizepadding
class ResizeLongestSide:
    def __init__(self, size, interpolation):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        # 获取图像的当前尺寸
        width, height = img.size
        # 计算缩放比例
        if width > height:
            new_width = self.size
            new_height = int(self.size * height / width)
        else:
            new_height = self.size
            new_width = int(self.size * width / height)
        # 按长边缩放
        return img.resize((new_width, new_height), Image.BILINEAR)

class PadToSquare:
    def __init__(self, size, fill=(114)):
        self.size = size
        self.fill = fill

    def __call__(self, img):
        # 获取当前尺寸
        width, height = img.size
        # 计算需要填充的大小
        delta_w = self.size - width
        delta_h = self.size - height
        padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
        # 填充图像
        return F.pad(img, padding, fill=self.fill, padding_mode='constant')
  1. 调用上面的类进行实现
def classify_transforms(
    size=96,
    mean=DEFAULT_MEAN,
    std=DEFAULT_STD,
    interpolation="BILINEAR",
    crop_fraction: float = DEFAULT_CROP_FRACTION,
    padding_color=(114, 114, 114),  # 默认填充为灰色
):

    import torchvision.transforms as T
    import torch
    from torchvision.transforms import functional as F

    # import ipdb;ipdb.set_trace()
    tfl = [
        # T.ClassifyLetterBox(size),
        ResizeLongestSide(size, interpolation=getattr(T.InterpolationMode, interpolation)),  # 按长边缩放
        PadToSquare(size, fill=padding_color),  # 填充至正方形
        T.ToTensor(),
        T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
    ]

    return T.Compose(tfl)
  1. 想要训练前先确定自己修改是否符合预期进行如下测试
Examples:
    >>> from ultralytics.data.augment import LetterBox, classify_transforms, classify_transforms_with_padding
    >>> from PIL import Image
    >>> transforms = classify_transforms_with_padding(size=96)
    >>> img = Image.open('bus.jpg')  3ch img_rgb = Image.merge('RGB', (img, img, img))
    >>> transformed_img = transforms(img)
    >>>import torchvision.transforms as T
    >>>DEFAULT_MEAN = (0.0, 0.0, 0.0)
    >>>DEFAULT_STD = (1.0, 1.0, 1.0)
    >>>import torch
    >>>def save_transformed_image(transformed_img, save_path="transformed_image.png"):
    # 定义反向变换,将张量转换回 PIL 图像
    unnormalize = T.Normalize(
        mean=[-m / s for m, s in zip(DEFAULT_MEAN, DEFAULT_STD)],
        std=[1 / s for s in DEFAULT_STD]
    )
    img_tensor = unnormalize(transformed_img)
    img_tensor = torch.clamp(img_tensor, 0, 1)
    to_pil = T.ToPILImage()
    img_pil = to_pil(img_tensor)
    img_pil.save(save_path)
    print(f"Image saved at {save_path}")
    >>>save_transformed_image(transformed_img, save_path="transformed_image.png")
  1. 效果图
    请添加图片描述
    请添加图片描述
  2. ok,效果预期一致,接下来可以训练了,之前对于矩形的图像会有裁切现在使用padding解决了。但是具体效果还得看结果。
  3. 补充一下修改一定要把类和方法分开,即不要在方法中定义类,这样会导致训练出错
总结中间遇到问题参考这里解决

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2123350.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

怎么做静态码一物一码?批量制作静态码的简单方法

静态二维码是日常很常见的一种二维码类型,一会用来展示文本或者链接跳转等用途使用,比如在很多的物品包装上,扫描二维码就可以查看物品对应的商品编码,就是静态二维码应用的一种。那么静态二维码批量生成的方法是什么样的呢&#…

干货 | Selenium+chrome自动批量下载地理空间数据云影像

1.背景介绍 1.1地理空间数据云 由中国科学院计算机网络信息中心科学数据中心成立的地理空间数据云平台是常见的下载空间数据的平台之一。其提供了较为完善的公开数据,如LANDSAT系列数据,MODIS的标准产品及其合成产品,DEM数据(SR…

客户需求挖掘的三个步骤

本文将介绍客户需求挖掘的三个关键步骤,帮助企业更好地理解客户,并提供个性化的服务。通过分析客户需求,可以更好地满足客户期望,提升客户满意度和忠诚度。 前言 本文将介绍客户需求挖掘的三个关键步骤,帮助企业更好地…

ZooKeeper--分布式协调服务

文章目录 ZooKeeperzk的由来zk解决了什么问题 ZK工作原理ZK数据模型zk功能1.命名服务2.状态同步3.配置中心4.集群管理 zk部署单机启动zk验证zk zk集群集群角色选举过程1.节点角色状态2.选举ID3.具体过程4.心跳机制5.ZAB协议 ZooKeeper 选举示例1.第一轮投票:2.节点收…

Flutter学习之一搭建开发环境

Flutter学习之一:搭建ununtu系统开发环境 一.背景 随着企业发展跟环境的变化,目前大前端开发越来越火,在国内应该是一个趋势;个人的技术栈主要还是在原生安卓开发上;长江后浪推前浪,如果不及时学习新知识&#xff0c…

中文文本分类详解及与机器学习算法对比

一.文本分类 文本分类旨在对文本集按照一定的分类体系或标准进行自动分类标记,属于一种基于分类体系的自动分类。文本分类最早可以追溯到上世纪50年代,那时主要通过专家定义规则来进行文本分类;80年代出现了利用知识工程建立的专家系统&…

首届云原生编程挑战赛总决赛冠军比赛攻略_greydog.队

关联比赛: 首届云原生编程挑战赛【复赛】实现一个 Serverless 计算服务调度系统 一、初赛赛道一(实现一个分布式统计和过滤的链路追踪) 赛题分析 1、数据来源 采集自分布式系统中的多个节点上的调用链数据,每个节点一份数据文件。数据格式…

系统架构师考试学习笔记第四篇——架构设计实践知识(21)安全架构设计理论与实践

本章考点: 第21课时主要学习信息系统中安全架构设计的理论和工作中的实践。根据考试大纲,本课时知识点会涉及案例分析题和论文题(各占25分),而在历年考试中,综合知识选择题目中也有过诸多考查。本课时内容侧重于知识点记忆;,按照以往的出题规律,安全架构设计基础知识…

工具知识 | Linux常用命令

参考 linw7的github《鸟哥的Linux私房菜》 一.文件管理 1.文件查找:find2.文件拷贝:cp3.打包解包:tar 二.文本处理 1.(显示行号)查看文件:nl2.文本查找:grep3.排序:sort4.转换:tr5.切分文本&…

Web 基础——Apache

Event Worker 的升级版、把服务器进程和连接进行分析,基于异步 I/O 模型。 请求过来后进程并不处理请求,而是直接交由其它机制来处理,通过 epoll 机制来通知请求是否完成; 在这个过程中,进程本身一直处于空闲状态&am…

【目标检测数据集】铁轨表面缺损检测数据集4789张VOC+YOLO格式

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):4789 标注数量(xml文件个数):4789 标注数量(txt文件个数):4789 标注…

只有IP地址没有域名怎么实现HTTPS访问?

🔐 实现IP地址HTTPS访问 🌐 确认公网IP地址 公网IP:确保你拥有一个公网IP地址,或者内网映射公网,这是实现HTTPS访问的前提。 📝 选择证书颁发机构(CA) 选择CA:选择一个…

【Qt】Qt音频

Qt 音频 在 Qt 中,⾳频主要是通过 QSound 类来实现。但是需要注意的是 QSound 类只⽀持播放 wav 格式的⾳频⽂件。也就是说如果想要添加⾳频效果,那么⾸先需要将 ⾮wav格式 的⾳频⽂件转换为 wav 格式。 【注意】使⽤ QSound 类时,需要添加模…

【C#Mutex】 initiallyOwned错误引起的缺陷

临界区只能对同一个进程的不同线程同步,互斥量可以跨进程同步。典型应用场景:两个exe会操作同一个注册表项。 错误代码 封装类 public class CMutexHelp : IDisposable {public CMutexHelp(){s_mutex.WaitOne();} private static Mutex s_mutex …

深度学习-目标检测(二)Fast R-CNN

一:Fast R-CNN Fast R-CNN 是一篇由Ross Girshick 在 2015 年发表的论文,题为 “Fast R-CNN”。这篇论文旨在解决目标检测领域中的一些问题,特别是传统目标检测方法中存在的速度和准确性之间的矛盾。 论文摘要:本文提出了一种基于…

关于tomcat如何设置自启动的设置

希望文章能给到你启发和灵感~ 如果觉得文章对你有帮助的话,点赞 关注 收藏 支持一下博主吧~ 阅读指南 开篇说明一、基础环境说明1.1 硬件环境1.2 软件环境 二、Windows 下的设置服务自启2.1 服务的注册2.2 开启自启 三、MacOS下设置服务自启…

ROS CDK魔法书:建立你的游戏王国(Python篇)

引言 在虚拟游戏的世界里,数字化的乐趣如同流动的音符,谱写着无数玩家的共同回忆。而在这片充满创意与冒险的乐园中,您的使命就是将独特的游戏体验与丰富的技术知识相结合,打造出令人难以忘怀的作品。当面对如何实现这一宏伟蓝图…

【数据结构】4——树和森林

数据结构——4树和森林 笔记 文章目录 数据结构——4树和森林树的存储结构双亲表示法孩子链表孩子兄弟表示法(二叉树表示法、二叉链表表示法) 树与二叉树转换森林和二叉树转化森林转二叉树二叉树转森林 树和森林的遍历树先根后根层次 森林 树的存储结构…

使用nvm工具实现多个nodejs版本的维护和切换

NodeJS的升级比较快,在开发中要使用最新的版本,必须经常升级,但对于一些老项目可能又要使用低版本的NodeJS,虽然可以在系统中同时安装多个NodeJS的版本,然后通过修改环境变量的方式实现切换,但这种方法太麻…

断点回归模型

断点回归(Regression Discontinuity Design, RDD)是一种准实验设计方法,用于评估政策或其他干预措施的效果。这种方法利用了一个清晰的阈值或“断点”,在这个阈值上,处理状态(例如是否接受某种干预&#xf…