模型训练与预测
- 项目列表
- 前言
- 损失函数
- one_hot
- Dice Loss
- Focal Loss
- 模型参数与训练
- 预测
项目列表
语义分割项目(一)——数据概况及预处理
语义分割项目(二)——标签转换与数据加载
语义分割项目(三)——语义分割模型(U-net和deeplavb3+)
语义分割项目(四)——模型训练与预测
前言
在本系列的前几篇文章中我们介绍了数据与模型,在本篇中我们将数据与模型相结合进行模型训练与预测。
损失函数
在正式构建损失函数之前我们首先介绍一下Dice loss,与其他分类任务不同的是,语义分割不仅要针对单个像素的分类,还包括像素所处位置的回归,对于像素的分类我们可以直接采用交叉熵去尽可能的与标签回归达到分类的效果(在这里我使用的改进后的交叉熵——Focal loss),而对于像素所处位置的损失我们以下面的公式来表示:
D
i
c
e
l
o
s
s
=
1
−
2
∗
∣
l
a
b
e
l
∩
t
a
r
g
e
t
∣
∣
l
a
b
e
l
∣
+
∣
t
a
r
g
e
t
∣
Dice loss = 1 - \frac{2 *|label \cap target|}{|label|+|target|}
Diceloss=1−∣label∣+∣target∣2∗∣label∩target∣
也就是1减去标签像素位置与预测像素位置的交集的二倍与标签总像素位置之和加上预测像素位置之和。
one_hot
为了求像素位置,我们需要对于标签进行one hot编码,即有像素为1,没有像素为0
def one_hot(target, num_classes=6, device='cuda'):
b, h, w = target.size()
hot = torch.zeros((num_classes, b, h, w)).to(device)
for i in range (num_classes):
idx = (target==i)
hot[i, idx] = 1.0
return hot.permute((1, 2, 3, 0))
Dice Loss
def Dice_loss(inputs, target):
inputs_hot = one_hot(inputs.argmax(dim=1))
target_hot = one_hot(target)
inter = (inputs_hot * target_hot).sum(dim=3)
unin = inputs_hot.sum(dim=3) + target_hot.sum(dim=3)
scores = 2 * inter / unin
dice_loss = 1 - scores.mean()
return dice_loss
Focal Loss
def Focal_Loss(inputs, target, alpha=0.5, gamma=2):
n, c, h, w = inputs.size()
nt, ht, wt = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
temp_target = target.view(-1)
logpt = -nn.CrossEntropyLoss(reduction='none')(temp_inputs, temp_target)
pt = torch.exp(logpt)
if alpha is not None:
logpt *= alpha
loss = -((1 - pt) ** gamma) * logpt
loss = loss.mean()
return loss
def loss(inputs, target):
return Focal_Loss(inputs, target) + Dice_loss(inputs, target)
模型参数与训练
def train(net, epochs, train_iter, test_iter, device, loss, optimizer):
print("device in : ", device)
net = net.to(device)
for epoch in range(epochs):
net.train()
train_loss = 0
train_num = 0
with tqdm(range(len(train_iter)), ncols=100, colour='red',
desc="train epoch {}/{}".format(epoch + 1, num_epochs)) as pbar:
for i, (X, y) in enumerate(train_iter):
optimizer.zero_grad()
X, y = X.to(device), y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
train_loss += l.detach()
train_num += 1
pbar.set_postfix({'loss': "{:.6f}".format(train_loss / train_num)})
pbar.update(1)
net.eval()
test_loss = 0
test_num = 0
with tqdm(range(len(test_iter)), ncols=100, colour='blue',
desc="test epoch {}/{}".format(epoch + 1, num_epochs)) as pbar:
for X, y in test_iter:
X, y = X.to(device), y.to(device)
y_hat = net(X)
with torch.no_grad():
l = loss(y_hat, y)
test_loss += l.detach()
test_num += 1
pbar.set_postfix({'loss': "{:.6f}".format(test_loss / test_num)})
pbar.update(1)
batch_size = 2
crop_size = (600, 600) # 裁剪大小
model_choice = 'U-net' # 可选U-net、deeplabv3+
in_channels = 3 # 输入图像通道
out_channels = 6 # 输出标签类别
num_epochs = 25 # 总轮次
lr = 5e-6
wd = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_iter, test_iter = load_data_voc(batch_size, crop_size, data_dir='dataset')
if model_choice == 'U-net':
net = U_net()
elif model_choice == 'deeplabv3+':
net = deeplabv3(3, 6)
if model_choice == 'U-net':
model_path = os.path.join('model_weights', 'u-net-vgg16.pth')
elif model_choice == 'deeplabv3+':
model_path = os.path.join('model_weights', 'Semantic-deeplabv3.pth')
net.load_state_dict(torch.load(model_path))
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
train(net, num_epochs, train_iter, test_iter, device='cuda', loss=loss, optimizer=trainer)
torch.save(net.state_dict(), model_path)
预测
在预测前我们需要进行一些额外处理,比如将数值标签转换成RGB图像标签,我们在本篇中使用label2image实现
import os
from math import ceil
import torch
import torchvision
from torchvision import io
from utils.dataLoader import load_data_voc
from utils.dataConvert import loadColorMap
from utils.model import U_net, deeplabv3
from torchvision import transforms
import matplotlib.pyplot as plt
def label2image(pred, device):
VOC_COLORMAP = loadColorMap()
colormap = torch.tensor(VOC_COLORMAP, device=device)
X = pred.long()
return colormap[X, :]
def predict(net, device, img, means, stds):
trans = torchvision.transforms.Normalize(
mean=means, std=stds)
X = trans(img / 255).unsqueeze(0)
pred = net(X.to(device)).argmax(dim=1)
return pred.reshape(pred.shape[1], pred.shape[2])
def read_voc_images(data_dir, is_train=True):
images = []
labels = []
if is_train:
with open(os.path.join(data_dir, 'train.txt')) as f:
lst = [name.strip() for name in f.readlines()]
else:
with open(os.path.join(data_dir, 'test.txt')) as f:
lst = [name.strip() for name in f.readlines()]
for name in lst:
image = io.read_image(os.path.join(data_dir, 'images', '{:03d}.jpg'.format(int(name))))
label = io.read_image(os.path.join(data_dir, 'labels', '{:03d}.png'.format(int(name))))
images.append(image)
labels.append(label)
return images, labels
def plotPredictAns(imgs):
length = len(imgs)
for i, img in enumerate(imgs):
plt.subplot(ceil(length / 3), 3, i+1)
plt.imshow(img)
plt.xticks([])
plt.yticks([])
if i == 0:
plt.title("original images")
if i == 1:
plt.title("predict label")
if i == 2:
plt.title("true label")
plt.show()
if __name__ == '__main__':
voc_dir = './dataset/'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
means = [0.4813, 0.4844, 0.4919]
stds = [0.2467, 0.2478, 0.2542]
test_images, test_labels = read_voc_images(voc_dir, False)
n = 4
imgs = []
batch_size = 2
crop_size = (600, 600) # 裁剪大小
_, test_iter = load_data_voc(batch_size, crop_size, data_dir='dataset')
model_choice = 'U-net'
if model_choice == 'U-net':
net = U_net()
elif model_choice == 'deeplabv3+':
net = deeplabv3(3, 6)
if model_choice == 'U-net':
model_path = os.path.join('model_weights', 'u-net-vgg16.pth')
elif model_choice == 'deeplabv3+':
model_path = os.path.join('model_weights', 'Semantic-deeplabv3.pth')
net.load_state_dict(torch.load(model_path))
net = net.to(device)
for i in range(n):
crop_rect = (0, 0, 600, 600)
X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
pred = label2image(predict(net, device, X, means, stds), device)
imgs += [X.permute(1, 2, 0), pred.cpu(),
torchvision.transforms.functional.crop(test_labels[i], *crop_rect).permute(1, 2, 0)]
plotPredictAns(imgs)