pytorch零基础实现语义分割项目(四)——模型训练与预测

news2024/12/24 3:36:51

模型训练与预测

  • 项目列表
  • 前言
  • 损失函数
    • one_hot
    • Dice Loss
    • Focal Loss
  • 模型参数与训练
  • 预测

项目列表

语义分割项目(一)——数据概况及预处理

语义分割项目(二)——标签转换与数据加载

语义分割项目(三)——语义分割模型(U-net和deeplavb3+)

语义分割项目(四)——模型训练与预测


前言

在本系列的前几篇文章中我们介绍了数据与模型,在本篇中我们将数据与模型相结合进行模型训练与预测。


损失函数

在正式构建损失函数之前我们首先介绍一下Dice loss,与其他分类任务不同的是,语义分割不仅要针对单个像素的分类,还包括像素所处位置的回归,对于像素的分类我们可以直接采用交叉熵去尽可能的与标签回归达到分类的效果(在这里我使用的改进后的交叉熵——Focal loss),而对于像素所处位置的损失我们以下面的公式来表示:
D i c e l o s s = 1 − 2 ∗ ∣ l a b e l ∩ t a r g e t ∣ ∣ l a b e l ∣ + ∣ t a r g e t ∣ Dice loss = 1 - \frac{2 *|label \cap target|}{|label|+|target|} Diceloss=1label+target2labeltarget
也就是1减去标签像素位置与预测像素位置的交集的二倍与标签总像素位置之和加上预测像素位置之和。

one_hot

为了求像素位置,我们需要对于标签进行one hot编码,即有像素为1,没有像素为0

def one_hot(target, num_classes=6, device='cuda'):
    b, h, w = target.size()
    hot = torch.zeros((num_classes, b, h, w)).to(device)
    for i in range (num_classes):
        idx = (target==i)
        hot[i, idx] = 1.0
    
    return hot.permute((1, 2, 3, 0))

Dice Loss

def Dice_loss(inputs, target):
    inputs_hot = one_hot(inputs.argmax(dim=1))
    target_hot = one_hot(target)
    inter = (inputs_hot * target_hot).sum(dim=3)
    unin = inputs_hot.sum(dim=3) + target_hot.sum(dim=3)
    scores = 2 * inter / unin
    
    dice_loss = 1 - scores.mean()
    return dice_loss

Focal Loss

def Focal_Loss(inputs, target, alpha=0.5, gamma=2):
    n, c, h, w = inputs.size()
    nt, ht, wt = target.size()
  
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)

    temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
    temp_target = target.view(-1)

    logpt  = -nn.CrossEntropyLoss(reduction='none')(temp_inputs, temp_target)
    pt = torch.exp(logpt)
    if alpha is not None:
        logpt *= alpha
    loss = -((1 - pt) ** gamma) * logpt
    loss = loss.mean()
    return loss
def loss(inputs, target):
    return Focal_Loss(inputs, target) + Dice_loss(inputs, target)

模型参数与训练

def train(net, epochs, train_iter, test_iter, device, loss, optimizer):
    print("device in : ", device)
    net = net.to(device)

    for epoch in range(epochs):
        net.train()

        train_loss = 0
        train_num = 0
        with tqdm(range(len(train_iter)), ncols=100, colour='red',
                  desc="train epoch {}/{}".format(epoch + 1, num_epochs)) as pbar:
            for i, (X, y) in enumerate(train_iter):
                optimizer.zero_grad()
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                l.backward()
                optimizer.step()
                train_loss += l.detach()
                train_num += 1
                pbar.set_postfix({'loss': "{:.6f}".format(train_loss / train_num)})
                pbar.update(1)

        net.eval()
        test_loss = 0
        test_num = 0
        with tqdm(range(len(test_iter)), ncols=100, colour='blue',
                  desc="test epoch {}/{}".format(epoch + 1, num_epochs)) as pbar:
            for X, y in test_iter:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                with torch.no_grad():
                    l = loss(y_hat, y)
                    test_loss += l.detach()
                    test_num += 1
                    pbar.set_postfix({'loss': "{:.6f}".format(test_loss / test_num)})
                    pbar.update(1)
