unet学习(初学者 自用)

news2025/4/16 16:40:54

代码解读 | 极简代码遥感语义分割,结合GDAL从零实现,以U-Net和建筑物提取为例

以上面链接中的代码为例,逐行解释。

训练

unet的train.py如下:

import torch.nn as nn
import torch
import gdal
import numpy as np
from torch.utils.data import Dataset, DataLoader


class UNet(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(UNet, self).__init__() # 在 Python 中,如果一个类继承了另一个类(例如 UNet 继承了 nn.Module),那么子类需要调用父类的构造函数来初始化父类的属性。

        # 定义encoder1-4、中心部分、decoder4-1和最终的卷积层
        self.enc1 = self.conv_block(input_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.center = self.conv_block(512, 1024)
        self.dec4 = self.conv_block(1024 + 512, 512)
        self.dec3 = self.conv_block(512 + 256, 256)
        self.dec2 = self.conv_block(256 + 128, 128)
        self.dec1 = self.conv_block(128 + 64, 64)
        self.final = nn.Conv2d(64,out_channels, kernel_size=1)
        #定义最大池化层,用于下采样;定义上采样层
        self.pool = nn.MaxPool2d(2, 2) 
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 

    #定义一个卷积块,包含两个卷积层。每个卷积层后面跟着 ReLU 激活函数和批量归一化。
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )
    
    #定义前向传播过程,torch.cat 用于将编码器的特征图与解码器的特征图拼接在一起。
    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        center = self.center(self.pool(enc4))

        dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))
        dec3 = self.dec3(torch.cat([enc3, self.up(dec4)], 1))
        dec2 = self.dec2(torch.cat([enc2, self.up(dec3)], 1))
        dec1 = self.dec1(torch.cat([enc1, self.up(dec2)], 1))
        final = self.final(dec1).squeeze()

        return torch.sigmoid(final)


class RSDataset(Dataset):
    def __init__(self, images_dir, labels_dir):
        self.images = self.read_multiband_images(images_dir)
        self.labels = self.read_singleband_labels(labels_dir)

    def read_multiband_images(self, images_dir):#读取多波段图像,并将其堆叠成一个三维数组。
        images = []
        for image_path in images_dir:
            rsdl_data = gdal.Open(image_path)
            images.append(np.stack([rsdl_data .GetRasterBand(i).ReadAsArray() for i in range(1, 4)], axis=0))
        return images

    def read_singleband_labels(self, labels_dir):#读取单波段标签图像。
        labels = []
        for label_path in labels_dir:
            rsdl_data = gdal.Open(label_path)
            labels.append(rsdl_data .GetRasterBand(1).ReadAsArray())
        return labels

    def __len__(self):#返回数据集长度
        return len(self.images)

    def __getitem__(self, idx):#返回指定索引的图像和标签,并将其转换为 PyTorch 张量。
        image = self.images[idx]
        label = self.labels[idx]
        return torch.tensor(image), torch.tensor(label)

images_dir = ['data/2_95_sat.tif', 'data/2_96_sat.tif',  'data/2_97_sat.tif', 'data/2_98_sat.tif', 'data/2_976_sat.tif']
labels_dir =['data/2_95_mask.tif', 'data/2_96_mask.tif',  'data/2_97_mask.tif', 'data/2_98_mask.tif', 'data/2_976_mask.tif']

#创建 RSDataset 实例,并使用 DataLoader 加载数据,设置批量大小为 2,并打乱数据
dataset = RSDataset(images_dir, labels_dir)
trainloader = DataLoader(dataset, batch_size=2, shuffle=True)

model = UNet(3, 1) #输入通道数3,输出通道数1
criterion = nn.BCELoss()#定义loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)#优化器
num_epochs=50

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(trainloader):
        images = images.float()
        labels = labels.float()/255.0
        outputs = model(images)
        labels = labels.squeeze(0)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

torch.save(model.state_dict(), 'models_building_50.pth')

Q1:dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))为什么要将编码器的特征图与解码器的特征图拼接在一起?这个拼接是怎么拼接,我不理解

A1:

Q2:final = self.final(dec1).squeeze() 这个squeeze是什么

A2:

推理 

infer.py内容如下:

import torch.nn as nn
import gdal
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2

class UNet(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(UNet, self).__init__()

        self.enc1 = self.conv_block(input_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.center = self.conv_block(512, 1024)
        self.dec4 = self.conv_block(1024 + 512, 512)
        self.dec3 = self.conv_block(512 + 256, 256)
        self.dec2 = self.conv_block(256 + 128, 128)
        self.dec1 = self.conv_block(128 + 64, 64)
        self.final = nn.Conv2d(64,out_channels, kernel_size=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)


    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        center = self.center(self.pool(enc4))

        dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))
        dec3 = self.dec3(torch.cat([enc3, self.up(dec4)], 1))
        dec2 = self.dec2(torch.cat([enc2, self.up(dec3)], 1))
        dec1 = self.dec1(torch.cat([enc1, self.up(dec2)], 1))
        final = self.final(dec1).squeeze()

        return torch.sigmoid(final)

model = UNet(3, 1)
model.load_state_dict(torch.load('models_building_50.pth'))
model.eval()

image_file='data/2_955_sat.tif'
rsdataset = gdal.Open(image_file)
images=(np.stack([rsdataset.GetRasterBand(i).ReadAsArray() for i in range(1, 4)], axis=0))
test_images = torch.tensor(images).float().unsqueeze(0)

outputs = model(test_images)
outputs = (outputs > 0.8).float()

cv2.imshow('Prediction', outputs.numpy())
cv2.waitKey(0)

要学习的点:

Q1:train和eval模式有什么区别

A1:

Q2:我不能理解,为什么model = UNet(3, 1) 初始化了一个unet网络,然后就可以outputs = model(test_images)?这个model的输入输出是什么呢

A2:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2298026.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CCFCSP第34次认证第一题——矩阵重塑(其一)

