【PyTorch】多对象分割项目

news2024/9/21 11:13:00

对象分割任务的目标是找到图像中目标对象的边界。实际应用例如自动驾驶汽车和医学成像分析。这里将使用PyTorch开发一个深度学习模型来完成多对象分割任务。多对象分割的主要目标是自动勾勒出图像中多个目标对象的边界。

对象的边界通常由与图像大小相同的分割掩码定义,在分割掩码中属于目标对象的所有像素基于预定义的标记被标记为相同。

目录

创建数据集

创建数据加载器

创建模型

部署模型

定义损失函数和优化器

训练和验证模型


创建数据集

from torchvision.datasets import VOCSegmentation
from PIL import Image   
from torchvision.transforms.functional import to_tensor, to_pil_image

class myVOCSegmentation(VOCSegmentation):
    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        target = Image.open(self.masks[index])

        if self.transforms is not None:
            augmented= self.transforms(image=np.array(img), mask=np.array(target))
            img = augmented['image']
            target = augmented['mask']                  
            target[target>20]=0

        img= to_tensor(img)            
        target= torch.from_numpy(target).type(torch.long)
        return img, target

from albumentations import (
    HorizontalFlip,
    Compose,
    Resize,
    Normalize)

mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]
h,w=520,520

transform_train = Compose([ Resize(h,w),
                HorizontalFlip(p=0.5), 
                Normalize(mean=mean,std=std)])

transform_val = Compose( [ Resize(h,w),
                          Normalize(mean=mean,std=std)])            

path2data="./data/"    
train_ds=myVOCSegmentation(path2data, 
                year='2012', 
                image_set='train', 
                download=False, 
                transforms=transform_train) 
print(len(train_ds)) 


val_ds=myVOCSegmentation(path2data, 
                year='2012', 
                image_set='val', 
                download=False, 
                transforms=transform_val)
print(len(val_ds)) 

import torch
import numpy as np
from skimage.segmentation import mark_boundaries
import matplotlib.pylab as plt
%matplotlib inline
np.random.seed(0)
num_classes=21
COLORS = np.random.randint(0, 2, size=(num_classes+1, 3),dtype="uint8")

def show_img_target(img, target):
    if torch.is_tensor(img):
        img=to_pil_image(img)
        target=target.numpy()
    for ll in range(num_classes):
        mask=(target==ll)
        img=mark_boundaries(np.array(img) , 
                            mask,
                            outline_color=COLORS[ll],
                            color=COLORS[ll])
    plt.imshow(img)


def re_normalize (x, mean = mean, std= std):
    x_r= x.clone()
    for c, (mean_c, std_c) in enumerate(zip(mean, std)):
        x_r [c] *= std_c
        x_r [c] += mean_c
    return x_r

 展示训练数据集示例图像

img, mask = train_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))

plt.figure(figsize=(20,20))

img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))

plt.subplot(1, 3, 2) 
plt.imshow(mask)

plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)
    

展示验证数据集示例图像

img, mask = val_ds[10]
print(img.shape, img.type(),torch.max(img))
print(mask.shape, mask.type(),torch.max(mask))

plt.figure(figsize=(20,20))

img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))

plt.subplot(1, 3, 2) 
plt.imshow(mask)

plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

创建数据加载器

 通过torch.utils.data针对训练和验证集分别创建Dataloader,打印示例观察效果

from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8, shuffle=False) 

for img_b, mask_b in train_dl:
    print(img_b.shape,img_b.dtype)
    print(mask_b.shape, mask_b.dtype)
    break

for img_b, mask_b in val_dl:
    print(img_b.shape,img_b.dtype)
    print(mask_b.shape, mask_b.dtype)
    break

创建模型

创建并打印deeplab_resnet模型结构,使用预训练权重

from torchvision.models.segmentation import deeplabv3_resnet101
import torch

model=deeplabv3_resnet101(pretrained=True, num_classes=21)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
print(model)

部署模型

在验证数据集的数据批次上部署模型观察效果 

