PIDNet(语义分割)排坑

news2024/12/28 5:21:05

PIDNet训练自己的数据集

    • 1. 前言
    • 2. 准备工作
    • 3. 配置环境
    • 4. 排坑过程
        • 4.1.1 configs增加了VOC文件夹 并在里面写了yaml参数文件
        • 4.1.2 加载VOC格式数据集的类
        • 4.1.3 train.py调试

1. 前言

paper小修时reviewer说baseline太老,所以对CVPR2023的PIDNet进行复现,用于下游任务的baseline。所以写一个记录说明完整的过程,同时要说的是:
没有从0开始训练,用的PIDNet作者提供的在ImageNet上训练的预训练权重来继续训练的。数据集我是用的是无人机航拍数据集AeroScapes,VOC格式。解决了两个问题:
[1] 基础模型只支持cityscapes和camvid数据集,现在我支持了VOC格式的分割数据集。
[2] 基础模型是用double GPUs来训练 其中涉及到很多需要修改才能适配单张GPU训练

2. 准备工作

代码下载:去github官网下载,地址:PIDNet-code
论文下载:CVPR可以直接看PDF,地址:PIDNet-paper

3. 配置环境

这里就不一一介绍了 能看到这里 说明大家已经是炼丹老师傅了 一个个看缺啥库就下啥 最好用虚拟环境。

4. 排坑过程

4.1.1 configs增加了VOC文件夹 并在里面写了yaml参数文件

参考了作者的yaml文件写的
【1】更改了DATASET部分 因为我的数据集(aeroscapes)放在了data目录下,训练和验证的的索引也变成了我现在的路径。
【2】 MODEL部分 我增加了作者提供的ImageNet上的预训练权重的路径 pretrained_models/imagenet/PIDNet_S_ImageNet.pth.tar,这个权重是small版本,还有medium和large版本。small版本的权重下载地址是:PIDNet_S_ImageNet.pth.tar,下载后直接放到pretrained_models\imagenet\PIDNet_S_ImageNet.pth.tar路径下即可。
【3】其他超参数 微微动了点
以上修改后的完整代码如下:

# name: pidnet_vai_aero.yaml
CUDNN:
  BENCHMARK: true
  DETERMINISTIC: false
  ENABLED: true
GPUS: 0
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 3
PRINT_FREQ: 10

DATASET:
  DATASET: voc
  ROOT: 'data/'
  TEST_SET: 'data/aeroscapes/ImageSets/Segmentation/val.txt'
  TRAIN_SET: 'data/aeroscapes/ImageSets/Segmentation/train.txt'
  NUM_CLASSES: 11
MODEL:
  NAME: pidnet_small
  NUM_OUTPUTS: 2
  PRETRAINED: "pretrained_models/imagenet/PIDNet_S_ImageNet.pth.tar"
LOSS:
  USE_OHEM: true
  OHEMTHRES: 0.9
  OHEMKEEP: 131072
  BALANCE_WEIGHTS: [0.4, 1.0]
  SB_WEIGHTS: 1.0
TRAIN:
  IMAGE_SIZE:
  - 960
  - 720
  BASE_SIZE: 960
  BATCH_SIZE_PER_GPU: 6
  SHUFFLE: true
  BEGIN_EPOCH: 0
  END_EPOCH: 200
  RESUME: false
  OPTIMIZER: sgd
  LR: 0.005
  WD: 0.0005
  MOMENTUM: 0.9
  NESTEROV: false
  FLIP: true
  MULTI_SCALE: true
  IGNORE_LABEL: 255
  SCALE_FACTOR: 16
TEST:
  IMAGE_SIZE:
  - 960
  - 720
  BASE_SIZE: 960
  BATCH_SIZE_PER_GPU: 1
  FLIP_TEST: false
  MULTI_SCALE: false
  MODEL_FILE: ''
  OUTPUT_INDEX: 1

4.1.2 加载VOC格式数据集的类

打开datasets/init.py文件 加上

from .voc_dataloader import VOC as VOC

然后创建一个voc_dataloader.py,这个过程中仿照了作者cityscapes.py中的类,实现一样的初始化参数来匹配接口。这个py文件的代码如下:

import os

import cv2
import numpy as np
import torch
from PIL import Image

from .base_dataset import BaseDataset

class VOC(BaseDataset):
    # 数据集类的构造函数
    def __init__(self,
                 root,
                 list_path,
                 num_classes=11,
                 multi_scale=True, 
                 flip=True, 
                 ignore_label=255, 
                 base_size=2048, 
                 crop_size=(512, 1024),
                 scale_factor=16,
                 mean=[0.452, 0.502, 0.434],
                 std=[0.196, 0.161, 0.179],
                 bd_dilate_size=4):
        super(VOC, self).__init__(ignore_label, base_size,
                                  crop_size, scale_factor, mean, std)
        # 用构造函数的参数初始化类的成员变量
        self.root = root
        self.list_path = list_path
        self.num_classes = num_classes
        self.multi_scale = multi_scale
        self.flip = flip
        self.bd_dilate_size = bd_dilate_size
        self.ignore_label = ignore_label
        
        # 读取图像ID列表
        with open(self.list_path, 'r') as f:
            self.image_ids = [line.strip() for line in f.readlines()]

        self.files = []
        for image_id in self.image_ids:
            image_file = os.path.join(self.root, "aeroscapes/JPEGImages", image_id + '.jpg')
            label_file = os.path.join(self.root, "aeroscapes/SegmentationClass", image_id + '.png')
            self.files.append({
                "image": image_file,
                "label": label_file,
                "name": image_id
            })

        self.class_weights = torch.FloatTensor([0.80906685, 1.01004548, 
                                                1.15333424, 1.0154087,  
                                                1.20380376, 1.23027661,
                                                1.11751722, 0.98967911, 
                                                0.88035226, 0.79071721, 
                                                0.79979855]).cuda()

    def __len__(self):
        return len(self.files)
    
    
    def __getitem__(self, index):
        item = self.files[index]
        name = item["name"]
        image = cv2.imread(item["image"], cv2.IMREAD_COLOR)
        size = image.shape

        if 'test' in self.list_path:
            image = self.input_transform(image)
            image = image.transpose((2, 0, 1))

            return image.copy(), np.array(size), name

        label = cv2.imread(item["label"], cv2.IMREAD_GRAYSCALE)


        label[label == 255] = self.ignore_label
        label[label >= self.num_classes] = self.ignore_label
        label[label < 0] = self.ignore_label
        # label = torch.from_numpy(label).long()


        image, label, edge = self.gen_sample(image, label,
                                             self.multi_scale, self.flip, edge_size=self.bd_dilate_size)

        return image.copy(), label.copy(), edge.copy(), np.array(size), name

这里可以看到作者提前计算了数据集的mean,std和class_weights,所以我们也要计算出来并替换进去。根据我提供的代码来进行计算得到结果并替换进去,计算的代码是:

import os
import cv2
import numpy as np
from tqdm import tqdm

def compute_mean_std(image_paths):
    # 初始化
    channel_sum = np.zeros(3)
    channel_squared_sum = np.zeros(3)
    num_pixels = 0

    for img_path in tqdm(image_paths):
        img = cv2.imread(img_path)  # BGR 格式
        img = img / 255.0  # 归一化到 [0, 1]
        h, w, c = img.shape
        num_pixels += h * w

        # 累加每个通道的像素值
        channel_sum += np.sum(np.sum(img, axis=0), axis=0)

        # 累加每个通道的像素值的平方
        channel_squared_sum += np.sum(np.sum(np.square(img), axis=0), axis=0)

    # 计算均值
    mean = channel_sum / num_pixels

    # 计算标准差
    std = np.sqrt(channel_squared_sum / num_pixels - np.square(mean))

    return mean, std


