【python-Unet】计算机视觉~舌象舌头图片分割~机器学习(三)

news2025/1/10 2:56:27

返回至系列文章导航博客

1 简介

舌体分割是舌诊检测的基础,唯有做到准确分割舌体才能保证后续训练以及预测的准确性。此部分真正的任务是在用户上传的图像中准确寻找到属于舌头的像素点。舌体分割属于生物医学图像分割领域。分割效果如下:
在这里插入图片描述

2 数据集介绍

舌象数据集包含舌象原图以及分割完成的二元图,共979*2张,示例图片如下:
在这里插入图片描述

数据集+源代码获取途径:
闲鱼链接

【闲鱼】https://m.tb.cn/h.UHsoI2k?tk=UdxzdPyLXyQ CZ3457 「我在闲鱼发布了【舌象数据集,详情见csdn!http://t.csdn.cn】」
点击链接直接打开

3 模型介绍

U-Net是一个优秀的语义分割模型,在中e诊中U-Net共三部分,分别是主干特征提取部分、加强特征提取部分、预测部分。利用主干特征提取部分获得5个初步有效的特征层,之后通过加强特征提取部分对上述获取到的5个有效特征层进行上采样并进行特征融合。最终获得了一个结合所有特征的有效特征层,并利用最终有效特征层对像素点进行预测,找到属于舌体的像素点。具体操作详情如下图所示:
在这里插入图片描述

进行标注后利用PyTorch框架构建U-Net模型抓取舌象图像特征,预测舌象图像标签。为对模型进行评价,在训练中计算每次循环的平均损失率。最终每张图的损失了约为2%左右。具体的平均损失率变化如下图:
在这里插入图片描述
训练共历时4天,共979张标记图像,最终平均预测损失率约为2%。模型预测,即舌体分割的效果非常理想,在此展示当损失率为40%与损失率为2%时的分割结果示例,示例如下图所示:
(1)损失率为40%时分割结果图
在这里插入图片描述
(2)损失率为2%时分割结果图
在这里插入图片描述
根据模型预测结果对属于舌体的像素点进行匹配提取,将不属于舌体的部分以墨绿色进行填充,最终的舌体分割效果图如下:
在这里插入图片描述

4 代码实现细节

4.1 相关文件介绍

在这里插入图片描述
notedata文件夹中有分割标注图片、ordata文件夹中有原始图片、params文件夹中有训练模型文件、result文件夹中有测试样例图片、train_image文件夹中有训练过程图片。

4.2 utils.py

工具类:由于数据集中各个图片的大小是不一样的,为了保障后续工作可以顺利进行,这里应该定义一个工具类将图片可以等比例缩放至256*256(可以改看自己需求)。

from PIL import Image

def keep_image_size_open(path, size=(256, 256)):
    img = Image.open(path)
    temp = max(img.size)
    mask = Image.new('RGB', (temp, temp), (0,0,0))
    mask.paste(img, (0,0))
    mask = mask.resize(size)
    return mask

4.3 data.py

这里主要是将数据集中标签图片与原图进行匹配合并~具体步骤代码注释中有详解!

import os
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform = transforms.Compose([
    transforms.ToTensor()
    ])

class MyDataset(Dataset):
    def __init__(self, path):   #拿到标签文件夹中图片的名字
        self.path = path
        self.name = os.listdir(os.path.join(path, 'notedata'))
        
    def __len__(self):          #计算标签文件中文件名的数量
        return len(self.name)
    
    def __getitem__(self, index):   #将标签文件夹中的文件名在原图文件夹中进行匹配(由于标签是png的格式而原图是jpg所以需要进行一个转化)
        segment_name = self.name[index] #XX.png
        segment_path = os.path.join(self.path, 'notedata', segment_name)
        image_path = os.path.join(self.path, 'ordata', segment_name.replace('png', 'jpg')) #png与jpg进行转化
        
        segment_image = keep_image_size_open(segment_path)  #等比例缩放
        image = keep_image_size_open(image_path)            #等比例缩放
        
        return transform(image), transform(segment_image)

if __name__ == "__main__":
    data = MyDataset("E:/ITEM_TIME/project/UNET/")
    print(data[0][0].shape)
    print(data[0][1].shape)

在这里插入图片描述
可见数据集已经规整!

4.4 net.py

Unet网络的编写!
在这里插入图片描述

from torch import nn
import torch
from torch.nn import functional as F


class Conv_Block(nn.Module):   #卷积
    def __init__(self, in_channel, out_channel):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1, padding_mode='reflect', 
                      bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', 
                      bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
            )
        
    def forward(self, x):
        return self.layer(x)
    
    