from torch import nn

model.eval()
with torch.no_grad():
    for xb, yb in val_dl:
        yb_pred = model(xb.to(device))
        yb_pred = yb_pred["out"].cpu()
        print(yb_pred.shape)    
        yb_pred = torch.argmax(yb_pred,axis=1)
        break
print(yb_pred.shape)

plt.figure(figsize=(20,20))

n=2
img, mask= xb[n], yb_pred[n]
img_r= re_normalize(img)
plt.subplot(1, 3, 1) 
plt.imshow(to_pil_image(img_r))

plt.subplot(1, 3, 2) 
plt.imshow(mask)

plt.subplot(1, 3, 3) 
show_img_target(img_r, mask)

可见勾勒对象方面效果很好 

定义损失函数和优化器

from torch import nn
criterion = nn.CrossEntropyLoss(reduction="sum")
from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-6)

def loss_batch(loss_func, output, target, opt=None):   
    loss = loss_func(output, target)
    
    if opt is not None:
        opt.zero_grad()
        loss.backward()
        opt.step()

    return loss.item(), None

from torch.optim.lr_scheduler import ReduceLROnPlateau
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)

def get_lr(opt):
    for param_group in opt.param_groups:
        return param_group['lr']

current_lr=get_lr(opt)
print('current lr={}'.format(current_lr))

训练和验证模型

def loss_epoch(model,loss_func,dataset_dl,sanity_check=False,opt=None):
    running_loss=0.0
    len_data=len(dataset_dl.dataset)

    for xb, yb in dataset_dl:
        xb=xb.to(device)
        yb=yb.to(device)
        
        output=model(xb)["out"]
        loss_b, _ = loss_batch(loss_func, output, yb, opt)
        running_loss += loss_b
        
        if sanity_check is True:
            break
    
    loss=running_loss/float(len_data)
    return loss, None

import copy
def train_val(model, params):
    num_epochs=params["num_epochs"]
    loss_func=params["loss_func"]
    opt=params["optimizer"]
    train_dl=params["train_dl"]
    val_dl=params["val_dl"]
    sanity_check=params["sanity_check"]
    lr_scheduler=params["lr_scheduler"]
    path2weights=params["path2weights"]
    
    loss_history={
        "train": [],
        "val": []}
    
    metric_history={
        "train": [],
        "val": []}    
    
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss=float('inf')    
    
    for epoch in range(num_epochs):
        current_lr=get_lr(opt)
        print('Epoch {}/{}, current lr={}'.format(epoch, num_epochs - 1, current_lr))   

        model.train()
        train_loss, train_metric=loss_epoch(model,loss_func,train_dl,sanity_check,opt)

        loss_history["train"].append(train_loss)
        metric_history["train"].append(train_metric)
        
        model.eval()
        with torch.no_grad():
            val_loss, val_metric=loss_epoch(model,loss_func,val_dl,sanity_check)
       
        loss_history["val"].append(val_loss)
        metric_history["val"].append(val_metric)   
        
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            
            torch.save(model.state_dict(), path2weights)
            print("Copied best model weights!")
            
        lr_scheduler.step(val_loss)
        if current_lr != get_lr(opt):
            print("Loading best model weights!")
            model.load_state_dict(best_model_wts) 
            
        print("train loss: %.6f" %(train_loss))
        print("val loss: %.6f" %(val_loss))
        print("-"*10) 
    model.load_state_dict(best_model_wts)
    return model, loss_history, metric_history        
import os
opt = optim.Adam(model.parameters(), lr=1e-6)
lr_scheduler = ReduceLROnPlateau(opt, mode='min',factor=0.5, patience=20,verbose=1)

path2models= "./models/"
if not os.path.exists(path2models):
        os.mkdir(path2models)

params_train={
    "num_epochs": 10,
    "optimizer": opt,
    "loss_func": criterion,
    "train_dl": train_dl,
    "val_dl": val_dl,
    "sanity_check": True,
    "lr_scheduler": lr_scheduler,
    "path2weights": path2models+"sanity_weights.pt",
}

