代码解读 | 极简代码遥感语义分割,结合GDAL从零实现,以U-Net和建筑物提取为例
以上面链接中的代码为例,逐行解释。
训练
unet的train.py如下:
import torch.nn as nn
import torch
import gdal
import numpy as np
from torch.utils.data import Dataset, DataLoader
class UNet(nn.Module):
def __init__(self, input_channels, out_channels):
super(UNet, self).__init__() # 在 Python 中,如果一个类继承了另一个类(例如 UNet 继承了 nn.Module),那么子类需要调用父类的构造函数来初始化父类的属性。
# 定义encoder1-4、中心部分、decoder4-1和最终的卷积层
self.enc1 = self.conv_block(input_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)
self.center = self.conv_block(512, 1024)
self.dec4 = self.conv_block(1024 + 512, 512)
self.dec3 = self.conv_block(512 + 256, 256)
self.dec2 = self.conv_block(256 + 128, 128)
self.dec1 = self.conv_block(128 + 64, 64)
self.final = nn.Conv2d(64,out_channels, kernel_size=1)
#定义最大池化层,用于下采样;定义上采样层
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#定义一个卷积块,包含两个卷积层。每个卷积层后面跟着 ReLU 激活函数和批量归一化。
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels)
)
#定义前向传播过程,torch.cat 用于将编码器的特征图与解码器的特征图拼接在一起。
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool(enc1))
enc3 = self.enc3(self.pool(enc2))
enc4 = self.enc4(self.pool(enc3))
center = self.center(self.pool(enc4))
dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))
dec3 = self.dec3(torch.cat([enc3, self.up(dec4)], 1))
dec2 = self.dec2(torch.cat([enc2, self.up(dec3)], 1))
dec1 = self.dec1(torch.cat([enc1, self.up(dec2)], 1))
final = self.final(dec1).squeeze()
return torch.sigmoid(final)
class RSDataset(Dataset):
def __init__(self, images_dir, labels_dir):
self.images = self.read_multiband_images(images_dir)
self.labels = self.read_singleband_labels(labels_dir)
def read_multiband_images(self, images_dir):#读取多波段图像,并将其堆叠成一个三维数组。
images = []
for image_path in images_dir:
rsdl_data = gdal.Open(image_path)
images.append(np.stack([rsdl_data .GetRasterBand(i).ReadAsArray() for i in range(1, 4)], axis=0))
return images
def read_singleband_labels(self, labels_dir):#读取单波段标签图像。
labels = []
for label_path in labels_dir:
rsdl_data = gdal.Open(label_path)
labels.append(rsdl_data .GetRasterBand(1).ReadAsArray())
return labels
def __len__(self):#返回数据集长度
return len(self.images)
def __getitem__(self, idx):#返回指定索引的图像和标签,并将其转换为 PyTorch 张量。
image = self.images[idx]
label = self.labels[idx]
return torch.tensor(image), torch.tensor(label)
images_dir = ['data/2_95_sat.tif', 'data/2_96_sat.tif', 'data/2_97_sat.tif', 'data/2_98_sat.tif', 'data/2_976_sat.tif']
labels_dir =['data/2_95_mask.tif', 'data/2_96_mask.tif', 'data/2_97_mask.tif', 'data/2_98_mask.tif', 'data/2_976_mask.tif']
#创建 RSDataset 实例,并使用 DataLoader 加载数据,设置批量大小为 2,并打乱数据
dataset = RSDataset(images_dir, labels_dir)
trainloader = DataLoader(dataset, batch_size=2, shuffle=True)
model = UNet(3, 1) #输入通道数3,输出通道数1
criterion = nn.BCELoss()#定义loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)#优化器
num_epochs=50
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(trainloader):
images = images.float()
labels = labels.float()/255.0
outputs = model(images)
labels = labels.squeeze(0)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
torch.save(model.state_dict(), 'models_building_50.pth')
Q1:dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))为什么要将编码器的特征图与解码器的特征图拼接在一起?这个拼接是怎么拼接,我不理解
A1:
Q2:final = self.final(dec1).squeeze() 这个squeeze是什么
A2:
推理
infer.py内容如下:
import torch.nn as nn
import gdal
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
class UNet(nn.Module):
def __init__(self, input_channels, out_channels):
super(UNet, self).__init__()
self.enc1 = self.conv_block(input_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.enc3 = self.conv_block(128, 256)
self.enc4 = self.conv_block(256, 512)
self.center = self.conv_block(512, 1024)
self.dec4 = self.conv_block(1024 + 512, 512)
self.dec3 = self.conv_block(512 + 256, 256)
self.dec2 = self.conv_block(256 + 128, 128)
self.dec1 = self.conv_block(128 + 64, 64)
self.final = nn.Conv2d(64,out_channels, kernel_size=1)
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool(enc1))
enc3 = self.enc3(self.pool(enc2))
enc4 = self.enc4(self.pool(enc3))
center = self.center(self.pool(enc4))
dec4 = self.dec4(torch.cat([enc4, self.up(center)], 1))
dec3 = self.dec3(torch.cat([enc3, self.up(dec4)], 1))
dec2 = self.dec2(torch.cat([enc2, self.up(dec3)], 1))
dec1 = self.dec1(torch.cat([enc1, self.up(dec2)], 1))
final = self.final(dec1).squeeze()
return torch.sigmoid(final)
model = UNet(3, 1)
model.load_state_dict(torch.load('models_building_50.pth'))
model.eval()
image_file='data/2_955_sat.tif'
rsdataset = gdal.Open(image_file)
images=(np.stack([rsdataset.GetRasterBand(i).ReadAsArray() for i in range(1, 4)], axis=0))
test_images = torch.tensor(images).float().unsqueeze(0)
outputs = model(test_images)
outputs = (outputs > 0.8).float()
cv2.imshow('Prediction', outputs.numpy())
cv2.waitKey(0)
要学习的点:
Q1:train和eval模式有什么区别
A1:
Q2:我不能理解,为什么model = UNet(3, 1) 初始化了一个unet网络,然后就可以outputs = model(test_images)?这个model的输入输出是什么呢
A2: