Ubuntu20运行SegNeXt代码提取道路水体(四)——成功解决训练与推理自己的数据集iou为0的问题!!

news2024/9/24 3:24:41

在我的这篇博文里
Ubuntu20运行SegNeXt代码提取道路水体(三)——SegNeXt训练与推理自己的数据集
经过一系列配置后
iou算出来是0
经过多次尝试后
终于让我试出来了正确配置方法!

具体的配置细节请查看这篇文章

1、在mmseg/datasets下面对数据集进行初始定义

我新建了一个myroaddata.py文件
 里面的内容是:

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import mmcv
import numpy as np
from PIL import Image

from .builder import DATASETS
from .custom import CustomDataset


@DATASETS.register_module()
class MyRoadData(CustomDataset):
    
    CLASSES = ('background','road')

    PALETTE = [[0,0,0],[255, 255, 255]]

    def __init__(self, **kwargs):
    	super(MyRoadData, self).__init__(img_suffix='_sat.tif', seg_map_suffix='_mask.png', 
                     **kwargs)
    	assert osp.exists(self.img_dir)

2、修改mmseg/datasets/目录下的_init_.py

 把我的自定义数据集加到原_init_.py中

# Copyright (c) OpenMMLab. All rights reserved.
from .ade import ADE20KDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .custom import CustomDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
                               RepeatDataset)
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .loveda import LoveDADataset
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
from .stare import STAREDataset
from .voc import PascalVOCDataset
from .myroaddata import MyRoadData

__all__ = [
    'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
    'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
    'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
    'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
    'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset',
    'iSAIDDataset', 'ISPRSDataset', 'PotsdamDataset','MyRoadData'
]

3、在configs/base/datasets下面对数据加载进行定义

 我新建了一个myroad.py

里面的内容为

# dataset settings
dataset_type = 'MyRoadData'
data_root = 'data/MyRoadData'
img_norm_cfg = dict(
    mean=[0.5947, 0.5815, 0.5625], std=[0.1173, 0.1169, 0.1157], to_rgb=True)
img_scale = (512, 512)
crop_size = (256, 256)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', prob=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=img_scale,
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ])
]

data = dict(
    samples_per_gpu=4,
    workers_per_gpu=8,
    train=dict(
        type='RepeatDataset',
        times=40000,
        dataset=dict(
            type=dataset_type,
            data_root=data_root,
            img_dir='images/training',
            ann_dir='annotations/training',
            pipeline=train_pipeline)),
    val=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline),
    test=dict(
        type=dataset_type,
        data_root=data_root,
        img_dir='images/validation',
        ann_dir='annotations/validation',
        pipeline=test_pipeline))


4、在configs/下面选择你需要的模型参数进行修改

在configs/下面选择你需要的模型参数进行修改 以pspnet为例子,在configs/pspnet/下新建一个文件pspnet_r50-d8_512x1024_40k_myroaddata.py

_base_ = [
    '../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/myroad.py',
    '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
]

5、修改configs/base/models/下面的pspnet_r50-d8.py

# model settings
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='open-mmlab://resnet50_v1c',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='PSPHead',
        in_channels=2048,
        in_index=3,
        channels=512,
        pool_scales=(1, 2, 3, 6),
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
    auxiliary_head=dict(
        type='FCNHead',
        in_channels=1024,
        in_index=2,
        channels=256,
        num_convs=1,
        concat_input=False,
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
    # model training and testing settings
    train_cfg=dict(),
    test_cfg=dict(mode='whole'))

6、返回tools/train.py进行训练

python tools/train.py configs/pspnet/pspnet_r50-d8_512x1024_40k_myroaddata.py
就可以跑啦

结果图:

自定义数据集格式配置 

 在data文件夹下新建一个MyRoadData文件夹,存放数据

再次新建俩个文件夹

annotation和images下面新建training和validation文件夹

 annotation-training下放训练标签

annotation-validation放预测标签

同理

images-training放训练原图

images-validation下放预测原图

 

1、图片格式要求为8位深度

注意,如果是24位的图片要全部转成8位!!!!

不然会报错

 

转换代码如下

