训练DeeplabV3+来分割车道线

news2025/1/16 20:52:43

本例我们训练DeepLabV3+语义分割模型来分割车道线。

ead64233424728797832a74e8d4cbab4.png

DeepLabV3+模型的原理有以下一些要点:

1,采用Encoder-Decoder架构。

2,Encoder使用类似Xception的结构作为backbone。

3,Encoder还使用ASPP(Atrous Spatial Pyramid Pooling),即空洞卷积空间金字塔池化,来实现不同尺度的特征融合,ASPP由4个不同rate的空洞卷积和一个全局池化组成。

4,Decoder再次使用跨层级的concat操作进行高低层次的特征融合。

#!pip install segmentation_models_pytorch
#!pip install albumentations
import torchkeras 

from argparse import Namespace

config = Namespace(
    img_size = 128, 
    lr = 1e-4,
    batch_size = 4,
)

一,准备数据

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook代码和车道线数据集下载链接。

from pathlib import Path
from PIL import Image
import numpy as np 
import torch 
from torch import nn 
from torch.utils.data import Dataset,DataLoader 
import os 
from torchkeras.data import resize_and_pad_image 
from torchkeras.plots import joint_imgs_col 

class MyDataset(Dataset):
    def __init__(self, img_files, img_size, transforms = None):
        self.__dict__.update(locals())
        
    def __len__(self) -> int:
        return len(self.img_files)

    def get(self, index):
        img_path = self.img_files[index]
        mask_path = img_path.replace('images','masks').replace('.jpg','.png')
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        return image, mask
    
    def __getitem__(self, index):
        
        image,mask = self.get(index)
        
        image = resize_and_pad_image(image,self.img_size,self.img_size)
        mask = resize_and_pad_image(mask,self.img_size,self.img_size)
        
        image_arr = np.array(image, dtype=np.float32)/255.0
        
        mask_arr = np.array(mask,dtype=np.float32)
        mask_arr = np.where(mask_arr>100.0,1.0,0.0).astype(np.int64)
        

        sample = {
            "image": image_arr,
            "mask": mask_arr
        }
        
        if self.transforms is not None:
            sample = self.transforms(**sample)
            
        sample['mask'] = sample['mask'][None,...]

            
        return sample
    
    def show_sample(self, index):
        image, mask = self.get(index)
        image_result = joint_imgs_col(image,mask)
        return image_result
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

def get_train_transforms():
    return A.Compose(
        [
            A.OneOf([A.HorizontalFlip(p=0.5),A.VerticalFlip(p=0.5)]),
            ToTensorV2(p=1),
        ],
        p=1.0
    )

def get_val_transforms():
    return A.Compose(
        [
            ToTensorV2(p=1),
        ],
        p=1.0
    )
train_transforms=get_train_transforms()
val_transforms=get_val_transforms()

ds_train = MyDataset(train_imgs,img_size=config.img_size,transforms=train_transforms)
ds_val = MyDataset(val_imgs,img_size=config.img_size,transforms=val_transforms)

dl_train = DataLoader(ds_train,batch_size=config.batch_size)
dl_val = DataLoader(ds_val,batch_size=config.batch_size)
ds_train.show_sample(10)

13afdd40413852e18552320016680739.png

二,定义模型

import torch 

num_classes = 1
net = smp.DeepLabV3Plus(
    encoder_name="mobilenet_v2", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights='imagenet',     # use `imagenet` pretrained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
    classes=num_classes,            # model output channels (number of classes in your dataset)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

三,训练模型

下面使用我们的梦中情炉~torchkeras~来实现最优雅的训练循环。😋😋

from torchkeras import KerasModel 
from torch.nn import functional as F 

# 由于输入数据batch结构差异,需要重写StepRunner并覆盖
class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
            
    
    def __call__(self, batch):
        features,labels = batch['image'],batch['mask'] 
        
        #loss
        preds = self.net(features)
        loss = self.loss_fn(preds,labels)

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_preds = self.accelerator.gather(preds)
        all_labels = self.accelerator.gather(labels)
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics
        step_metrics = {self.stage+"_"+name:metric_fn(all_preds, all_labels).item() 
                        for name,metric_fn in self.metrics_dict.items()}
        
        if self.optimizer is not None and self.stage=="train":
            step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            
        return step_losses,step_metrics

KerasModel.StepRunner = StepRunner
from torchkeras.metrics import IOU


class DiceLoss(nn.Module):
    def __init__(self,smooth=0.001,num_classes=1,weights = None):
        ...

    def forward(self, logits, targets):
        
        ...
        
    def compute_loss(self,preds,targets):
        ...
    
    
class MixedLoss(nn.Module):
    def __init__(self,bce_ratio=0.5):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        self.bce_ratio = bce_ratio
        
    def forward(self,logits,targets):
        bce_loss = self.bce(logits,targets.float())
        dice_loss = self.dice(logits,targets)
        total_loss = bce_loss*self.bce_ratio + dice_loss*(1-self.bce_ratio)
        return total_loss
optimizer = torch.optim.AdamW(net.parameters(), lr=config.lr)


lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer = optimizer,
    T_max=8,
    eta_min=0
)