第34次认证第一题——矩阵重塑(其一) 官网链接 时间限制: 1.0 秒 空间限制: 512 MiB 相关文件: 题目目录(样例文件) 题目背景 矩阵(二维)的重塑(reshap…

探索B-树系列

🌈前言🌈 本文将讲解B树系列,包含 B-树,B树,B*树,其中主要讲解B树底层原理,为什么用B树作为外查询的数据结构,以及B-树插入操作并用代码实现;介绍B树、B*树。 &#x1f4…

GRN前沿:DeepMCL:通过深度多视图对比学习从单细胞基因表达数据推断基因调控网络

1.论文原名:Inferring gene regulatory networks from single-cell gene expression data via deep multi-view contrastive learning 2.发表日期:2023 摘要: 基因调控网络(GRNs)的构建对于理解细胞内复杂的调控机制…

Linux 内核架构入门:从基础概念到面试指南*

1. 引言 Linux 内核是现代操作系统的核心,负责管理硬件资源、提供系统调用、处理进程调度等功能。对于初学者来说,理解 Linux 内核的架构是深入操作系统开发的第一步。本篇博文将详细介绍 Linux 内核的架构体系,结合硬件、子系统及软件支持的…

【竞技宝】PGL瓦拉几亚S4预选:Tidebound2-0轻取spiky

北京时间2月13日,DOTA2的PGL瓦拉几亚S4预选赛继续进行,昨日进行的中国区预选赛胜者组首轮Tidebound对阵的spiky比赛中,以下是本场比赛的详细战报。 第一局: 首局比赛,spiky在天辉方,Tidebound在夜魇方。阵容方面,spiky点出了幻刺、火枪、猛犸、小强、巫妖,Tidebound则是拿到飞…

EasyRTC智能硬件:小体积,大能量,开启音视频互动新体验

在万物互联的时代,智能硬件正以前所未有的速度融入我们的生活。然而,受限于硬件性能和网络环境,许多智能硬件在音视频互动体验上仍存在延迟高、卡顿、回声等问题,严重影响了用户的使用体验。 EasyRTC智能硬件,凭借其强…

【ESP32指向鼠标】——icm20948与esp32通信

【ESP32指向鼠标】——icm20948与esp32通信 ICM-20948介绍 ICM-20948 是一款由 InvenSense(现为 TDK 的一部分)生产的 9 轴传感器集成电路。它结合了 陀螺仪、加速度计和磁力计。 内置了 DMP(Digital Motion Processor)即负责执…

算法——结合实例了解深度优先搜索(DFS)

一,深度优先搜索(DFS)详解 DFS是什么? 深度优先搜索(Depth-First Search,DFS)是一种用于遍历或搜索树、图的算法。其核心思想是尽可能深地探索分支,直到无法继续时回溯到上一个节点…

SpringMVC学习使用

一、SpringMVC简单理解 1.1 Spring与Web环境集成 1.1.1 ApplicationContext应用上下文获取方式 应用上下文对象是通过new ClasspathXmlApplicationContext(spring配置文件) 方式获取的,但是每次从容器中获得Bean时都要编写new ClasspathXmlApplicationContext(sp…

运维-自动访问系统并截图

需求背景 因项目甲方要求需要对系统进行巡检,由于系统服务器较多,并且已经采用PrometheusGrafana对系统服务器进行管理,如果要完成该任务,需要安排一个人力对各个系统和服务器进行一一截图等操作,费时费力&#xff0c…

在CodeBlocks搭建SDL2工程虚拟TFT彩屏解码带压缩形式的Bitmap(BMP)图像显示

在CodeBlocks搭建SDL2工程虚拟TFT彩屏解码带压缩形式的Bitmap BMP图像显示 参考文章文章说明一、创建和退出SDL2二、 Bitmap(BMP)图片解码图三、Bitmap解码初始化四、测试代码五、主函数六、测试结果 参考文章 解码带压缩形式的Bitmap(BMP)图像并使用Python可视化解码后实际图…

mapbox进阶,添加绘图扩展插件,绘制任意方向矩形

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️MapboxDraw 绘图控件二、🍀添加绘图扩…

初阶c语言(循环语句习题,完结)

前言: c语言为b站鹏哥,嗯对应视频37集 昨天做的c语言,今天在来做一遍,发现做错了 今天改了平均值的计算, 就是说最大值加上最小值,如果说这个数值非常大的话,两个值加上会超过int类型的最大…

提升编程效率,体验智能编程助手—豆包MarsCode一键Apply功能测评

提升编程效率,体验智能编程助手—豆包MarsCode一键Apply功能测评 🌟 嗨,我是LucianaiB! 🌍 总有人间一两风,填我十万八千梦。 🚀 路漫漫其修远兮,吾将上下而求索。 目录 引言豆包…

【deepseek-r1本地部署】

首先需要安装ollama,之前已经安装过了,这里不展示细节 在cmd中输入官网安装命令:ollama run deepseek-r1:32b,开始下载 出现success后,下载完成 接下来就可以使用了,不过是用cmd来运行使用 可以安装UI可视化界面&a…

多用户商城系统的客服管理体系建设

多用户商城系统的运营,客服管理体系建设至关重要。优质的客服服务不仅能提升用户购物体验,还能增强用户对商城的信任与忠诚度,进而促进商城业务的持续增长。以下从四个关键方面探讨如何建设完善的客服管理体系,信息化客服系统在其…

C++设计模式 - 模板模式

一:概述 模板方法(Template Method)是一种行为型设计模式。它定义了一个算法的基本框架,并且可能是《设计模式:可复用面向对象软件的基础》一书中最常用的设计模式之一。 模板方法的核心思想很容易理解。我们需要定义一…

CZML 格式详解,javascript加载导出CZML文件示例

示例地址:https://dajianshi.blog.csdn.net/article/details/145573994 CZML 格式详解 1. 什么是 CZML? CZML(Cesium Zipped Markup Language)是一种基于 JSON 的文件格式,用于描述地理空间数据和时间动态场景。它专…

OpenAI推出全新AI助手“Operator”:让人工智能帮你做事的新时代!

引言 随着人工智能技术的不断发展,OpenAI 再次推出令人兴奋的功能——Operator,一个全新的 AI 助手平台。这不仅仅是一个普通的助手,它代表了人工智能技术的又一次飞跃,将改变我们工作和生活的方式。 什么是“Operator”&#xff…

重看Spring聚焦BeanFactory分析

目录 一、理解BeanFactory (一)功能性理解 (二)BeanFactory和它的子接口 (三)BeanFactory的实现类 二、BeanFactory根接口 (一)源码展示和理解 (二)基…