基于预训练模型的Unet【超级简单】【懒人版】【Pytorch版】

news2025/1/23 22:44:08

基于预训练模型的Unet【超级简单】【懒人版】【Pytorch版】

在本项目开始前,首先给大家保证,本次项目只是一个最简单的Unet实现,使用现成的代码,不需要手写代码,使用预训练模型,不需要标注数据集和训练。所以,如果只是想稍微接触一下语义分割的话,放心观看!!!保证不需要脑子!!!
在这里插入图片描述大家好哇!其实在计算机视觉领域,一直有一个我很感兴趣,但是至今还没有接触的任务,就是语义分割。我们实验室面有人做语义分割,每次看到展示工作的时候,都觉得好神奇哇!智能抠图!好有意思!
现在让我们开始吧!

实验

首先我们在GitHub上面下载Pytorch版的Unet官方代码:
在这里插入图片描述下载之后,我们可以看到,在predict.py文件里面,这里‘–model’,默认是‘MODEL.pth’,这里需要我们下载一个预训练模型.pth文件,放在文件夹下,这样我们就可以直接使用预训练模型进行预测啦!
在这里插入图片描述我们继续下拉界面
可以看到这里有个Pretrained model 的蓝色字体,点击会跳转
在这里插入图片描述
接下来就跳转到预训练模型界面啦,大家可以选择下载!
在这里插入图片描述下载到本地后,就给可以更改‘–model’的默认值,

parser.add_argument('--model', '-m', default='unet_carvana_scale0.5_epoch2.pth', metavar='FILE',
                        help='Specify the file in which the model is stored')

接下来就可以快乐预测啦!

注意

Unet官方有提供预训练模型unet_carvana_scale0.5_epoch2.pth,该模型是在Carvana数据集上进行训练。
2017 年 7 月,美国二手汽车零售平台 Carvana 在知名机器学习竞赛平台 kaggle 上发布了名为 Carvana 图像掩模挑战赛(Carvana Image Masking Challenge)的比赛项目,吸引了许多计算机视觉等相关领域的研究者参与。Carvana 希望为消费者提供全面、透明的购车信息,以提升购买体验。传统的二手车销售平台向消费者提供的车辆展示图片往往是模糊的,缺少标准规范的汽车信息图片往往也不能全面地向消费者展示全面的信息。这严重降低了二手车的销售效率。为了解决这一问题,Carvana 设计了一套用以展示 16 张可旋转的汽车图片的系统。然而,反光以及车身颜色与背景过于相似等问题会引起一系列视觉错误,使得 Carvana 不得不聘请专业的图片编辑来修改汽车图片。这无疑是一件费时费力的工作。因此,Carvana 希望此次比赛的参赛者设计出能够自动将图片中的汽车从背景中抽离的算法,以便日后将汽车融合到新的背景中去。
所以,该模型其实是一个汽车语义分割的2分类模型,大家在测试的时候,一定记得测试的图片是汽车的图片,最好背景也干净一点,这样效果会比价好。

将两张图片水平拼接

因为我想看到一个语义分割结果和原图的对比,所以就增加了一个图像水平拼接函数。

# 定义图像拼接函数
def join_two_image(img_1, img_2, flag='horizontal'):  # 默认是水平参数
    size1, size2 = img_1.size, img_2.size
    if flag == 'horizontal':
        joint = Image.new("RGB", (size1[0] + size2[0], size1[1]))
        loc1, loc2 = (0, 0), (size1[0], 0)
        joint.paste(img_1, loc1)
        joint.paste(img_2, loc2)
    return joint

测试结果

请添加图片描述

请添加图片描述

请添加图片描述

完整的predict.py代码

import argparse
import logging
import os

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask

def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()
    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)[0]
        else:
            probs = torch.sigmoid(output)[0]

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((full_img.size[1], full_img.size[0])),
            transforms.ToTensor()
        ])

        full_mask = tf(probs.cpu()).squeeze()

    if net.n_classes == 1:
        return (full_mask > out_threshold).numpy()
    else:
        return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()