batch_size  = 2
crop_size   = (600, 600) # 裁剪大小
model_choice = 'U-net'  # 可选U-net、deeplabv3+
in_channels = 3 # 输入图像通道
out_channels = 6 # 输出标签类别
num_epochs = 25 # 总轮次
lr = 5e-6
wd = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_iter, test_iter = load_data_voc(batch_size, crop_size, data_dir='dataset')
if model_choice == 'U-net':
    net = U_net()
elif model_choice == 'deeplabv3+':
    net = deeplabv3(3, 6)

if model_choice == 'U-net':
    model_path = os.path.join('model_weights', 'u-net-vgg16.pth')
elif model_choice == 'deeplabv3+':
    model_path = os.path.join('model_weights', 'Semantic-deeplabv3.pth')

net.load_state_dict(torch.load(model_path))

trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
train(net, num_epochs, train_iter, test_iter, device='cuda', loss=loss, optimizer=trainer)
torch.save(net.state_dict(), model_path)

预测

在预测前我们需要进行一些额外处理,比如将数值标签转换成RGB图像标签,我们在本篇中使用label2image实现

import os
from math import ceil
import torch
import torchvision
from torchvision import io
from utils.dataLoader import load_data_voc
from utils.dataConvert import loadColorMap
from utils.model import U_net, deeplabv3
from torchvision import transforms
import matplotlib.pyplot as plt


def label2image(pred, device):
    VOC_COLORMAP = loadColorMap()
    colormap = torch.tensor(VOC_COLORMAP, device=device)
    X = pred.long()
    return colormap[X, :]


def predict(net, device, img, means, stds):
    trans = torchvision.transforms.Normalize(
        mean=means, std=stds)
    X = trans(img / 255).unsqueeze(0)
    pred = net(X.to(device)).argmax(dim=1)
    return pred.reshape(pred.shape[1], pred.shape[2])


def read_voc_images(data_dir, is_train=True):
    images = []
    labels = []
    if is_train:
        with open(os.path.join(data_dir, 'train.txt')) as f:
            lst = [name.strip() for name in f.readlines()]

    else:
        with open(os.path.join(data_dir, 'test.txt')) as f:
            lst = [name.strip() for name in f.readlines()]

    for name in lst:
        image = io.read_image(os.path.join(data_dir, 'images', '{:03d}.jpg'.format(int(name))))
        label = io.read_image(os.path.join(data_dir, 'labels', '{:03d}.png'.format(int(name))))
        images.append(image)
        labels.append(label)

    return images, labels


def plotPredictAns(imgs):
    length = len(imgs)

    for i, img in enumerate(imgs):
        plt.subplot(ceil(length / 3), 3, i+1)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
        if i == 0:
            plt.title("original images")

        if i == 1:
            plt.title("predict label")

        if i == 2:
            plt.title("true label")

    plt.show()

if __name__ == '__main__':
    voc_dir = './dataset/'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    means = [0.4813, 0.4844, 0.4919]
    stds = [0.2467, 0.2478, 0.2542]
    test_images, test_labels = read_voc_images(voc_dir, False)
    n = 4
    imgs = []
    batch_size = 2
    crop_size = (600, 600)  # 裁剪大小
    _, test_iter = load_data_voc(batch_size, crop_size, data_dir='dataset')
    model_choice = 'U-net'

    if model_choice == 'U-net':
        net = U_net()
    elif model_choice == 'deeplabv3+':
        net = deeplabv3(3, 6)

    if model_choice == 'U-net':
        model_path = os.path.join('model_weights', 'u-net-vgg16.pth')
    elif model_choice == 'deeplabv3+':
        model_path = os.path.join('model_weights', 'Semantic-deeplabv3.pth')
    net.load_state_dict(torch.load(model_path))
    net = net.to(device)
    for i in range(n):
        crop_rect = (0, 0, 600, 600)
        X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
        pred = label2image(predict(net, device, X, means, stds), device)
        imgs += [X.permute(1, 2, 0), pred.cpu(),
                 torchvision.transforms.functional.crop(test_labels[i], *crop_rect).permute(1, 2, 0)]

    plotPredictAns(imgs)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/358259.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