metrics_dict = {'iou': IOU(num_classes=1)}

model = KerasModel(net,
                   loss_fn=MixedLoss(bce_ratio=0.5),
                   metrics_dict=metrics_dict,
                   optimizer=optimizer,
                   lr_scheduler = lr_scheduler
                  )
from torchkeras.kerascallbacks import WandbCallback

wandb_cb = WandbCallback(project='unet_lane',
                         config=config.__dict__,
                         name=None,
                         save_code=True,
                         save_ckpt=True)

dfhistory=model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=100, 
                    ckpt_path='checkpoint.pt',
                    patience=10, 
                    monitor="val_iou",
                    mode="max",
                    mixed_precision='no',
                    callbacks = [wandb_cb],
                    plot = True 
                   )

<<<<<< ⚡️ cuda is used >>>>>>

7bec7d9048d769278fba9fb625ec7365.png

================================================================================2023-05-21 20:45:27
Epoch 1 / 100

100%|████████████████████| 20/20 [00:03<00:00,  6.60it/s, lr=5e-5, train_iou=0.15, train_loss=0.873]
100%|██████████████████████████████████| 5/5 [00:00<00:00,  8.54it/s, val_iou=0.162, val_loss=0.836]
[0;31m<<<<<< reach best val_iou : 0.16249321401119232 >>>>>>[0m

================================================================================2023-05-21 20:45:30
Epoch 2 / 100

100%|███████████████████████| 20/20 [00:02<00:00,  7.24it/s, lr=0, train_iou=0.25, train_loss=0.836]
100%|██████████████████████████████████| 5/5 [00:00<00:00,  8.49it/s, val_iou=0.291, val_loss=0.821]
[0;31m<<<<<< reach best val_iou : 0.2905024290084839 >>>>>>[0m


================================================================================2023-05-21 20:51:06
Epoch 95 / 100

100%|███████████████████| 20/20 [00:02<00:00,  7.21it/s, lr=5e-5, train_iou=0.721, train_loss=0.187]
100%|██████████████████████████████████| 5/5 [00:00<00:00,  8.71it/s, val_iou=0.665, val_loss=0.249]

四,评估模型

metrics_dict = {'iou': IOU(num_classes=1,if_print=True)}

model = KerasModel(net,
                   loss_fn=MixedLoss(bce_ratio=0.5),
                   metrics_dict=metrics_dict,
                   optimizer=optimizer,
                   lr_scheduler = lr_scheduler
                  )
model.evaluate(dl_val)
100%|██████████████████████████████████| 5/5 [00:00<00:00,  8.91it/s, val_iou=0.667, val_loss=0.252]


global correct: 0.9912
IoU: ['0.9911', '0.3422']
mean IoU: 0.6667

五,使用模型

batch = next(iter(dl_val))

with torch.no_grad():
    model.eval()
    logits = model(batch["image"].cuda())
    
pr_masks = logits.sigmoid()
from matplotlib import pyplot as plt 
for image, gt_mask, pr_mask in zip(batch["image"], batch["mask"], pr_masks):
    plt.figure(figsize=(16, 10))

    plt.subplot(1, 3, 1)
    plt.imshow(image.numpy().transpose(1, 2, 0))  # convert CHW -> HWC
    plt.title("Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(gt_mask.numpy().squeeze()) 
    plt.title("Ground truth")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(pr_mask.cpu().numpy().squeeze()) 
    plt.title("Prediction")
    plt.axis("off")

    plt.show()

543b40a4a1126920496567338f4973ae.png

a360cb249f45048b509963d4a0e085e7.png

d451854564cd1d290049d8aa386f895c.png

a413800b49de3cbe8360f6d89906f0f1.png

六,保存模型

torch.save(model.net.state_dict(),'deeplab_v3_plus.pt')

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook代码和车道线数据集下载链接。

万水千山总是情,点个赞赞行不行?😋😋

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/575257.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

听听飞桨框架硬核贡献者如何玩转开源!

当仰望星空时&#xff0c;你在想什么&#xff1f;我在想象&#xff0c;未来可能是什么样子。从应用广泛的人工神经网络&#xff0c;到火遍全网的AIGC&#xff0c;创造新宇宙的人&#xff0c;相信永远看不到天花板。 在这些神奇的AI产品背后&#xff0c;有一个了不起的开源项目—…

滴滴时空供需系统的设计和演进

本篇文章分为&#xff1a; 1.背景介绍 2.系统框架的演进 2.1 旧系统框架的不足 2.2 新系统框架的优势 3.系统建设思考 3.1 存储治理 3.2 性能优化 3.3 研发提效&#xff1a;配置化能力升级 3.总结 1. 背景介绍 时空供需系统(SDS, supply and demand system)是为了满足滴滴网约车…

开箱即用的工具函数库xijs更新指南(v1.2.6)

xijs 是一款开箱即用的 js 业务工具库, 聚集于解决业务中遇到的常用函数逻辑问题, 帮助开发者更高效的开展业务开发. 接下来就和大家一起分享一下 v1.2.6 版本的更新内容以及后续的更新方向. 贡献者列表: 1. 计算变量内存calculateMemory 该模块主要由 zhengsixsix 贡献, 我们可…

leetcode练习(汇总插入区间)

文章目录 题目一&#xff1a;汇总区间题目二&#xff1a;插入区间 语言&#xff1a;python 工具&#xff1a;jupyuter 题目一&#xff1a;汇总区间 给定一个 无重复元素 的 有序 整数数组 nums 。 返回 恰好覆盖数组中所有数字 的 最小有序 区间范围列表 。也就是说&#xff0c…

“程序员,致敬!”

手机震动&#xff0c;提醒着我3年前参加研发的应用迎来了一次重大升级。我按下开源社区提供的合并请求按钮&#xff0c;与开源社区的朋友分享我对这个项目的改进。不久&#xff0c;一条消息提醒我合并请求已被其它社区成员审核通过。 这种远程协作、开源分享的方式是如今广泛存…

chatgpt赋能python:Python数值计算指南:为什么它是一种强大的工具

Python数值计算指南&#xff1a;为什么它是一种强大的工具 当谈到数值计算时&#xff0c;许多人所想到的编程语言都是MATLAB和R。然而&#xff0c;Python也在数值计算领域有着强大的地位。Python是一种令人难以置信的通用编程语言&#xff0c;它不仅为数据科学和机器学习提供了…

行人检测重识别yolov5+reid(跑通+界面设计)

行人检测重识别yolov5reid&#xff08;跑通界面设计&#xff09; 参考源代码: github 权重文件&#xff1a; 根据github上面的网盘进行权重下载&#xff1a; 检测&#xff1a;将 ReID_resnet50_ibn_a.pth放在person_search/weights文件下&#xff0c;yolov5s.pt放person_sear…

如何用海外代理辅助对接 ChatGPT

许多朋友问我有没有好用的海外代理。说实话&#xff0c;真的好用的并不多。 最近我了解到了一家还不错的海外代理&#xff0c;叫做 IPIDEA&#xff0c;我已经使用了一段时间了&#xff0c;觉得质量挺不错。 你可能知道&#xff0c;我最近在进行一些 ChatGPT 相关的研究&#xf…

DTW 2023:戴尔发力多云战略与边缘运营

近日&#xff0c;2023戴尔科技全球科技大会&#xff08;Dell Technologies World&#xff0c;简称DTW&#xff09;在美国拉斯维加斯如期而至。 作为戴尔科技集团一年一度的科技盛宴&#xff0c;本届DTW吸引了众多业界人士的关注。而作为本届大会的重头戏&#xff0c;戴尔科技集…

Spark学习笔记

1 spark简介 (1) spark是基于内存计算的分布式并行计算框架&#xff0c;如今已成为apache软件基金会最重要的三大分布式计算系统开源项目之一(Hadoop、Spark、Storm)。 (2) spark组件 (3) spark组件应用场景 Spark Streaming&#xff1a;提供流计算功能 Sparl SQL&#xff1…

Python实现循环的最快方式(for、while等速度对比)

众所周知&#xff0c;Python 不是一种执行效率较高的语言。此外在任何语言中&#xff0c;循环都是一种非常消耗时间的操作。假如任意一种简单的单步操作耗费的时间为 1 个单位&#xff0c;将此操作重复执行上万次&#xff0c;最终耗费的时间也将增长上万倍。 while 和 for 是 …

JavaCV - 图像暗通道去雾

一、效果图 二、实现原理 暗通道先验:首先说在绝大多数非天空的局部区域里,某一些像素总会有至少一个颜色通道具有很低的值,也就是说该区域光强是一个很小的值。所以给暗通道下了个数学定义,对于任何输入的图像J,其暗通道可以用下面的公式来表示:其中JC表示彩色图像每个…

SOFA Weekly|SOFAChannel#33 直播预告、Layotto 社区会议回顾与预告、社区本周贡献

SOFA WEEKLY | 每周精选 筛选每周精华问答&#xff0c;同步开源进展 欢迎留言互动&#xff5e; SOFAStack&#xff08;Scalable Open Financial Architecture Stack&#xff09;是蚂蚁集团自主研发的金融级云原生架构&#xff0c;包含了构建金融级云原生架构所需的各个组件&am…

阿里、腾讯、京东齐降价:云计算迎来新拐点

‍数据智能产业创新服务媒体 ——聚焦数智 改变商业 618源于京东创办日&#xff08;创办于2004年6月18日&#xff09;&#xff0c;发展至今&#xff0c;618已然成为中国两大最火爆的消费节点之一。每年618&#xff0c;京东都会推出覆盖全品类的优惠政策&#xff0c;并以严格的…

chatgpt赋能python:Python捕获所有异常

Python 捕获所有异常 Python是一种易用、高效的编程语言&#xff0c;广泛应用于Web开发、数据科学、人工智能等领域。在Python编程中&#xff0c;异常处理是一项重要的技能&#xff0c;因为程序总会出现各种异常情况&#xff0c;如输入错误、网络错误、程序崩溃等等。Python提…

矿井水除氟——高矿化度矿井水氟化物深度降解的技术方案

高矿化度矿井水是指含有高浓度溶解性矿物质的废水&#xff0c;通常指的是含有高浓度钠、钙、镁、铁、铝、钾等离子的废水。这些离子通常来自于废水所处的环境、工业或生产过程中使用的原材料和化学品。高矿化度的废水通常具有高盐度、高电导率、高硬度等特征&#xff0c;对环境…

(十七)ArcGIS 属性表生成GUID字段

ArcGIS 属性表生成GUID字段 目录 ArcGIS 属性表生成GUID字段 1.GUID概念2.GUID格式3. ArcGIS 属性表生成GUID字段3.1新建GUID字段3.2生成GUID字段 1.GUID概念 全局唯一标识符&#xff08;GUID&#xff0c;Globally Unique Identifier&#xff09;是一种由算法生成的二进制长度…

堆的实现+堆的应用(堆排序和Topk)

珍惜当下的一切&#xff0c;相信未来的一切都是美好的。 -- 丹尼尔迪凯托目录 一.堆的概念及结构 二.堆的各种函数的实现 1.结构体的内容 2.堆的初始化 3.堆的插入 4.堆的向上调整法 5.验证堆的向上调整法 6.堆顶的删除 7.堆的向下调整法 8.返回堆…

【Python】使用百度AI能力

知识目录 一、写在前面✨二、百度AI能力介绍三、植物识别四、总结撒花&#x1f60a; 一、写在前面✨ 大家好&#xff01;我是初心&#xff0c;希望我们一路走来能坚守初心&#xff01; 今天跟大家分享的文章是 Python调用百度AI能力进行植物识别。 &#xff0c;希望能帮助到大…

欧盟加密监管法案通过,美国急了?

万众期待的欧盟《加密资产市场监管法案》&#xff08;Markets in Crypto-Assets Regulation&#xff0c;简称MiCA&#xff09;终于在5月16日尘埃落定。 尽管在4月20日&#xff0c;该方案已在欧洲议会全体会议上投票通过&#xff0c;但直到5月16日&#xff0c;包括27个国家的欧盟…