# -*- coding: utf-8 -*-
"""
Created on Wed Oct 4 16:50:20 2022

@author:Laney_Midory
csdn:Laney_Midory
"""
import cv2
import os

import glob
import shutil

import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable as V

import pickle

from time import time



os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # 指定第一块GPU可用

# config.gpu_options.per_process_gpu_memory_fraction = 0.7  # 程序最多只能占用指定gpu50%的显存,服务器上注释掉这句

Image.MAX_IMAGE_PIXELS = None

tar = "/home/wangtianni/SegNeXt-main/SegNeXt-main/data/data/MyRoadData/annotations/training/"
print('将24位深度转换为8位')
mask_names = filter(lambda x: x.find('png')!=-1, os.listdir(tar))
#trainlist = list(map(lambda x: x[:-8], imagelist))


#new_path = "C:/Users/Administrator/Desktop/white/"  # 目标文件夹


for file in mask_names:

    path = tar + file.strip()
    if not os.path.exists(path):
        continue;  
    img = Image.open(tar+file)#读取系统的内照片 
   
    img2 = img.convert('P')
   # print(train_path+'\\'+base_name[0]+'_mask.png')

    img2.save(path)
   
    #img2.save(new_path +path2 + "_mask.png")
    print("Finish deep change!")
   

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1143533.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用baostock获取上市公司情况

起因是有个不知道什么专业的同学问了我一题 cs: import baostock as bs import pandas as pd import datetime 日线指标参数包括:date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM,isST 周…

CompletableFuture的理解

CompletableFuture 是 JDK1.8 里面引入的一个基于事件驱动的异步回调类。简单来说,就是当使用异步线程去执行一个任务的时候,我们希望在任务结束以后触发一个后续的动作。而 CompletableFuture 就可以实现这个功能。 业务问题 举个简单的例子&#xff0…

最长公共子序列(LCS)与最长上升子序列(LIS)问题的相互转换

在此只做直观理解,不做严格证明 参考:LCS 问题与 LIS 问题的相互关系,以及 LIS 问题的最优解证明 LCS转LIS LCS转LIS只能对特殊情况适用。即当LCS中两个数组有一个不存在重复元素的情况下才能进行转换。 我们以一个例子进行说明&#xff0c…

leetCode 76. 最小覆盖子串 + 滑动窗口 + 哈希Hash

我的往期文章:此题的其他解法,感兴趣的话可以移步看一下: leetCode 76. 最小覆盖子串 滑动窗口 图解(详细)-CSDN博客https://blog.csdn.net/weixin_41987016/article/details/134042115?spm1001.2014.3001.5501 力…

XJ+Nreal 高精度地图+Nreal眼镜SDK到发布APK至眼镜中

仅支持Anroid平台 Nreal套装自带的计算单元,其实也是⼀个没有显示器的Android设备 新建unity⼯程,将⼯程切换Android平台。 正在上传…重新上传取消正在上传…重新上传取消 Cloud XDK Unity User Manual for Nreal ARGlasses 该XDK是针对 NReal AR 眼镜…

mysql源码安装

Linux环境 1、mysql下载地址:https://dev.mysql.com/downloads/mysql/5.7.html#downloads 下载参考: 2、把下载的 MySQL 压缩包上传到 Linux 服务器 3、解压mysql-5.7.39-linux-glibc2.12-x86_64.tar.gz tar -zxvf mysql-5.7.39-linux-glibc2.12-x86…

【UE】Rider编辑器错误 .NET SDK 的版本 6.0300 至少需要 MBuild 的 17.00 版本,当前可用的 MSuld 版本

异常:.NET SDK 的版本 6.0300 至少需要 MBuild 的 17.00 版本,当前可用的 MSuld 版本为 16.1.2.50704请在 global.json 中指定的 .NET SDK 更为需要当前可用的 MSBuld 版本的版本 解决 切换当前使用的MBuild版本 File->Settings…打开设置窗口 找到…

基于web和mysql的图书管理系统

系统分为用户端和管理员端 用户端功能如下 登陆注册数据一栏个人信息修改个人信息修改密码图书查询借阅信息借阅状态安全退出 管理员端功能如下 登录个人信息修改个人信息修改密码读者管理书籍管理借阅管理借阅状态安全退出 用户 管理员 源码下载地址 支持:远程…

