Pytorch语义分割(1)-----加载数据

news2025/1/19 3:13:33

一、数据标注

(1)使用labelme来进行分割标注,标注完之后会得到一个json,数据格式如下:

二、获取数据信息

        读取json文件来得到标签信息,读取jpg文件获取图像。在语义分割中用到的数据无非就是原始图片(image)和标注后得到的mask图片,所以在读取数据的时候只要返回图片和标签信息就OK 了。

    def get_label(self, img_path, labelme_json_path, img_size):
        # h, w 是图片的宽高
        mask_array = np.zeros([self.num_classes, h, w, 1], dtype=np.uint8)
        with open(labelme_json_path, "r") as f:
            json_data = json.load(f)
        shapes = json_data["shapes"]
        for shape in shapes:
            category = shape["label"]
            # 获取类别的索引
            category_idx = self.category_types.index(category)
            points = shape["points"]
            points_array = np.array(points, dtype=np.int32)
            temp = mask_array[category_idx, ...]   # 获取出mask_array对应的类别层
            # 将标注的坐标点连接成一个区域,并将区域内的值填为255
            mask_array[category_idx, ...] = cv2.fillPoly(temp, [points_array], 255)
            # 可以将每一层的输出来看mask图
        # 交换维度
        mask_array = np.transpose(mask_array, (1, 2, 0, 3)).squeeze(axis=-1)
        # 将mask转为tensor
        mask_tensor = ut.i2t(mask_array, False)
        return mask_tensor

 如果一张图片只有一个类别,那其他类的mask图就是黑的,效果展示如下:

完整的读取数据代码如下:dataloader_labelme.py

import torch
import os
import numpy as np
from torch.utils.data import Dataset
from utils_func import seg_utils as ut
import cv2
from torchvision.transforms.functional import rotate as tensor_rotate
from torchvision.transforms.functional import vflip, hflip
from torchvision.transforms.functional import adjust_brightness
import random
import base64
import json
import os
import os.path as osp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class train_data(Dataset):
    def __init__(self, image_folder, img_size, category_types):
        self.image_folder = image_folder
        self.img_size = img_size
        self.category_types = category_types
        self.num_classes = len(category_types)
        self.data_list = self.generate_mask()
        # print(self.data_list)

        self.img_and_mask = []
        for idx, data in enumerate(self.data_list):
            img_tensor, mask_tensor = self.get_image_and_label(data[0], data[1], self.img_size)
            self.img_and_mask.append([img_tensor, mask_tensor])

    def __len__(self):
        return len(self.img_and_mask)

    def __getitem__(self, index):
        img_tensor = self.img_and_mask[index][0].to(device)
        mask_tensor = self.img_and_mask[index][1].to(device)

        # 如果有数据增强就在这里处理
        return img_tensor, mask_tensor

    def data_augment(self, img_tensor, mask_tensor, aug_flag):
        if aug_flag[0] == 0:
            angel = random.choice(aug_flag[1])
            img_tensor = tensor_rotate(img_tensor, int(angel))
            mask_tensor = tensor_rotate(mask_tensor, int(angel))
        elif aug_flag[0] == 1:
            factor = aug_flag[1]
            img_tensor = adjust_brightness(img_tensor, factor)
        elif aug_flag[0] == 2:
            flip_type = random.choice(aug_flag[1])
            if flip_type == 1:
                img_tensor = vflip(img_tensor)
                mask_tensor = vflip(mask_tensor)
            else:
                img_tensor = hflip(img_tensor)
                mask_tensor = hflip(mask_tensor)

        return img_tensor, mask_tensor

    def generate_mask(self):
        data_lists = []
        for file_name in os.listdir(self.image_folder):
            if file_name.endswith("json"):
                json_path = os.path.join(self.image_folder, file_name)
                img_path = osp.join(self.image_folder, "%s.jpg" % file_name.split(".")[0])
                data_lists.append([img_path, json_path])
        return data_lists

    def get_image_and_label(self, img_path, labelme_json_path, img_size):
        # print("==================================================")
        # print(img_path)
        # print("==================================================")

        img = ut.p2i(img_path)
        h, w = img.shape[:2]
        img = cv2.resize(img, (img_size[0], img_size[1]))
        # cv2.imwrite(r"C:\Users\HJ\Desktop\test\%s.jpg"%img_path.split(".")[0][-7:], img)
        img_tensor = ut.i2t(img)

        mask_array = np.zeros([self.num_classes, h, w, 1], dtype=np.uint8)
        with open(labelme_json_path, "r") as f:
            json_data = json.load(f)
        shapes = json_data["shapes"]
        for shape in shapes:
            category = shape["label"]
            category_idx = self.category_types.index(category)
            points = shape["points"]
            points_array = np.array(points, dtype=np.int32)
            temp = mask_array[category_idx, ...]
            mask_array[category_idx, ...] = cv2.fillPoly(temp, [points_array], 255)
        mask_array = np.transpose(mask_array, (1, 2, 0, 3)).squeeze(axis=-1)
        mask_array = cv2.resize(mask_array, (self.img_size[0], self.img_size[1])).astype(np.uint8)

        mask_tensor = ut.i2t(mask_array, False)
        return img_tensor, mask_tensor