winserver服务器硬盘满了怎么清理? 服务器硬盘空间不足清理方法

本文主要介绍我在维护windows server服务器期间总结的一些磁盘清理方式。如对您有所帮助,不甚荣幸。 文章目录一、C盘清理1. System32的日志文件2. IIS的日志文件3. .Net Framework的缓存文件4. 清理其他不必要文件5. 虚拟内存从c盘移到其他硬盘二、其他软件清理1. …

【离散数学】4. 图论

1.数理逻辑 2. 集合论 3. 代数系统 4. 图论 图&#xff1a;点边边与点的映射函数 连通性与判别 欧拉图与哈密尔顿图 二分图和平面图与欧拉公式 树及生成树 单源点最短路径&#xff1a;Dijkstra算法 对偶图 4. 图论 4.1 图的基本概念 4.1.1 图 一个图G是一个三重组 <V(G),E…

【LeetCode】No.232. 用栈实现队列 -- Java Version

题目链接&#xff1a;https://leetcode.cn/problems/implement-queue-using-stacks/ 1. 题目介绍&#xff08;232. 用栈实现队列&#xff09; 请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作&#xff08;push、pop、peek、empty&#xff09;&#xff…

两年外包生涯做完,感觉自己废了一半....

先说一下自己的情况。大专生&#xff0c;17年通过校招进入湖南某软件公司&#xff0c;干了接近2年的点点点&#xff0c;今年年上旬&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落&#xff01;而我已经在一个企业干了五年的功能测试…

慕了没?3年经验,3轮技术面+1轮HR面,拿下字节30k*16薪offer

前段时间有个朋友出去面试&#xff0c;这次他面试目标比较清晰&#xff0c;面的都是业务量大、业务比较核心的部门。前前后后去了不少公司&#xff0c;几家大厂里&#xff0c;他说给他印象最深的是字节3轮技术面1轮HR面&#xff0c;他最终拿到了30k*16薪的offer。第一轮主要考察…

MyBatis-Plus详细讲解(整合spring Boot)

哈喽&#xff0c;大家好&#xff0c;今天带大家了解的是MyBatis-Plus&#xff08;简称 MP&#xff09;&#xff0c;是一个 MyBatis 的增强工具&#xff0c;在 MyBatis 的基础上只做增强不做改变&#xff0c;为简化开发、提高效率而生。首先说一下MyBatis-Plus的愿景是什么&…

十五.程序环境和预处理

文章目录一.程序翻译环境和执行环境1.ANSI C 标准2.程序的翻译环境和执行环境二.程序编译和链接1.翻译环境2.编译本身的几个阶段3.运行环境三.预处理1.预定义符号2.#define&#xff08;1&#xff09;#define定义标识符&#xff08;2&#xff09;#define定义宏&#xff08;3&…

【Linux】——基础开发工具和vim编辑器的基本使用方法

目录 Linux 软件包管理器 yum Linux编辑器-vim使用 1.vim的基本概念 2. vim的基本操作 3. vim正常模式命令集 4. vim末行模式命令集 如何配置vim Linux 软件包管理器 yum yum是Linux下的一个下载软件的软件 对于yum&#xff0c;现阶段只需要会使用yum的三板斧就…

【linux】——gcc/g++,make/makefile的简单使用

目录 1.gcc的基本使用 2.Linux下的静态库和动态库的理解 3.Linux项目自动化构建工具——make/makefile 1.gcc的基本使用 gcc是专门用来编译c语言的 g是专门用来编译c的&#xff0c;但是g也能够用来编译c语言 预处理&#xff08;进行宏替换&#xff09; 预处理功能主要包括宏…

Idea无法识别SpringBoot配置文件

SpringBoot的配置文件 application.properties > application.yml > application.yaml 配置文件间的加载优先级 properties&#xff08;最高&#xff09;> yml > yaml&#xff08;最低&#xff09;不同配置文件中相同配置按照加载优先级相互覆盖&#xff0c;不同配…

免费使用通配符域名证书