def compute_class_weights(label_paths, num_classes, ignore_label=255):
    class_counts = np.zeros(num_classes)

    for label_path in tqdm(label_paths):
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        # 忽略被标记为 ignore_label 的像素
        label = label[label != ignore_label]

        # 统计每个类别的像素数量
        for i in range(num_classes):
            class_counts[i] += np.sum(label == i)

    # 确保像素数量不为零,防止取对数和除零错误
    epsilon = 1e-6
    pixel_count = class_counts + epsilon

    # 调用您的函数计算类别权重
    class_weights = get_weight(num_classes, pixel_count)

    return class_weights

def get_weight(class_num, pixel_count):
    W = 1 / np.log(pixel_count)
    W = class_num * W / np.sum(W)
    return W



# 获取数据集中的所有图像路径
image_dir = 'data/aeroscapes/JPEGImages'
image_paths = [os.path.join(image_dir, filename) for filename in os.listdir(image_dir) if filename.endswith('.jpg')]

# 获取数据集中的所有标签路径
label_dir = 'data/aeroscapes/SegmentationClass'
label_paths = [os.path.join(label_dir, filename) for filename in os.listdir(label_dir) if filename.endswith('.png')]

mean, std = compute_mean_std(image_paths)
print("Dataset Mean: ", mean)
print("Dataset Std: ", std)
num_classes = 11  # 根据你的数据集设置
class_weights = compute_class_weights(label_paths, num_classes)
print("Class Weights: ", class_weights)


得到结果:
在这里插入图片描述在这里插入图片描述

4.1.3 train.py调试

tools/train.py中的第一个函数是parse_args(),可以看到我们的参数配置文件现在是新的,所以把路径修改成现在的configs/VOC/pidnet_vai_aero.yaml,具体代码如下:

def parse_args():
    parser = argparse.ArgumentParser(description='Train segmentation network')
    
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        default="configs/VOC/pidnet_vai_aero.yaml",
                        type=str)
    parser.add_argument('--seed', type=int, default=304)    
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()
    update_config(config, args)

    return args

调试main函数 给这个parse_args()函数加上断点发现没有问题。step by step执行后发现,代码中有计算当前gpu数量的功能,代码如下:

gpus = list(config.GPUS)
    if torch.cuda.device_count() != len(gpus):
        print("The gpu numbers do not match!")
        return 0

由于我们现在只用一个GPU,所以去刚刚的pidnet_vai_aero.yaml文件中先把GPUS的参数从(0,1)改成0,回到train.py把这个计算gpu数量的代码修改成

gpus = config.GPUS
    if torch.cuda.device_count() != gpus+1:
        print("The gpu numbers do not match!")
        return 0

(tips:当然可以删除这一段代码。另外:由于现在GPUS的值是0,当中的list对于int形会报错 后面也要删除list。len测量这个list的长度也会报错,把后面出现的len()函数都删了 )。
继续执行代码,发现batch_size作者是设置了一颗GPU是多少 乘以 GPU数量 直接删除len(gpus)即可。后面出现len(gpus)也记得删除,不然报错。

# 原本的代码
# batch_size = config.TRAIN.BATCH_SIZE_PER_GPU * len(gpus)
# 现在的代码
batch_size = config.TRAIN.BATCH_SIZE_PER_GPU

继续执行发现刚刚的voc_dataloader没有给scale_factor传入参数 所以直接在voc_dataloader.py给它初始化为16。继续执行发现没有问题,在后面有一段代码

model = FullModel(model, sem_criterion, bd_criterion)
model = nn.DataParallel(model, device_ids=gpus).cuda()
#改成了
model = FullModel(model, sem_criterion, bd_criterion).cuda()

不用并行了 本人就用一张卡跑。测试没问题,继续执行。发现本地numpy版本新一些,继承了BaseDataset类,会出现np.int报错 所以把base_dataset.py中的np.int全部改成了int后错误消失。至此,成功run了。直接python tools/train.py试一试。发现可以训练
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2228986.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Google Recaptcha V2 简单使用