class DownSample(nn.Module):    #下采样
    def __init__(self, channel):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(channel, channel,3,2,1,padding_mode='reflect',
                      bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
            
            )
        
    def forward(self,x):
        return self.layer(x)
    
    
class UpSample(nn.Module):   #上采样(最邻近插值法)
    def __init__(self, channel):
        super(UpSample, self).__init__()
        self.layer = nn.Conv2d(channel, channel//2,1,1)
        
    def forward(self,x, feature_map):
        up = F.interpolate(x, scale_factor=2, mode='nearest')
        out = self.layer(up)
        return torch.cat((out,feature_map),dim=1)
    
    
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64, 128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2=UpSample(512)
        self.c7=Conv_Block(512,256)
        self.u3=UpSample(256)
        self.c8=Conv_Block(256,128)
        self.u4=UpSample(128)
        self.c9=Conv_Block(128,64)
        
        self.out = nn.Conv2d(64,3,3,1,1)
        self.Th = nn.Sigmoid()

       
        
    def forward(self,x):
        R1 = self.c1(x)
        R2 = self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        
        O1 = self.c6(self.u1(R5,R4))
        O2 = self.c7(self.u2(O1,R3))
        O3 = self.c8(self.u3(O2,R2))
        O4 = self.c9(self.u4(O3,R1))
        
        return self.Th(self.out(O4))
    
if __name__ == "__main__":
    x = torch.randn(2, 3, 256, 256)
    net  = UNet()
    print(net(x).shape)
         

在这里插入图片描述
结果匹配说明没问题~

4.5 train.py

训练代码~

from torch import nn
from torch import optim
import torch
from data import *
from net import *
from torchvision.utils import save_image
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet.pth'
data_path = 'E:/ITEM_TIME/project/UNET/'
save_path = 'train_image'

if __name__ == "__main__":
    
    dic = []###
    
    data_loader = DataLoader(MyDataset(data_path),batch_size=3,shuffle=True)  #batch_size用3/4都可以看电脑性能
    net = UNet().to(device)
    if os.path.exists(weight_path):
        net.load_state_dict(torch.load(weight_path))
        print('success load weight')
    else:
        print('not success load weight')
        
    opt = optim.Adam(net.parameters())
    loss_fun = nn.BCELoss()
    
    epoch = 1
    while True:
        avg = []###
        for i, (image,segment_image) in enumerate(data_loader):
            image,segment_image = image.to(device),segment_image.to(device)
            
            out_image = net(image)
            train_loss = loss_fun(out_image, segment_image)
            
            opt.zero_grad()
            train_loss.backward()
            opt.step()
            
            if i%5 == 0:
                print('{}-{}-train_loss===>>{}'.format(epoch,i,train_loss.item()))
                
            if i%50 == 0:
                torch.save(net.state_dict(), weight_path)
            #为方便看效果将原图、标签图、训练图进行拼接
            _image = image[0]
            _segment_image = segment_image[0]
            _out_image = out_image[0]
            
            img = torch.stack([_image,_segment_image,_out_image],dim=0)
            save_image(img, f'{save_path}/{i}.jpg')
            
            avg.append(float(train_loss.item()))###
            
        
        
        loss_avg = sum(avg)/len(avg)
        
        dic.append(loss_avg)
        
        epoch += 1
    print(dic)
    

在这里插入图片描述
可见代码成功运行~上面的损失率是在训练4天后的效果,刚开始肯定很大很差,需要有耐心!

4.6 test.py

测试代码,对图片进行智能分割~

from net import *
from utils import keep_image_size_open
import os
import torch
from data import *
from torchvision.utils import save_image
from PIL import Image
import numpy as np

net = UNet().cpu()  #或者放在cuda上

weights = 'params/unet.pth'  #导入网络

if os.path.exists(weights):
    net.load_state_dict(torch.load(weights))
    print('success')
else:
    print('no loading')
    
_input = 'xxxx.jpg'  #导入测试图片

img = keep_image_size_open(_input)


img_data = transform(img)
print(img_data.shape)

img_data = torch.unsqueeze(img_data, dim=0)

print(img_data)
out = net(img_data)

save_image(out, 'result/result.jpg')
save_image(img_data, 'result/orininal.jpg')

print(out)

#E:\ITEM_TIME\UNET\ordata\4292.jpg

img_after = Image.open(r"result\result.jpg")
img_before = Image.open(r"result\orininal.jpg")
#img.show()
img_after_array = np.array(img_after)#把图像转成数组格式img = np.asarray(image)
img_before_array = np.array(img_before)

shape_after = img_after_array.shape
shape_before = img_before_array.shape

print(shape_after,shape_before)

#将分隔好的图片进行对应像素点还原,即将黑白分隔图转化为有颜色的提取图

if shape_after == shape_before:
    height = shape_after[0]
    width = shape_after[1]
    dst = np.zeros((height,width,3))
    for h in range(0,height):
        for w in range (0,width):
            (b1,g1,r1) = img_after_array[h,w]
            (b2,g2,r2) = img_before_array[h,w]
            
            if (b1, g1, r1) <= (90, 90, 90): 
                img_before_array[h, w] = (144,238,144) 
            dst[h,w] = img_before_array[h,w]
    img2 = Image.fromarray(np.uint8(dst))
    img2.save(r"result\blend.png","png")

else:
    print("失败!")

结果展示:
(1)原图(orininal.jpg):
在这里插入图片描述
(2)模型分割图(result.jpg):
在这里插入图片描述
(3)对应像素点还原图(blend.png):就是将(2)中的图白色的部分用原图像素点填充,黑色的部分用绿色填充
在这里插入图片描述

至此,舌体分割完成!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/424336.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一些二叉树相关面试题

文章目录1. 对折2. 判断是否是平衡二叉树3. 判断是否是搜索二叉树4. 二叉树的直径5. 寻找最大二叉搜索树6. 用递归套路判断是否是完全二叉树7. 派对的最大快乐值1. 对折 这个大家可以自己用纸对折一下&#xff0c;我这里就简单的说一下&#xff1a; 这是我们第一次对折的情况。…

使用PHP做个图片防盗链(全网详解)

概念&#xff1a; 防盗链是一种防范网络图片、视频等资源被他人盗链&#xff08;直接在其它网站使用&#xff09;的技术。在网站上添加防盗链功能可以防止其他网站恶意盗取自己网站的图片等内容&#xff0c;减少带宽消耗和保护网站内容安全。通常实现防盗链的方式是在网站服务…

js 使用 Array.from 快速生成0~5,步进值为0.1的数组

一、我们平常用的比较多的方法是for循环生成 let data[] for(let i0;i < 51;i){data.push(i/10) }二、用Array.from生成 先来认识一下我们今天的主角&#xff01;&#xff01;&#xff01; 1、释义 Array.from() 方法从一个类似数组或可迭代对象创建一个新的&#xff0c…

SpringMVC(三):请求流程处理

一、引言&#xff1a; 如下是我画的一个简单的SpringMVC的请求流程图&#xff0c;接下来会通过请求流程图去进行源码分析。 [1 ] 当我们客户端发送请求时&#xff0c;Servlet会进行请求的解析&#xff0c;然后交给DispatcherServlet进行统一分发。[2] DispatcherServlet会根…

北京君正案例:超能面板PRO采用4英寸IPS超清多彩屏,值不值得买?

清晨&#xff0c;窗帘自动拉开&#xff0c;悦耳音乐缓缓响起&#xff0c;面包机、咖啡机自动工作&#xff0c;开启新一天。离家时&#xff0c;一键关掉所有灯光和家电&#xff0c;节能安全&#xff0c;手机上便可查看家里设备状态&#xff0c;不用担心门没锁、灯没关等问题。下…

ClickHouse入门详解

ClickHouse基础部分详解一、ClickHouse简介二、ClickHouse单机版安装2.1、ClickHouse安装前准备环境2.2、ClickHouse单机安装三、ClickHouse数据类型四、ClickHouse的表引擎一、ClickHouse简介 对于其他乱起八糟的简介&#xff0c;我就不写了&#xff0c;只写干货. ClickHouse总…

Lumen6 /laravel 框架路由请求实现token验证

版本 Lumen6.0 中文文档&#xff1a;https://learnku.com/docs/lumen/5.7/cache/2411 实现功能效果 1、使用缓存存储用户token 2、从请求头head 中获取用户token 3、返回指定的认证失败结构体 4、对指定的接口路由做身份验证 第一步&#xff1a;解除注释 注意&#xff1…

QML控件--Container

文章目录一、控件基本信息二、控件说明三、属性成员四、成员函数一、控件基本信息 Import Statement: import QtQuick.Controls 2.14 Since: Qt 5.7 Inherits: Control Inherited By: DialogButtonBox, MenuBar, SwipeView, and TabBar 二、控件说明 Container&#xff08;容…

网络安全之从原理看懂 XSS

01、XSS 的原理和分类 跨站脚本攻击 XSS(Cross Site Scripting)&#xff0c;为了不和层叠样式表(Cascading Style Sheets&#xff0c;CSS)的缩写混淆 故将跨站脚本攻击缩写为 XSS&#xff0c;恶意攻击者往 Web 页面里插入恶意 Script 代码&#xff0c;当用户浏览该页面时&…

【产品设计】删除确认文案,猛男落泪

使用各种系统时&#xff0c;都有各种删除操作&#xff0c;用户在删除时&#xff0c;很少关注文案写了什么&#xff0c;但这个文案往往让产品经理们殚精竭虑。怎么样才能写出合格的删除确认文案呢&#xff1f; 使用各种系统的时候&#xff0c;都有各种删除操作&#xff0c;作为用…

substrate中打印调试信息的多种方式详解

目录1. 获取substrate-node-template代码2. 添加一个用于测试的pallet至依赖到pallets目录3. log方式来输出信息3.1 将log依赖添到cargo.toml文件3.2 log-test/src/lib.rs修改call方法3.3 polkadot.js.调用测试函数do_something_log_test4. printable trait方式来输出信息4.1 首…

在 Rainbond 上使用在线知识库系统zyplayer-doc

zyplayer-doc 是一款适合企业和个人使用的WIKI知识库管理工具&#xff0c;提供在线化的知识库管理功能&#xff0c;专为私有化部署而设计&#xff0c;最大程度上保证企业或个人的数据安全&#xff0c;可以完全以内网的方式来部署使用它。 当然也可以将其作为企业产品的说明文档…

2023“认证杯”数学中国数学建模赛题浅析

2023年认证杯”数学中国数学建模如期开赛&#xff0c;本次比赛与妈杯&#xff0c;泰迪杯时间有点冲突。因此&#xff0c;个人精力有限&#xff0c;有些不可避免地错误欢迎大家指出。为了大家更方便的选题&#xff0c;我将为大家对四道题目进行简要的解析&#xff0c;以方便大家…

4.redis-主从复制

01-主从复制概述 单机redis存在的问题 ①硬盘故障, 导致数据丢失, redis不好用;②内存容量受限制, 单台服务器内存扩展有上限, 内存数据超过该上线, 该如何处理. 解决方案 1个master可以有多个slave, 1个slave只能对应1个master master: 主机, 可以读也可以写, 主要负责写. …

MySQL运维24-SHOW ENGINE INNODB STATUS解析

文章目录1、SHOW ENGINE INNODB STATUS概述2、信号量&#xff08;Semaphores&#xff09;2.1、信号量信息示例2.2、信号量信息说明2.3、知识点&#xff1a;CPU自旋(SPIN)2.3、信号量中的OS WAIT ARRAY INFO3、死锁3.1、死锁信息示例3.2、死锁信息说明4、外键冲突4.1、外键冲突信…

elementUI实现selecttree自定义下拉框树形组件

elementUI有select组件也有tree组件&#xff0c;但是就是没有下拉框和tree组件的结合体&#xff0c;那么这次我们就自定义一个。 效果图 引入组件 <select-tree ref"selectTree" treeChange"treeChangeFun" :dataArray"orgList" :value"…

【网络安全】文件上传漏洞及中国蚁剑安装

文件上传漏洞描述中国蚁剑安装1. 官网下载源码和加载器2.解压至同一目录并3.安装4.可能会出现的错误文件上传过程必要条件代码示例dvwa靶场攻击示例1.书写一句话密码进行上传2. 拼接上传地址3.使用中国蚁剑链接webshell前端js绕过方式服务端校验请求头中content-type黑名单绕过…

《花雕学AI》22:一种让AI模拟虚拟角色方法,足以更多创造力的ChatGPT角色扮演

一、什么是ChatGPT的角色扮演&#xff1f; ChatGPT是一种基于GPT-3模型的人机对话技术&#xff0c;它可以实现自然语言和计算机之间的交互。ChatGPT的角色扮演指的是让模型扮演一个虚构的人物&#xff0c;与用户进行设定好的对话。 例如&#xff0c;您可以让ChatGPT扮演一个关…

一文打通锁升级(偏向锁,轻量级锁,重量级锁)

前置知识&#xff1a;synchronized 在JavaSE1.6以前&#xff0c;synchronized都被称为重量级锁。但是在JavaSE1.6的时候&#xff0c;对synchronized进行了优化&#xff0c;引入了偏向锁和轻量级锁&#xff0c;以及锁的存储结构和升级过程&#xff0c;减少了获取锁和释放锁的性能…

Hbase1.1:Hbase官网、Hbase定义、Habse结构、Hbase依赖框架、Hbase整合框架

这里写自定义目录标题Hbase官网Hbase特点&#xff1a;大Hbase定义Habse结构Hbase依赖框架hadoopHbase整合框架PhoenixHiveHbase官网 Hbase官网地址 HBase是Hadoop database&#xff0c;一个分布式、可扩展的大数据存储。 当您需要对大数据进行随机、实时读/写访问时&#xf…