文章目录前言一、手动安装acme.sh操作1、安装acme.sh2、使用dns api自动续签二、宝塔自动操作【推荐】总结前言 之前个人站点一般都是使用阿里云免费单域名证书&#xff0c;虽然好用但是只有一年有效&#xff0c;到期只能手动重新申请&#xff0c;并且每次弄个子域名出来就要重…

【C++】类和对象练习——日期类的实现

文章目录前言1. 日期的合法性判断2. 日期天数&#xff08;/&#xff09;2.1 和的重载2.2 对于两者复用的讨论3. 前置和后置重载4. 日期-天数&#xff08;-/-&#xff09;5. 前置- -和后置- -的重载6. 日期-日期7. 流插入<<重载8. 流提取>>重载9. 总结10. 源码展示前…

JavaScript - 函数

文章目录一、箭头函数二、函数名三、理解参数3.1 箭头函数中的参数四、没有重载五、默认参数值5.1 默认参数作用域与暂时性死区六、参数扩展与收集6.1 扩展参数6.2 收集参数七、函数声明与函数表达式八、函数作为值九、函数内部9.1 arguments9.2 this9.3 caller9.4 new.target十…

关于机器人状态估计(12)-VIO/VSLAM的稀疏与稠密

VIO三相性与世界观室内ALL IN ONE 首先以此链接先对近期工作的视频做个正经的引流&#xff0c;完成得这么好的效果&#xff0c;仅仅是因为知乎限流1分钟以内的视频&#xff0c;导致整个浏览量不到300&#xff0c;让人非常不爽。 这套系统已经完成了&#xff0c;很快将正式发布…

总是跳转到国内版(cn.bing.com)?New Bing使用全攻略

你是否想要使用强大的&#xff08;被削后大嘘&#xff09;New Bing&#xff1f; 你是否已经获得了New Bing的使用资格&#xff1f; 你是否在访问www.bing.com/new时提示页面不存在&#xff1f; 你是否在访问www.bing.com时总是重定向到cn.bing.com而使用不了New Bing? New Bi…

C++——C++11第二篇

目录 可变参数模板 lambda表达式 lambda表达式语法 捕获列表说明 可变参数模板 可变参数&#xff1a;可以有0到n个参数&#xff0c;如之前学过的 Printf C11的新特性可变参数模板能够让您创建可以接受可变参数的函数模板和类模板 模板参数包 // Args是一个模板参数包&…

Python3 pip

Python3 pip pip 是 Python 包管理工具&#xff0c;该工具提供了对 Python 包的查找、下载、安装、卸载的功能。 软件包也可以在 https://pypi.org/ 中找到。 目前最新的 Python 版本已经预装了 pip。 注意&#xff1a;Python 2.7.9 或 Python 3.4 以上版本都自带 pip 工具…

IM 即时通讯实战:环信Web IM极速集成

前置技能 Node.js 环境已搭建。npm 包管理工具的基本使用。Vue2 或者 Vue3 框架基本掌握或使用。 学习目标 项目中集成 IM 即时通讯实战利用环信 IM Web SDK 快速实现在 Vue.js 中发送出一条 Hello World! 一、了解环信 IM 什么是环信 IM&#xff1f; 环信即时通讯为开发者…

深度学习神经网络基础知识(一) 模型选择、欠拟合和过拟合

专栏&#xff1a;神经网络复现目录 深度学习神经网络基础知识(一) 本文讲述神经网络基础知识&#xff0c;具体细节讲述前向传播&#xff0c;反向传播和计算图&#xff0c;同时讲解神经网络优化方法&#xff1a;权重衰减&#xff0c;Dropout等方法&#xff0c;最后进行Kaggle实…

机器学习算法原理之k近邻 / KNN

文章目录k近邻 / KNN主要思想模型要素距离度量分类决策规则kd树主要思想kd树的构建kd树的搜索总结归纳k近邻 / KNN 主要思想 假定给定一个训练数据集&#xff0c;其中实例标签已定&#xff0c;当输入新的实例时&#xff0c;可以根据其最近的 kkk 个训练实例的标签&#xff0c…