最新的版本是v3&#xff0c;但是一直习惯用v2&#xff0c;就记录一下v2 的简单用法&#xff0c;以免将来忘记了 首先在这里注册你域名&#xff0c;如果是本机可以直接直接填 localhost 或127.0.0.1 https://www.google.com/recaptcha/about/ 这是列子 网站密钥&#xff1a;是…

autMan奥特曼机器人-内置Redis

autMan内置了redis服务&#xff0c;有的脚本运行需要redis支持 几个注意事项&#xff1a; 启用redis服务后要重启autMan生效&#xff0c;关闭一样的道理。启用redis服务后会增加约200M的内存占用多个autMan的redis服务可以组成集群redis服务

五、快速入门K8s之Pod容器的生命周期

一、容器的初始化init ⭐️ init c &#xff1a; init contariner 初始化容器&#xff0c;只是用来初始化&#xff0c;初始化完成就会死亡可以大于的等于一也可以没有&#xff0c;每个init只有在前一个init c执行完成后才可以执行下一个、init容器总是运行到成功完成为止&#…

sqoop问题汇总记录

此篇博客仅记录在使用sqoop时遇到的各种问题。持续更新&#xff0c;有问题评论区一起探讨&#xff0c;写得有不足之处见谅。 Oracle_to_hive 1. main ERROR Could not register mbeans java.security.AccessControlException: access denied ("javax.management.MBeanTr…

C++对象模型:Function 语意学

Member 的各种调用方式 Nonstatic Member Function 使用C时&#xff0c;成员函数和非成员函数在性能上应该是等价的。当设计类时&#xff0c;我们不应该因为担心效率问题而避免使用成员函数。 实现&#xff1a;编译器会将成员函数转换为一个带有额外this指针参数的非成员函数…

二叉树中的深搜 算法专题