def get_args():
    parser = argparse.ArgumentParser(description='Predict masks from input images')
    parser.add_argument('--model', '-m', default='unet_carvana_scale0.5_epoch2.pth', metavar='FILE',
                        help='Specify the file in which the model is stored')
    parser.add_argument('--input',  default='images', metavar='INPUT', help='Filenames of input images')
    parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
    parser.add_argument('--viz', '-v', action='store_true',
                        help='Visualize the images as they are processed')
    parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
    parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
                        help='Minimum probability value to consider a mask pixel white')
    parser.add_argument('--scale', '-s', type=float, default=0.5,
                        help='Scale factor for the input images')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')

    return parser.parse_args()


def get_output_filenames(args):
    def _generate_name(fn):
        return f'{os.path.splitext(fn)[0]}_OUT.png'

    return args.output or list(map(_generate_name, args.input))


def mask_to_image(mask: np.ndarray):
    if mask.ndim == 2:
        return Image.fromarray((mask * 255).astype(np.uint8))
    elif mask.ndim == 3:
        return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
# 定义图像拼接函数
def join_two_image(img_1, img_2, flag='horizontal'):  # 默认是水平参数
    size1, size2 = img_1.size, img_2.size
    if flag == 'horizontal':
        joint = Image.new("RGB", (size1[0] + size2[0], size1[1]))
        loc1, loc2 = (0, 0), (size1[0], 0)
        joint.paste(img_1, loc1)
        joint.paste(img_2, loc2)
    return joint

if __name__ == '__main__':
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Loading model {args.model}')
    logging.info(f'Using device {device}')

    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info('Model loaded!')
    print(in_files)
    for filename in os.listdir(in_files):
        print(filename)
        logging.info(f'\nPredicting image {filename} ...')

        img = Image.open(os.path.join(in_files, filename))

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           device=device)
        result = mask_to_image(mask)
        result = join_two_image(img, result)
        result.save(os.path.join('out', filename))

嘿嘿!完结撒花!!!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/47002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NTFS及文件共享

一,NTFS安全权限概述 1、给文件和文件夹设置权限,通过设置权限,实现不同的用户访问不同文件和文件夹的权限。 2、分配了正确的访问权限后,用户才能访问对应资源。 3、设置权限防止资源被篡改、删除。 二、文件系统概述 文件系统…

[附源码]Python计算机毕业设计SSM旅游服务平台(程序+LW)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

Whistle 前端抓包

whistle文档:http://wproxy.org/whistle/install.html 1.确保电脑安装了node node -v如果能正常输出Node的版本号,表示Node已安装成功 2.安装whistle npm install -g whistlewhistle安装完成后,执行命令 whistle help 或 w2 help&#xf…

Spring——三级缓存解决循环依赖详解

三级缓存解决循环依赖详解一、什么是三级缓存二、三级缓存详解Bean实例化前属性赋值/注入前初始化后总结三、怎么解决的循环依赖四、不用三级缓存不行吗五、总结一、什么是三级缓存 就是在Bean生成流程中保存Bean对象三种形态的三个Map集合,如下: // 一…

IPv6进阶:IPv6 过渡技术之 NAT64(IPv6 节点主动访问 IPv4 节点-地址池方式)

实验拓扑 PC1是IPv4网络的一个节点,处于Trust安全域;PC2是IPv6网络的一个节点,处于Untrust安全域。 实验需求 完成防火墙IPv4、IPv6接口的配置,并将接口添加到相应的安全域;在防火墙上配置NAT64的IPv6前缀3001::/64&…

cpu设计和实现(数据访问)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 在cpu设计当中,数据访问是比较重要的一个环节。一般认为,数据访问就是内存访问。其实不然。我们都知道,cpu访问…

【微服务】SpringCloud中Ribbon的轮询(RoundRobinRule)与重试(RetryRule)策略

💖 Spring家族及微服务系列文章 ✨【微服务】SpringCloud中Ribbon集成Eureka实现负载均衡 ✨【微服务】SpringCloud轮询拉取注册表及服务发现源码解析 ✨【微服务】SpringCloud微服务续约源码解析 ✨【微服务】SpringCloud微服务注册源码解析 ✨

Nginx的操作

一、什么是nginx。 Nginx (engine x) 是一个高性能的HTTP和反向代理web服务器 , 其特点是占有内存少,并发能力强,事实上nginx的并发能力在同类型的网页服务器中表现较好。 Nginx代码完全用C语言从头写成 . 能够支持高达 50,000 个并发连接数的响应. 现在…