model, loss_hist, _ = train_val(model, params_train)

绘制了训练和验证损失曲线 

num_epochs=params_train["num_epochs"]

plt.title("Train-Val Loss")
plt.plot(range(1,num_epochs+1),loss_hist["train"],label="train")
plt.plot(range(1,num_epochs+1),loss_hist["val"],label="val")
plt.ylabel("Loss")
plt.xlabel("Training Epochs")
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1967424.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL进阶】事务隔离级别 MVCC

目录 MySQL事务隔离级别 1. 读未提交&#xff08;Read Uncommitted&#xff09; 2. 读已提交&#xff08;Read Committed&#xff09; 3. 可重复读&#xff08;Repeatable Read&#xff09;(默认隔离级别) 4. 串行化&#xff08;Serializable&#xff09; 表格总结 MVCC …

C++栈和队列(容器适配器)

目录 1.什么是适配器&#xff1f; 2.栈(stack) 3.队列(queue) 4.双端队列(deque) 5.优先级队列(priority_queue) 1.什么是仿函数&#xff1f; 2.仿函数有什么用&#xff1f; 3.优先级队列(priority_queue) 1.什么是适配器&#xff1f; 我们之前实现栈和队列&#xff0…

QTCreate中使用git进行代码的备份

一开始使用QTCreate设计UI时&#xff0c;都是手动保存&#xff0c;后面觉得应该升级一下自己的技术栈&#xff0c;把git工具学了一些&#xff0c;摸索两天。首先&#xff0c;git是一个版本控制工具&#xff0c;正常开发需要一个master和一个你自己需要的分支&#xff0c;在分支…

批量下载 B 站 视频的工具 downkyi

批量下载 B 站 视频的工具 downkyi 亲测好用 图片&#xff1a; 下载地址&#xff1a; https://github.com/leiurayer/downkyi

SQL Server Profiler 只查看当前操作的语句

1.打开Sql Server Manage Studio&#xff0c;登录->工具->SQL Server Profiler->链接 点击事件选择 点击列筛选器 点击spid 输入对应的spid 如果不知道你的spid是多少的话&#xff0c;你可以先不筛选&#xff0c;直接运行&#xff0c;然后开启跟踪再运行一个独特的…

【项目管理】项目经理管理表单(及全套资料)

PM项目管理模板 甘特图 OKR周报 团队任务 工作总结

有没有比较好用的家用洗地机推荐?一文搞懂洗地机哪种牌子好

如今洗地机在我们家庭清洁中&#xff0c;已经很常见了&#xff0c;它可以让我们快速的完成地面清洁的工作&#xff0c;无需我们手动去清洗滚布&#xff0c;大大的节省了我们清洁时间&#xff0c;而且清洁效果也更加到位。但是目前市面上的洗地机型号多到让人眼花缭乱&#xff0…

Android开发之组件化

#来自ウルトラマンゼロ&#xff08;哉阿斯&#xff09; 1 简介 通俗来讲&#xff0c;将一个功能完整的 App 或模块拆分成多个子模块, 每个子模块可以独立编译和运行。也可以任意组合成另一个新的 App 或模块, 每个模块即不相互依赖但又可以相互交互, 遇到某些特殊情况甚至可以升…

RN 开发环境搭建(Windows For Android)

传送门&#xff1a;官网的搭建步骤&#xff08;英文&#xff09; 传送门&#xff1a;官网的搭建步骤&#xff08;中文&#xff09; 注&#xff1a;教程写于2022年11月21日&#xff0c;当时也是根据官网步骤一步步操作的。现在时隔2年&#xff0c;最新的 RN 版本&#xff08;V…

在现有的vue3项目中 配置electron

Vue项目已创建&#xff0c;在此基础上安装electron 配置步骤&#xff1a; 装依赖 yarn install装electron安装concurrently ( 一条命令实现同时启动vue项目和electron)安装nodemon (实现热更新) 一、配置途中遇到的问题&#xff1a; 1. 安装 yarn add electron -D 一直卡在这…

