[Matting]论文阅读:Deep Image Matting 详细解读-CSDN博客文章浏览阅读3.5k次,点赞2次,收藏10次。[Matting]论文阅读:Deep Image Matting 详细解读一 、摘要二、方法2.1 第一部分(Matting encoder-decoder stage)2.2 第二部分(Matting refinement stage)三、实验结果四、思考写在前面文章链接: linkgithub复现链接: link随着深度学习的成熟与发展,抠图的方法从早些年间基于传播的方法和基于采样的方法到现在使用卷积神经网络来完成抠图的过程,中间伴随着众多挑战:1.总所周知,深度学习出色的结果是通过大量的数_deep image mattinghttps://blog.csdn.net/XxxLittleOne/article/details/114435137
代码:
model = DIMModel(n_classes=1,in_channels=4,is_unpooling=True)
- down1, indices_1, unpool_shape1 = self.down1(inputs)
- down2, indices_2, unpool_shape2 = self.down2(down1)
- down3, indices_3, unpool_shape3 = self.down3(down2)
- down4, indices_4, unpool_shape4 = self.down4(down3)
- down5, indices_5, unpool_shape5 = self.down5(down4)
- up5 = self.up5(down5, indices_5, unpool_shape5)
- up4 = self.up4(up5, indices_4, unpool_shape4)
- up3 = self.up3(up4, indices_3, unpool_shape3)
- up2 = self.up2(up3, indices_2, unpool_shape2)
- up1 = self.up1(up2, indices_1, unpool_shape1)
- x = torch.squeeze(up1, dim=1) # [N, 1, 320, 320] -> [N, 320, 320]
- x = self.sigmoid(x)
train_dataset = DIMDataset("train")
- img, alpha, fg, bg = process(im_name, bg_name) # 在data_gen中提前生成
- different_sizes = [(320, 320), (480, 480), (640, 640)]
- crop_size = random.choice(different_sizes)
- trimap = gen_trimap(alpha)
- x, y = random_choice(trimap, crop_size)
- img = safe_crop(img, x, y, crop_size)
- alpha = safe_crop(alpha, x, y, crop_size)
- trimap = gen_trimap(alpha)
- x = torch.zeros((4, im_size, im_size), dtype=torch.float)
- img = img[..., ::-1] # RGB
- img = transforms.ToPILImage()(img)
- img = self.transformer(img)
- x[0:3, :, :] = img
- x[3, :, :] = torch.from_numpy(trimap.copy() / 255.) # RGB+Trimap
- y = np.empty((2, im_size, im_size), dtype=np.float32)
- y[0, :, :] = alpha / 255.
- mask = np.equal(trimap, 128).astype(np.float32)
- y[1, :, :] = mask # alpha+mask
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
valid_dataset = DIMDataset('valid')
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
for epoch in rang(start_epoch,end_epoch):
train_loss = train()
- model.train()
- for i,(img,alpha_label) in enumerate(train_loader):
alpha_out = model(img)
loss = alpha_prediction_loss(alpha_out,alpha_label)
optimizer.zero_grad()
loss.backward()
...
valid_loss = valid()
save_checkpoint()