【pen200-lab】10.11.1.72

pen200-lab 学习笔记 【pen200-lab】10.11.1.72 🔥系列专栏:pen200-lab 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 📆首发时间:🌴2022年11月27日🌴 🍭作…

aws beanstalk 使用eb cli配置和启动环境

Elastic Beanstalk 不额外收费,只需为存储和运行应用程序所需的 AWS 资源付费 EB CLI 是 Amazon Elastic Beanstalk 的命令行界面,它提供了可简化从本地存储库创建、更新和监控环境的交互式命令 安装eb cli $ pip install virtualenv $ virtualenv ebve…

2023年考研数学测试卷(预测)

2023年考研数学测试卷 原题再现: 多的我也不说了,直接把预测的2023年考研数学卷子分享给大家好吧,准确详细全面是我的宗旨,我的博客创立初衷和发展方向肯定不应该只是"考试有用",而是面对社会生产生活的有用…

Day815.数据库参数设置优化 -Java 性能调优实战

数据库参数设置优化 Hi,我是阿昌,今天学习记录的是关于数据库参数设置优化。 MySQL 是一个灵活性比较强的数据库系统,提供了很多可配置参数,便于根据应用和服务器硬件来做定制化数据库服务。 数据库主要是用来存取数据的&#…

视频编解码 — DCT变换和量化

目录 视频编码流程 DCT变换 Hadamard变换 量化 H264中的DCT变换和量化 H264各模式的DCT变换和量化过程 1、亮度16x16帧内预测块 2,其它模式亮度块 3,色度块 小结 视频编码流程 DCT变换 离散余弦变换 它能将空域信号转换到频率上表示&#xff0…

建造者模式

文章目录定义优点使用场景代码实现定义 将一个复杂对象的构建与它的表示分离,使得同样的构建过程可以创建不同的表示。 4个角色: Product产品类:通常是实现了模板方法模式,也就是有模板方法和基本方法Builder抽象建造者&#xf…

PyQt5可视化编程-事件、信号和对话框

一、概述: 所有的应用都是事件驱动的。事件大部分都是由用户的行为产生的,当然也有其他的事件产生方式,比如网络的连接,窗口管理器或者定时器等。调用应用的exec_()方法时,应用会进入主循环,主循环会监听和分发事件。…

算法题:整数除法

一.题目描述以及来源 给定两个整数 a 和 b ,求它们的除法的商 a/b ,要求不得使用乘号 *、除号 / 以及求余符号 % 。 注意: 整数除法的结果应当截去(truncate)其小数部分,例如:truncate(8.345…

MP157-2-TF-A移植:

MP157-2-TF-A移植:1. TF-A移植:1.1 新建开发板的设备树1.2 修改设备树电源管理1.3修改TF卡和EMMC设备树1.4 修改USBOTG设备树2 编译测试2.1 Makefile.sdk 修改内容:2.2 编译命令:正点原子第九章内容:自己记的笔记&…

SpringBoot(One·上)

SpringBoot一、简介概述Spring Boot特性SpringBoot四大核心二、SpringBoot项目分析1、创建第一个案例结构目录和pom文件2、Springboot集成mvcSpringboot核心配置文件application.propertiesSpringboot核心配置文件application.yml或者application.yamlapplication.ymlapplicati…

Allegro削铜皮详细操作指导

Allegro削铜皮详细操作指导 Allegro可以编辑任意形状的铜皮,下面介绍几种削铜皮的方式 任意形状,shape-manual Void/cavity-Polygon 鼠标左键点击铜皮,铜皮会被亮起来 画出需要的形状 完成后如下图 方形shape-manual Void/cavity-Rectangular 同样的选择铜皮,画出需要…

通过 js 给元素添加动画样式animation属性 ,以及 perspective 属性探究

学习关键语句: js添加动画效果 js控制元素animation属性 写在前面 在制作组件的过程中呢 , 突然觉得这个动画啊应该由用户来决定到底是个啥样 , 但是怎么让用户操作这一步呢 ? 总不能让用户自己去写 css keyframe 吧 , 所以便有了这篇文章 , 同时 , 这篇文章的下半部分我们会…