if __name__ == '__main__':
    img_folder = r"D:\finish_code\SegmentationProject\datasets\data2"
    img_size1 = [256, 512]
    category_types = ["background", "person", "car", "road"]

    t = train_data(img_folder, img_size1, category_types)
    t.__getitem__(1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1801940.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis-sentinel(哨兵模式)的搭建步骤及相关知识

1、什么是redis-sentinel,和redis主从复制相比,它具有什么优势 1.1、redis主从复制 Redis主从复制是一种用于数据冗余和可伸缩性的机制,它将一台Redis服务器的数据复制到其他Redis服务器。在这种模式下,数据会实时地从一个主节点…

SwiftUI六组合复杂用户界面

代码下载 应用的首页是一个纵向滚动的地标类别列表,每一个类别内部是一个横向滑动列表。随后将构建应用的页面导航,这个过程中可以学习到如果组合各种视图,并让它们适配不同的设备尺寸和设备方向。 下载起步项目并跟着本篇教程一步步实践&a…

wx 生命周期

以下内容你不需要立马完全弄明白,不过以后它会有帮助。 下图说明了页面 Page 实例的生命周期。

记录jenkins pipeline ,git+maven+sonarqube+打包镜像上传到阿里云镜像仓库

1、阶段视图: 2、准备工作 所需工具与插件 jdk:可以存在多版本 maven:可以存在多版本 sonar-scanner 凭证令牌 gitlab:credentialsId sonarqube:配置在sonarqube208服务中 3、jenkinsfile pipeline {agent anystages {stage(从…

GSS7000卫星导航模拟器结合RTKLIB 接收NTRIP网络RTCM数据以输出RS232

本文聚焦,使用GSS7000仿真GNSS NTRIP,利用开源工具RTKLIB 作为NTRIP Client 接受GSS7000仿真的RTCM数据, 并通过STRSVR将收到的RTCM数据通过USB-RS232数据线吐出,并转给DUT,让其获得RTK -FIXED 固定解。 废话不多说&a…

微信小程序 导航navigation-bar

属性类型默认值必填说明最低版本titlestring否导航条标题2.9.0loadingbooleanfalse否是否在导航条显示 loading 加载提示2.9.0front-colorstring否导航条前景颜色值,包括按钮、标题、状态栏的颜色,仅支持 #ffffff 和 #0000002.9.0background-colorstring…

如何提高网站收录?

GSI服务就是专门干这个的,这个服务用的是光算科技自己研发的GPC爬虫池系统。这个系统通过建立一个庞大的站群和复杂的链接结构,来吸引谷歌的爬虫。这样一来,你的网站就能更频繁地被谷歌的爬虫访问,从而提高被收录的机会。 说到效…

Python语言读取图像

import cv2 import numpy as np width 640 # 图像宽度height 480 # 图像高度channels 3 # 颜色通道数imgEmpty np.empty((height, width, channels), np.uint8) # 创建空白数组imgBlack np.zeros((height, width, channels), np.uint8) # 创建黑色图像 RGB0imgWhite …

全自动饲料机械成套设备:养殖好帮手

全自动饲料机械成套设备是一套能够自动完成饲料生产全过程的机械设备。从原料的粉碎、混合、制粒,到成品的包装、储存,再到生产过程的监控与管理,全部实现自动化操作。减轻了人工劳动强度,提高了生产效率,同时也保证了…

指针在C/C++中的魔力:一级指针与二级指针

什么是指针? 指针是一个变量,它的值是另一个变量的地址。在C/C中,指针是一个强大的工具,可以让我们直接操作内存地址。指针的主要用途包括动态内存分配、数组和字符串处理、函数参数传递等。 一级指针 一级指针(也称为…

Prometheus+Altermanager实现钉钉告警

PrometheusAltermanager实现钉钉告警 Prometheus和Altermanager的安装这里就不赘述了,我之前的文章有写到 不记得的小伙伴可以去看看Prometheus和Altermanager的安装使用 直接开始上操作 下载钉钉并打开,先创建一个接收告警信息的钉钉群 添加一个自定…

数据结构【二叉树——堆】

二叉树——堆 1.二叉树的概念与性质二叉树的概念特殊的二叉树 2.二叉树的性质3.二叉树的存储结构顺序结构链式结构 4.堆堆的概念堆接口的实现(默认为大堆)堆的结构堆的初始化堆的销毁栈的插入堆的删除取堆顶数据堆的元素个数堆的判空 完整代码Heap.hHeap…

ArcGIS for js 4.x 加载图层

二维&#xff1a; 1、创建vue项目 npm create vitelatest 2、安装ArcGIS JS API依赖包 npm install arcgis/core 3、引入ArcGIS API for JavaScript模块 <script setup> import "arcgis/core/assets/esri/themes/light/main.css"; import Map from arcgis…

计网期末复习指南(五):运输层(可靠传输原理、TCP协议、UDP协议、端口)

前言&#xff1a;本系列文章旨在通过TCP/IP协议簇自下而上的梳理大致的知识点&#xff0c;从计算机网络体系结构出发到应用层&#xff0c;每一个协议层通过一篇文章进行总结&#xff0c;本系列正在持续更新中... 计网期末复习指南&#xff08;一&#xff09;&#xff1a;计算机…

【Go语言精进之路】构建高效Go程序:零值可用、使用复合字面值作为初值构造器

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 引言一、深入理解并利用零值提升代码质量1.1 深入Go类型零值原理1.2 零值可用性的实践与优势1.2.1 切片(Slice)的零值与动态扩展1.2.2 Map的零值与安全访问1.2.3 函数参数与零值 二、使用复合字面值作为初值构造器2.1 结构体…

KT1404A语音芯片USB连电脑,win7正常识别WIN10无法识别USB设备

一、简介 KT1404A语音芯片画的板子&#xff0c;USB连接电脑&#xff0c;win7可以正常识别到U盘&#xff0c;WIN10提示无法识别USB设备&#xff08;获取设备描述符失败&#xff09;&#xff0c;这是什么问题 问题 首先&#xff0c;这款芯片已经出货非常非常多了&#xff0c;所…

【已有项目版】uniapp项目发版pda -- Android Studio

必备资料清单&#xff1a; 构建完成的app项目 在HBuilderX开发的uniapp项目 .keystore文件 文章目录 1. 安装Android Studio&#xff1a;https://developer.android.google.cn/studio?hlzh-cn2. 安装Android 离线SDK&#xff1a;https://nativesupport.dcloud.net.cn/AppDocs…

vs2013 - 打包

文章目录 vs2013 - 打包概述installshield2013limitededitionMicrosoft Visual Studio 2013 Installer Projects选择哪种来打包? 笔记VS2013打包和VS2019打包的区别打包工程选择view打包工程中单击工程名称节点&#xff0c;就可以在属性框中看到要改的属性(e.g. 默认是x86, 要…

全面分析找不到msvcr120.dll,无法继续执行程序问题

在计算机使用过程中&#xff0c;我们可能会遇到一些错误提示&#xff0c;其中“找不到msvcr120.dll”就是常见的一种。那么&#xff0c;找不到msvcr120.dll是什么意思呢&#xff1f; 一&#xff0c;msvcr120.dll文件概述 msvcr120.dll 是 Microsoft Visual C Redistributable …