二叉树中的深搜 一. 计算布尔二叉树的值 计算布尔二叉树的值 class Solution {public boolean evaluateTree(TreeNode root) {if(root.left null) return root.val 0? false: true;boolean left evaluateTree(root.left);boolean right evaluateTree(root.right);return…

【Linux】环境ChatGLM-4-9B 模型部署

一、模型介绍 GLM-4-9B 是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开源版本。 在语义、数学、推理、代码和知识等多方面的数据集测评中&#xff0c; GLM-4-9B 及其人类偏好对齐的版本 GLM-4-9B-Chat 均表现出超越 Llama-3-8B 的卓越性能。除了能进行多轮对话&#xf…

深入理解Java 线程并发编排工具: 概述和应用场景

目录 前言概述1. CountDownLatch2. CyclicBarrier3. Semaphore&#xff08;信号量)4. Condition 案例CountDownLatch-马拉松场景CyclicBarrier-马拉松场景Semaphore-公交车占座场景Condition-线程等待唤醒场景 前言 在 Java 的 java.util.concurrent (JUC) 包中&#xff0c;提…

C++初阶(八)--内存管理

目录 引入&#xff1a; 一、C中的内存布局 1.内存区域 2.示例变量存储位置说明 二、C语言中动态内存管理 三、C内存管理方式 1.new/delete操作内置类型 2.new和delete操作自定义类型 四、operator new与operator delete函数&#xff08;重要点进行讲解&#xff09; …

架构的本质之 MVC 架构

前言 程序员习惯的编程方式就是三步曲。 所以&#xff0c;为了不至于让一个类撑到爆&#x1f4a5;&#xff0c;需要把黄色的对象、绿色的方法、红色的接口&#xff0c;都分配到不同的包结构下。这就是你编码人生中所接触到的第一个解耦操作。 分层框架 MVC 是一种非常常见且常…

Node学习记录-child_process 子进程

来自&#xff1a;https://juejin.cn/post/7277045020422930488 child_process用于处理CPU密集型应用&#xff0c;Nodejs创建子进程有7个API&#xff0c;其中带Async的是同步API,不带的是异步API child_process.exec(command[, options][, callback]) command:要运行的命令&am…

NVR批量管理软件/平台EasyNVR多个NVR同时管理支持对接阿里云、腾讯云、天翼云、亚马逊S3云存储

随着云计算技术的日益成熟&#xff0c;越来越多的企业开始将其业务迁移到云端&#xff0c;以享受更为灵活、高效且经济的服务模式。在视频监控领域&#xff0c;云存储因其强大的数据处理能力和弹性扩展性&#xff0c;成为视频数据存储的理想选择。NVR批量管理软件/平台EasyNVR&…

2024年编程语言排行榜:技术世界的新星与常青树

随着技术的不断进步&#xff0c;编程语言的流行度也在不断变化。今天&#xff0c;就让我们一起来看看2024年的编程语言排行榜&#xff0c;探索哪些语言在技术世界中占据了主导地位。 1. Python&#xff1a;稳居榜首 Python以其在人工智能、数据科学、网络开发等多个领域的广泛…

MFC工控项目实例二十八模拟量信号每秒采集100次

采用两个多媒体定时器&#xff0c;一个0.1秒计时,另一个用来对模拟量信号采集每秒100次.。 1、在SEAL_PRESSUREDlg.h中添加代码 class CSEAL_PRESSUREDlg : public CDialog { public:CSEAL_PRESSUREDlg(CWnd* pParent NULL); // standard constructor&#xff0e;&#xff0e…

基于MoviNet检测视频中危险暴力行为

项目源码获取方式见文章末尾&#xff01; 600多个深度学习项目资料&#xff0c;快来加入社群一起学习吧。 《------往期经典推荐------》 项目名称 1.【Faster & Mask R-CNN模型实现啤酒瓶瑕疵检测】 2.【卫星图像道路检测DeepLabV3Plus模型】 3.【GAN模型实现二次元头像生…

ArcGIS003:ArcMap常用操作0-50例动图演示

摘要&#xff1a;本文以动图形式介绍了ArcMap软件的基本操作&#xff0c;包括快捷方式创建、管理许可服务、操作界面元素&#xff08;如内容列表、目录树、搜索窗口、工具箱、Python窗口、模型构建器窗口等&#xff09;的打开与关闭、位置调整及合并&#xff0c;设置默认工作目…

NVR批量管理软件/平台EasyNVR多个NVR同时管理支持视频投放在电视墙上

在当今智能化、数字化的时代&#xff0c;视频监控已经成为各行各业不可或缺的一部分&#xff0c;无论是公共安全、交通管理、企业监控还是智慧城市建设&#xff0c;都离不开高效、稳定的视频监控系统的支持。而在这些应用场景中&#xff0c;将监控视频实时投放到大屏幕电视墙上…

asp.net core 跨域配置不起作用的原因

1、中间件配置跨域的顺序不对 中间件顺序配置对了基本上就能解决大部分问题中间件顺序配置对了基本上就能解决大部分问题 附上官网简单的启用跨域的代码 var MyAllowSpecificOrigins "_myAllowSpecificOrigins";var builder WebApplication.CreateBuilder(args);…

Linux 命令解释器-shell

概念 shell &#xff1a;壳&#xff0c;命令解释器&#xff0c;负责解析用户输入的命令 分类&#xff1a; 内置命令 (shell 内置 ) &#xff0c; shell 为了完成自我管理和基本的管理&#xff0c;不同的 shell 内置不同的命令&#xff0c;但是大 部分都差不多 外置命令&…

【开源免费】基于SpringBoot+Vue.JS网上超市系统(JAVA毕业设计)

本文项目编号 T 037 &#xff0c;文末自助获取源码 \color{red}{T037&#xff0c;文末自助获取源码} T037&#xff0c;文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…