javaswing/gui+mysql的学生信息管理系统

使用了Java Swing作为前端界面的开发工具,而 MySQL 作为后端数据库管理系统。这个系统主要用于学生信息的管理,包括班级和学生的增删改查操作。 在系统开发过程中,首先设计了数据库表结构,包括班级表和学生表,并定义了…

损失函数总结(九):SoftMarginLoss、MultiLabelSoftMarginLoss

损失函数总结(九):SoftMarginLoss、MultiLabelSoftMarginLoss 1 引言2 损失函数2.1 SoftMarginLoss2.2 MultiLabelSoftMarginLoss 3 总结 1 引言 在前面的文章中已经介绍了介绍了一系列损失函数 (L1Loss、MSELoss、BCELoss、CrossEntropyLos…

NCCL后端

"NCCL" 代表 "NVIDIA Collective Communications Library","NVIDIA 集体通信库",它是一种由 NVIDIA 开发的用于高性能计算的通信库。NCCL 专门设计用于加速 GPU 群集之间的通信,以便在并行计算和深度学习等领域…

智能直播,助力新营销战场 !降本增效,新消费市场唾手可得

在当今竞争激烈的全球商业环境中,企业们迫切需要降低成本、提高效率,物联网(IoT:Internet of Things)的快速崛起为企业提供了全新的增长动力。它直接改变了人们的生活方式,其中最突出的表现就是网购&#x…

每日一题 2558. 从数量最多的堆取走礼物(简单,heapq)

怎么这么多天都是简单题,不多说了 class Solution:def pickGifts(self, gifts: List[int], k: int) -> int:gifts [-gift for gift in gifts]heapify(gifts)for i in range(k):heappush(gifts, -int(sqrt(-heappop(gifts))))return -sum(gifts)

2023MathorCup(妈妈杯) 数学建模挑战赛 解题思路

云顶数模最新解题思路免费分享~~ 2023妈妈杯数学建模A题B题思路,供大家参考~~ A题 B题

ReentrantLock 是如何实现锁公平和非公平性的 ?

公平和非公平 公平,指的是竞争锁资源的线程,严格按照请求顺序来分配锁。非公平,表示竞争锁资源的线程,允许插队来抢占锁资源。ReentrantLock 默认采用了非公平锁的策略来实现锁的竞争逻辑。 ReentrantLock ReentrantLock 内部使…

C程序设计(第五版)谭浩强

目录 目录 第1章程序设计和C语言 ​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑第2章算法——程序的灵魂 ​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编辑​编…

Java面向对象(进阶)-- this关键字的使用

文章目录 一、引子(1) this是什么?(2)什么时候使用this1.实例方法或构造器中使用当前对象的成员2. 同一个类中构造器互相调用 二、探讨(1)问题(2)解决 三、this关键字&am…

sql---慢查询和语句耗时

查看当前会话的所有的sql语句耗时情况 profile 开启 查询指定sql的各个阶段耗时 查看执行计划指令 Explain Explain select * from 表 Index 和 all 属于性能不太好 在不扫描得的情况下才可能为null,index表示使用了索引但是扫描了所有的索引&#xff…

Java 入门指南:使用 Docker 创建容器化 Spring Boot 应用程序

文章目录 步骤 1: 准备工作步骤 2: 克隆 Spring Boot 应用程序步骤 3: 创建 Dockerfile步骤 4: 构建 Docker 映像步骤 5: 运行容器步骤 6: 链接到本地数据库步骤 7: 使用 Docker Compose 运行多个容器步骤 8: 设置 CI/CD 管道结论 🎈个人主页:程序员 小侯…

《SpringBoot项目实战》第三篇—留下用户调用接口的痕迹

系列文章导航 第一篇—接口参数的一些弯弯绕绕 第二篇—接口用户上下文的设计与实现 第三篇—留下用户调用接口的痕迹 第四篇—接口的权限控制 第五篇—接口发生异常如何统一处理 本文参考项目源码地址:summo-springboot-interface-demo 前言 大家好!…