连锁企业组网的优化解决方案

对于连锁企业来说&#xff0c;建立高效的网络组网很重要&#xff0c;因为它直接影响到各分支机构之间的信息共享、管理效率和业务流程的顺畅。一个理想的解决方案需要从多个角度入手&#xff0c;以确保网络的稳定性、安全性和可扩展性。 首先&#xff0c;需要选择合适的网络拓扑…

stm32番外-----0.96寸OLED播放电影《你的名字》

目录 前言 OLED播放视频 1.简述 2.现象 3.电路连接图​编辑 4.项目主要文件 5.代码 6.注意事项 前言 刚好前面学习了USART串口通信&#xff0c;本期咱们来玩个有意思的&#xff0c;就是去通过USART实现视频的播放&#xff0c;本期内容程序是来自江协科技的&#xf…

Getty 携手英伟达升级商业文生图 AI 模型;苹果新专利探索「心跳」解锁 iPhone 丨 RTE 开发者日报

开发者朋友们大家好&#xff1a; 这里是 「RTE 开发者日报」 &#xff0c;每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE&#xff08;Real-Time Engagement&#xff09; 领域内「有话题的 新闻 」、「有态度的 观点 」、「有意思的 数据 」、「有思考的 文…

中科院4区救命神刊!主打不让任何一个人延毕~沾边可录!

【SciencePub学术】本期&#xff0c;小编给大家推荐一本JCR2区中科院4区的“救命神刊”&#xff01;征稿领域可谓是相当广泛&#xff0c;且国人友好&#xff0c;计算机领域的学者可以考虑一下这本期刊&#xff01; 期刊解析 KNOWLEDGE AND INFORMATION SYSTEMS 《知识与信息系统…

计算机毕业设计选题推荐-学院教学工作量统计系统-Java/Python项目实战

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

Unity后处理(Post-processing)

Unity post-processing 就像是对图片采用滤镜一样&#xff08;如下图对比&#xff09;对当前场景显示做一定的显示处理&#xff0c;使得场景更漂亮、有趣或者有型。 视觉风格与视觉保真 游戏场景后处理能够达到所需的视觉风格&#xff08;visual style&#xff09;同时也保证视…

【Stack和Queue模拟实现】

Stack和Queue模拟实现 小杨 在模拟实现之前&#xff0c;有必要介绍一下什么是容器适配器 容器适配器 适配器是一种设计模式&#xff0c;该种模式是将一个类的接口转换成客户希望的另一个接口。 虽然stack和queue中也可以存放元素&#xff0c;但在STL中并没有将其划分在容器的…

星环科技推出革新性智能业务分析洞察平台——无涯·问数

5月30-31日&#xff0c;2024向星力未来数据技术峰会期间&#xff0c;星环科技重磅发布一款新产品无涯问数——智能业务分析洞察平台。该产品旨在解决传统BI在数据获取、使用门槛和效率方面的挑战&#xff0c;为决策者和业务人员带来前所未有的数据分析体验。 无涯问数的问世&am…

丰田生产方式:拒绝表面效率!!!

在需要的时间&#xff0c;一件一件地生产所需要的东西&#xff0c;就可以避免“过量生产的浪费”。但是&#xff0c;这时你必须知道“需要的时间”是在什么时候。于是&#xff0c;“单位时间”的意义就很重要了。 所谓“单位时间”&#xff0c;就是指制造一件产品的时间。这必须…

【开源分享】2024最新php在线客服系统源码|聊天系统 附搭建教程

源码的主要特色 自动回复和机器人知识库&#xff1a;通过后台设置机器人知识库&#xff0c;系统可以根据关键词自动回复用户&#xff0c;提高响应速度和服务效率。 内容过滤&#xff1a;支持设置违禁词&#xff0c;避免接收包含不良信息的用户消息&#xff0c;维护平台健康。…