【Pytorch项目实战】之生成式模型:DeepDream、风格迁移、图像修复

news2024/9/19 11:02:40

文章目录

  • 生成式模型
    • (算法一)深度梦境(DeepDream)
    • (算法二)风格迁移(Style Transfer)
    • (算法三)图像修复(Image Inpainting)
    • (一)实战:基于VGG19的深度梦境(DeepDream)
    • (二)实战:基于VGG19的风格迁移(Style Transfer)
    • (三)实战:基于上下文解码器(Context Encoders)的图像修复(Image Inpainting)

生成式模型

(算法一)深度梦境(DeepDream)

DeepDearm 模型在2015年由谷歌提出,理论基础是2013年所提出的《Visualizing and Understanding Convolutional Neural Networks》
具体方式: 使用梯度上升的方法可视化网络每一层的特征

  • (1)用一张噪声图像输入网络,但反向更新的时候不更新网络权重,而是更新初始图像的像素值,以这种" 训练图像 "的方式可视化网络。此外输入图像也可以是一些正常的图片,这样的话就是生成背景图像之类的。
  • (2)为了提高训练质量,采用高斯模糊以使图像更平滑,并使用多尺度(又称为八度)的图像进行计算:先连续缩小输入图像,然后,再逐步放大,并将结果合并为一个图像输出。
     

如何放大图像特征? 现有一个猫狗分类网络模型,当输入一张云的图像进行判断时,假设这朵云比较像狗,则机器提取的特征也会偏向于狗的特征。假设特征对应的概率分别为:[狗,猫] = [x1,x2] = [0.6,0.4],那么采用L2范数(L2 = x1 ^ 2 + x2 ^ 2)可以很好达到放大特征的效果,最终图像越来越像狗。

在这里插入图片描述
先对图像连续做二次等比例缩小,该比例是1.5,之所以要缩小,图像缩小是为了让图像的像素点调整后所得结果图案能显示的更加平滑,过程主要是抑制了图像的高频成分,放大了低频成分。缩小二次后,把图像每个像素点当作参数,对它们求偏导,这样就可以知道如何调整图像像素点能够对给定网络层的输出产生最大化的刺激。机器学习:DeepDearm模型(与书上完全相同)

(算法二)风格迁移(Style Transfer)

2015年,德国Gatys提出基于神经网络的图像风格迁移。
输入数据包括:一张代表内容的图像(上海外滩图)一张代表风格的图像(梵高的星空图)

  • 主要原理:将风格图像的艺术风格应用于内容图像,同时保留内容图像的内容(比如:人物、景物等)。
     
  • 核心思想:定义损失函数,包括内容损失和风格损失。(其中,风格损失权重系数远大于内容损失,达到风格迁移的效果)
  • 内容损失:卷积网络不同层学到的图像特征是不一样的。研究发现,使用靠近底层但不能靠太近的层来衡量图像内容比较理想。
    (1)靠近底层的卷积层(输入端):学到的是图像局部特征。如:位置、形状、颜色、纹理等。
    (2)靠近顶部的卷积层(输出端):学到的图像特征更全面、更抽象,但也会丢失图像的详细信息。
    卷积神经网络图像风格转移(演示)在这里插入图片描述
  • 风格损失:采用基于通道的格拉姆矩阵(Gram Matrix)来衡量风格。内积可以理解为该层特征之间相互关系的映射,映射关系反映了图像的纹理统计规律。
    计算过程:计算向量与其转置向量的内积,从而得到该向量的格拉姆矩阵。
    矩阵特点:对称。一个n维的向量可以得到n ∗ n维的格拉姆矩阵,其中每一个元素都可以表示为特征 i 与特征 j 的相关性,而矩阵对角线上的元素可以表示为某个特征 i 在整个图像中的强度。
    格拉姆矩阵(Gram matrix)详细解读
    在这里插入图片描述

(算法三)图像修复(Image Inpainting)

本文主要介绍一种基于上下文编码器(Context Encoders)的图像修复方法,上下文编码器主要构成:编码器解码器。上下文编码器的特点:根据上下文预测的原理。

网络架构:
在这里插入图片描述

  • (1)编码器与解码器之间不是采样全连接层,而是采样Channel-Wise Fully-Connected Layer,可以极大地降低参数量。
  • (2)采用对抗判定器,用来区分预测值与真实值。
  • (3)解码器。基于 AlexNet 网络:5个卷积+池化。主要通过一系列操作,使其恢复到与原图一样的大小。 如果输入图像大小为227 x 227,可以得到一个6 x 6 6 256的特征图。
  • (4)损失函数:整个模型的损失值由重构损失(Reconstruction Loss)与对抗损失(Adversarial Loss)组成。
    在这里插入图片描述

(一)实战:基于VGG19的深度梦境(DeepDream)

在这里插入图片描述

######################################################################
# 取VGG19模型为预训练模型,将获取的特征最大化之后展示在一张普通的图像上(梵高的星空图)。
# 为了训练更加有效,还使用对图像进行不同大小的缩放处理。
######################################################################
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageFilter, ImageChops
from torchvision import models
from torchvision import transforms

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'         # "OMP: Error #15: Initializing libiomp5md.dll"
######################################################################
def deprocess(image, device):
    """反归一化:在图像处理过程中有归一化的操作,所以要"反归一化"。"""
    image = image * torch.tensor([0.229, 0.224, 0.225], device=device) + torch.tensor([0.485, 0.456, 0.406], device=device)
    return image


def prod(image, feature_layers, iterations, lr, transform, device, vgg, modulelist):
    """主要功能:传入输入图像,正向传播到VGG19的指定层(如:第8层、32层等),然后,用梯度上升更新输入图像的特征值。"""
    # 传入输入图像,正 向传播到VGG19的指定层,然后,用梯度上升更新 输入图像的特征值。
    input = transform(image).unsqueeze(0)           # 对图像进行resize,转成tensor和归一化操作,要增加一个维度,表示一个样本,[1, C, H, W]
    input = input.to(device).requires_grad_(True)   # 对图片进行追踪计算梯度
    vgg.zero_grad()  # 梯度清零
    for i in range(iterations):
        out = input
        for j in range(feature_layers):     # 遍历features模块的各层
            out = modulelist[j + 1](out)    # 以上一层的输出特征作为下一层的输入特征
        loss = out.norm()                   # 计算特征的二范数
        loss.backward()                     # 反向传播计算梯度,其中图像的每个像素点都是参数
        with torch.no_grad():
            input += lr * input.grad        # 更新原始图像的像素值
    input = input.squeeze()                 # 训练完成后将表示样本数的维度去除
    # 交互维度
    # input = input.transpose(0, 1)
    # input = input.transpose(1, 2)
    input = input.permute(1, 2, 0)          # 维度转换,因为tensor的维度是(C, H, W),而array是(H, W, C)
    input = np.clip(deprocess(input, device).detach().cpu().numpy(), 0, 1)      # 将像素值限制在(0, 1)之间
    image = Image.fromarray(np.uint8(input * 255))      # 将array类型的图像转成PIL类型图像,要乘以255是因为转成tensor时函数自动除以了255
    return image


def deep_dream_vgg(image, feature_layers, iterations, lr, transform, device, vgg, modulelist, octave_scale=2, num_octaves=20):
    """递归函数,多次缩小图像,然后调用函数prod。接着再放大结果,并与按一定比例图像混合在一起,最终得到与输入图像相同大小的输出图像。"""
    # (1)octave_scale参数决定了有多少个尺度的图像, num_octaves参数决定一共有多少张图像
    # (2)octave_scale和num_octaves两个参数的选定对生成图像的影响很大。
    if num_octaves > 0:
        image1 = image.filter(ImageFilter.GaussianBlur(2))  # 高斯模糊
        if (image1.size[0] / octave_scale < 1 or image1.size[1] / octave_scale < 1):  # 当图像的大小小于octave_scale时图像尺度不再变化
            size = image1.size
        else:
            size = (int(image1.size[0] / octave_scale), int(image1.size[1] / octave_scale))
        image1 = image1.resize(size, Image.ANTIALIAS)       # 连续缩小图片
        """递归"""
        image1 = deep_dream_vgg(image1, feature_layers, iterations, lr, transform, device, vgg, modulelist, octave_scale, num_octaves - 1)
        size = (image.size[0], image.size[1])

        image1 = image1.resize(size, Image.ANTIALIAS)       # 放大图像
        image = ImageChops.blend(image, image1, 0.6)        # 按一定比例将图像混合在一起
        # PIL.ImageChops.blend(image1, image2, alpha)
        # out = image1 * (1.0 - alpha) + image2 * alpha
    """调用"""
    img_result = prod(image, feature_layers, iterations, lr, transform, device, vgg, modulelist)
    img_result = img_result.resize(image.size)
    return img_result


if __name__ == '__main__':
    # (1)图像预处理(裁剪、格式转换、标准化)
    tranform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(),
                                   transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    #############################################################################
    # VGG19预模型:基于ImageNet大数据集训练的模型,该数据集共有1000个类别。
    #       包括三种不同的模块:(1)特征提取模块(features) 共有36层;(2)池化层(avgpool) 只有一层;(3)分类层(classifier) 共有6层。
    #       备注:越靠近顶部的层,其激活值表现就越全面或抽象,如像某些类别(比如狗)的图案。
    #############################################################################
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    vgg = models.vgg19(pretrained=True).to(device)      # 下载预训练模型
    print(vgg)                                          # 打印网络层
    modulelist = list(vgg.features.modules())           # 网络层转成列表结构:下标0表示全部网络层,下标从1开始为迭代网络层
    # (2)开始训练
    night_sky = Image.open(r'starry_night.jpg')         # 加载图像
    night_sky_04 = deep_dream_vgg(image=night_sky, feature_layers=4, iterations=6, lr=0.2,
                                  transform=tranform, device=device, vgg=vgg, modulelist=modulelist,
                                  octave_scale=2, num_octaves=6)
    night_sky_08 = deep_dream_vgg(image=night_sky, feature_layers=8, iterations=6, lr=0.2,
                                  transform=tranform, device=device, vgg=vgg, modulelist=modulelist,
                                  octave_scale=2, num_octaves=6)
    night_sky_16 = deep_dream_vgg(image=night_sky, feature_layers=16, iterations=6, lr=0.2,
                                  transform=tranform, device=device, vgg=vgg, modulelist=modulelist,
                                  octave_scale=2, num_octaves=6)
    night_sky_20 = deep_dream_vgg(image=night_sky, feature_layers=20, iterations=6, lr=0.2,
                                  transform=tranform, device=device, vgg=vgg, modulelist=modulelist,
                                  octave_scale=2, num_octaves=6)
    night_sky_24 = deep_dream_vgg(image=night_sky, feature_layers=24, iterations=6, lr=0.2,
                                  transform=tranform, device=device, vgg=vgg, modulelist=modulelist,
                                  octave_scale=2, num_octaves=6)
    night_sky_28 = deep_dream_vgg(image=night_sky, feature_layers=28, iterations=6, lr=0.2,
                                  transform=tranform, device=device, vgg=vgg, modulelist=modulelist,
                                  octave_scale=2, num_octaves=6)
    night_sky_32 = deep_dream_vgg(image=night_sky, feature_layers=32, iterations=6, lr=0.2,
                                  transform=tranform, device=device, vgg=vgg, modulelist=modulelist,
                                  octave_scale=2, num_octaves=6)
    plt.subplot(241), plt.imshow(night_sky, 'gray'), plt.title('night_sky')
    plt.subplot(242), plt.imshow(night_sky_04, 'gray'), plt.title('night_sky_04')
    plt.subplot(243), plt.imshow(night_sky_08, 'gray'), plt.title('night_sky_08')
    plt.subplot(244), plt.imshow(night_sky_16, 'gray'), plt.title('night_sky_16')
    plt.subplot(245), plt.imshow(night_sky_20, 'gray'), plt.title('night_sky_20')
    plt.subplot(246), plt.imshow(night_sky_24, 'gray'), plt.title('night_sky_24')
    plt.subplot(247), plt.imshow(night_sky_28, 'gray'), plt.title('night_sky_28')
    plt.subplot(248), plt.imshow(night_sky_32, 'gray'), plt.title('night_sky_32')
    plt.show()

(二)实战:基于VGG19的风格迁移(Style Transfer)

在这里插入图片描述

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # "OMP: Error #15: Initializing libiomp5md.dll"
###################################################################################
def image_loader(image_name):
    """加载图像"""
    image = Image.open(image_name)          # 打开图像
    image = loader(image).unsqueeze(0)      # 增加一个维度,其值为1
    return image.to(device, torch.float)


def Image_trans(tensor):
    """格式转换"""
    image = tensor.cpu().clone()            # 为避免因image修改影响tensor的值,这里采用clone
    image = image.squeeze(0)                # 去掉一个维度
    unloader = transforms.ToPILImage()      # reconvert into PIL image
    image = unloader(image)
    return image


class ContentLoss(nn.Module):
    """内容损失函数"""
    def __init__(self, target, ):
        super(ContentLoss, self).__init__()
        # 必须要用detach来分离出target,这时候target不再是一个Variable
        # 这是为了动态计算梯度,否则forward会出错,不能向前传播
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input


def gram_matrix(input):
    """格拉姆矩阵"""
    a, b, c, d = input.size()               # a is batch size. b is number of channels. c is height and d is width.
    features = input.view(a * b, c * d)     # x矩阵
    G = torch.mm(features, features.t())    # 计算内积(矩阵 * 转置矩阵)
    return G.div(a * b * c * d)             # 除法(格拉姆矩阵 - 归一化处理)


class StyleLoss(nn.Module):
    """风格损失函数"""
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input


class Normalization(nn.Module):
    """标准化处理"""
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        # .view the mean and std to make them [C x 1 x 1] so that they can directly work with image Tensor of shape [B x C x H x W].
        # B is batch size. C is number of channels. H is height and W is width.
        self.mean = mean.clone().detach().view(-1, 1, 1)    # self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = std.clone().detach().view(-1, 1, 1)      # self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std


###################################################################################
# 卷积网络不同层学到的图像特征是不一样的。研究发现,使用靠近底层但不能太靠近的层来衡量图像内容比较理想。
#       (1)靠近底层的卷积层(输入端):学到的是图像局部特征。如:位置、形状、颜色、纹理等。
#       (2)靠近顶部的卷积层(输出端):学到的图像特征更全面、更抽象,但也会丢失图像详细信息。
###################################################################################
def get_style_model_and_losses(cnn,                                     # VGG19网络主要用来做内容识别
                               normalization_mean, normalization_std, style_img, content_img,
                               content_layers=['conv_4'],               # 为计算内容损失和风格损失,指定使用的卷积层
                               # 研究发现:使用前三层已经能够达到比较好的内容重建工作,而后两层保留一些比较高层的特征。
                               style_layers=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']):
    cnn = copy.deepcopy(cnn)                    # 深复制
    normalization = Normalization(normalization_mean, normalization_std).to(device)  # 标准化
    content_losses = []                         # 初始化(内容)损失值
    style_losses = []                           # 初始化(风格)损失值
    model = nn.Sequential(normalization)        # 使用sequential方法构建模型

    i = 0                                       # 每次迭代增加1
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):        # isinstance(object, classinfo):判断两个对象类型是否相同。
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
        model.add_module(name, layer)          # 添加指定的网络层(conv、relu、pool、bn)到模型中

        if name in content_layers:                          # 累加内容损失
            target = model(content_img).detach()            # 前向传播
            content_loss = ContentLoss(target)              # 内容损失
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:                            # 累加风格损失
            target_feature = model(style_img).detach()      # 前向传播
            style_loss = StyleLoss(target_feature)          # 风格损失
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    # 我们需要对在内容损失和风格损失之后的层进行修剪
    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break
    model = model[:(i + 1)]
    return model, style_losses, content_losses


def run_style_transfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img, num_steps=600, style_weight=1000000, content_weight=1):
    """风格迁移"""
    model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img)
    optimizer = optim.LBFGS([input_img.requires_grad_()])       # requires_grad_():需要对输入图像进行梯度计算。采用LBFGS优化方法
    run = [0]
    while run[0] <= num_steps:                                  # 批次数
        def closure():
            input_img.data.clamp_(0, 1)                         # 将输入张量的每个元素收紧到区间内,并返回结果到一个新张量。
            optimizer.zero_grad()                               # 梯度清零
            model(input_img)                                    # 前向传播

            # 计算当前批次的损失
            style_score = 0
            content_score = 0
            for sl in style_losses:
                style_score += sl.loss                          # (叠加)风格得分
            for cl in content_losses:
                content_score += cl.loss                        # (叠加)内容得分
            style_score *= style_weight                         # 风格权重系数:1000000
            content_score *= content_weight                     # 内容权重系数:1
            loss = style_score + content_score                  # 总损失
            loss.backward()                                     # 反向传播

            # 打印损失值
            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score.item(), content_score.item()))
                print()
            return style_score + content_score

        optimizer.step(closure)                                 # 梯度更新
    input_img.data.clamp_(0, 1)
    return input_img


###################################################################################
# (1)图像预处理
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")                   # 可用设备
# torchvision.transforms.Normalize		# Normalize的作用是用均值和标准差对Tensor进行标准化处理。
loader = transforms.Compose([transforms.Resize((512, 600)), transforms.ToTensor()])     # 数据预处理
style_img = image_loader(r"./pytorch-12/油画.jpg")                                      # 风格图像
content_img = image_loader(r"./pytorch-12/大黄蜂.jpg")                                  # 内容图像
print("style size:", style_img.size())
print("content size:", content_img.size())
assert style_img.size() == content_img.size(), "we need to import style and content images of the same size"
###################################################################################
# (2)风格迁移
cnn = models.vgg19(pretrained=True).features.to(device).eval()                          # 下载预训练模型
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)                 # 标准化(均值)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)                  # 标准化(标准差)
input_img = content_img.clone()                                                         # 复制图像
# input_img = torch.randn(content_img.data.size(), device=device)       				# 随机添加白噪声
output_img = run_style_transfer(cnn=cnn, normalization_mean=cnn_normalization_mean,
                                normalization_std=cnn_normalization_std,
                                content_img=content_img, style_img=style_img, input_img=input_img,
                                num_steps=500, style_weight=1000000, content_weight=1)
###################################################################################
# (3)画图
style_img = Image_trans(style_img)          # 格式转换
content_img = Image_trans(content_img)      # 格式转换
output_img = Image_trans(output_img)        # 格式转换
plt.subplot(131), plt.imshow(style_img, 'gray'),    plt.title('style_img')
plt.subplot(132), plt.imshow(content_img, 'gray'),  plt.title('content_img')
plt.subplot(133), plt.imshow(output_img, 'gray'),   plt.title('style_img + content_img')
plt.show()

(三)实战:基于上下文解码器(Context Encoders)的图像修复(Image Inpainting)

在这里插入图片描述

from __future__ import print_function
import argparse
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.models as models

from PIL import Image
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # "OMP: Error #15: Initializing libiomp5md.dll"
##########################################################################################################
parser = argparse.ArgumentParser(description='Process some params')
parser.add_argument('--dataset',  default='streetview', help='cifar10 | lsun | imagenet | folder | lfw ')
parser.add_argument('--test_image', default='pytorch-12/油画.jpg', required=False, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imageSize', type=int, default=128, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--nc', type=int, default=3)
parser.add_argument('--niter', type=int, default=50, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='pytorch_12_xf/model/netG_streetview.pth', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')

parser.add_argument('--nBottleneck', type=int, default=4000,help='of dim for bottleneck of encoder')
parser.add_argument('--overlapPred', type=int, default=4,help='overlapping edges')
parser.add_argument('--nef', type=int, default=64, help='of encoder filters in first conv layer')
parser.add_argument('--wtl2', type=float, default=0.999, help='0 means do not use else use with this weight')
opt = parser.parse_args(['--dataset', 'streetview'])
print(opt)


##########################################################################################################
def load_image(filename, size=None, scale=None):
    """加载图像"""
    img = Image.open(filename)
    if size is not None:
        img = img.resize((size, size), Image.ANTIALIAS)
    elif scale is not None:
        img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
    return img


def save_image(filename, data):
    """保存图像"""
    img = data.clone().add(1).div(2).mul(255).clamp(0, 255).cpu().numpy()
    img = img.transpose(1, 2, 0).astype("uint8")
    img = Image.fromarray(img)
    img.save(filename)


class netG(nn.Module):
    """测试-网络模型"""
    def __init__(self, opt):
        super(netG, self).__init__()
        self.ngpu = opt.ngpu                # ngpu表示gpu个数,如果大于1,将使用并发处理
        self.main = nn.Sequential(          # 输入通道数opt.nc, 输出通道数为opt.nef, opt.ngf为该层输出通道数
            nn.Conv2d(opt.nc,opt.nef,4,2,1, bias=False),        nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(opt.nef,opt.nef,4,2,1, bias=False),       nn.BatchNorm2d(opt.nef),            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(opt.nef,opt.nef*2,4,2,1, bias=False),     nn.BatchNorm2d(opt.nef*2),          nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(opt.nef*2,opt.nef*4,4,2,1, bias=False),   nn.BatchNorm2d(opt.nef*4),          nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(opt.nef*4,opt.nef*8,4,2,1, bias=False),   nn.BatchNorm2d(opt.nef*8),          nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(opt.nef*8,opt.nBottleneck,4, bias=False), nn.BatchNorm2d(opt.nBottleneck),    nn.LeakyReLU(0.2, inplace=True),
            # 后面采用转置卷积
            nn.ConvTranspose2d(opt.nBottleneck, opt.ngf * 8, 4, 1, 0, bias=False),  nn.BatchNorm2d(opt.ngf * 8), nn.ReLU(True),
            nn.ConvTranspose2d(opt.ngf * 8, opt.ngf * 4, 4, 2, 1, bias=False),      nn.BatchNorm2d(opt.ngf * 4), nn.ReLU(True),
            nn.ConvTranspose2d(opt.ngf * 4, opt.ngf * 2, 4, 2, 1, bias=False),      nn.BatchNorm2d(opt.ngf * 2), nn.ReLU(True),
            nn.ConvTranspose2d(opt.ngf * 2, opt.ngf, 4, 2, 1, bias=False),          nn.BatchNorm2d(opt.ngf),     nn.ReLU(True),
            nn.ConvTranspose2d(opt.ngf, opt.nc, 4, 2, 1, bias=False), nn.Tanh())

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output


##########################################################################################################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")         # 可选设备
# 加载预训练模型实现图像修复(该模型基于大量街道数据训练得到)
# checkpoint = torch.load('checkpoint.pth')				# 加载已经训练好的权重参数
netG = netG(opt)                                        # 模型实例化
# netG.load_state_dict(checkpoint['state_dict'])		    # 加载该模型的权重参数
netG.eval()                                             # 验证模型

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
image = load_image(opt.test_image, opt.imageSize)       # 加载测试图像
image = transform(image)                                # 图像预处理
image = image.repeat(1, 1, 1, 1)

criterionMSE = nn.MSELoss()                                                     # 定义MSE损失函数
input_real = torch.FloatTensor(1, 3, opt.imageSize, opt.imageSize)              # 真实
input_cropped = torch.FloatTensor(1, 3, opt.imageSize, opt.imageSize)           # 裁剪
real_center = torch.FloatTensor(1, 3, opt.imageSize//2, opt.imageSize//2)       # 中心
# 将数据加载到设备中
netG.to(device)
criterionMSE.to(device)
input_real, input_cropped, real_center = input_real.to(device), input_cropped.to(device), real_center.to(device)

input_cropped.data.resize_(image.size()).copy_(image)
real_center_cpu = image[:, :, opt.imageSize//4:opt.imageSize//4+opt.imageSize//2, opt.imageSize//4:opt.imageSize//4+opt.imageSize//2]
real_center.data.resize_(real_center_cpu.size()).copy_(real_center_cpu)
input_cropped.data[:, 0, opt.imageSize//4+opt.overlapPred:opt.imageSize//4+opt.imageSize//2-opt.overlapPred, opt.imageSize//4+opt.overlapPred:opt.imageSize//4+opt.imageSize//2-opt.overlapPred] = 2*117.0/255.0 - 1.0
input_cropped.data[:, 1, opt.imageSize//4+opt.overlapPred:opt.imageSize//4+opt.imageSize//2-opt.overlapPred, opt.imageSize//4+opt.overlapPred:opt.imageSize//4+opt.imageSize//2-opt.overlapPred] = 2*104.0/255.0 - 1.0
input_cropped.data[:, 2, opt.imageSize//4+opt.overlapPred:opt.imageSize//4+opt.imageSize//2-opt.overlapPred, opt.imageSize//4+opt.overlapPred:opt.imageSize//4+opt.imageSize//2-opt.overlapPred] = 2*123.0/255.0 - 1.0
# 开始测试
fake = netG(input_cropped)                      # 前向传播
errG = criterionMSE(fake, real_center)          # 损失函数
recon_image = input_cropped.clone()             # 复制图像
recon_image.data[:, :, opt.imageSize//4:opt.imageSize//4+opt.imageSize//2, opt.imageSize//4:opt.imageSize//4+opt.imageSize//2] = fake.data
##########################################################################################################
# 保存图像
save_image('val_real_samples.png', image[0])                        # 真实图像
save_image('val_cropped_samples.png', input_cropped.data[0])        # 裁剪图像
save_image('val_recon_samples.png', recon_image.data[0])            # 修复图像
print('%.4f' % errG.item())
# 加载图像
val_real_samples = 'val_real_samples.png'
val_recon_samples = 'val_recon_samples.png'
val_cropped_samples = 'val_cropped_samples.png'
# reconsPath = 'pytorch-12/油画.jpg'
val_real_samples = mpimg.imread(val_real_samples)
val_cropped_samples = mpimg.imread(val_cropped_samples)
val_recon_samples = mpimg.imread(val_recon_samples)
# 画图
plt.subplot(131), plt.imshow(val_real_samples, 'gray'),         plt.title('real_samples')
plt.subplot(132), plt.imshow(val_cropped_samples, 'gray'),      plt.title('cropped_samples')
plt.subplot(133), plt.imshow(val_recon_samples, 'gray'),        plt.title('recon_samples')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/183036.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(13)工业界推荐系统-小红书推荐场景及内部实践【用户行为序列建模】

&#xff08;1&#xff09;工业界推荐系统-小红书推荐场景及内部实践【业务指标、链路、ItemCF】 &#xff08;2&#xff09;工业界推荐系统-小红书推荐场景及内部实践【UserCF、离线特征处理】 &#xff08;3&#xff09;工业界推荐系统-小红书推荐场景及内部实践【矩阵补充、…

Docker搭建LNMP+Wordpress

一、服务器环境 容器操作系统IP地址主要软件nginxCentOS 7172.18.0.10Docker-NginxmysqlCentOS 7172.18.0.20Docker-MysqlmysqlCentOS 7172.18.0.20Docker-Mysql 二、Linux系统基础镜像 systemctl stop firewalld setenforce 0 docker pull centos:7 #从公有仓库中下载cento…

cubeIDE开发, stm32人工智能开发应用实践(Cube.AI).篇三

一、cube.AI实际项目应用 接篇二&#xff0c;前文都是采用FP-AI-SENSING1案例和配套的B-L475E-IOT01A开发板来阐述的&#xff0c;而实际项目中&#xff0c;我们都是基于自身项目硬件平台来训练模型及部署模型的&#xff0c;我们仅仅需要cube.AI软件包&#xff08;作为可调用库&…

技术大佬说我对「压测目标」的分析不够细

前言 前面总结压测类型的时候有简单描述了不同压测类型的从准备-脚本设计-压测的整体过程&#xff0c;但是对于压测对象没有更深入的进行分析总结&#xff0c;导致在压测执行结束后&#xff0c;出现压测结果不准确的情况。所以这边就压测的对象进行单独的总结分析。 在执行压测…

lego-loam学习笔记(三)

前言&#xff1a; 对于lego-loam中点云聚类源码的学习&#xff0c;它使用了广度优先算法&#xff0c;并且使用了数组双指针技巧。 主要分为两个部分&#xff1a; 第一个是labelComponents函数&#xff0c;它的功能是为每个点及其相邻的4个点运算角度&#xff0c;在对角度小于…

微信小程序开发

微信小程序开发 | 前言&#xff1a;本文章中的很大一部分内容的图片&#xff0c;文字信息来源于微信小程序官方文档和网络资源&#xff0c;感谢大家的支持&#xff0c;如文章中有不足和错误的地方&#xff0c;请及时联系作者-白泽。并协同修改&#xff0c;相信大家的帮助会使这…

屏蔽360阻止远程执行变更注册表自启动数据的办法

屏蔽360阻止远程执行变更注册表自启动数据的办法 运程服务器上的程序&#xff0c;由于需要。我在服务器中&#xff0c;加入更新升级自身&#xff08;exe&#xff09;文件&#xff0c;并变更操作系统自启动数据的代码。 实践证明&#xff0c;通过客户端&#xff0c;调用运程服务…

spring 声明式事务 @Transactional 运行原理

注意&#xff1a;如果想要理解spring 的声明式事务&#xff0c;必须先理解AOP 的原理。 一、spring注册 InfrastructureAdvisorAutoProxyCreator 通过 EnableTransactionManagement 可以看到先把TransactionManagementConfigurationSelector通过Import注册到spring。同时注意…

VULNCMS靶机

环境准备 靶机链接&#xff1a;百度网盘 请输入提取码 提取码&#xff1a;i3j0 虚拟机网络链接模式&#xff1a;桥接模式 攻击机系统&#xff1a;kali linux 2022.03 信息收集 1.查看靶机ip地址 2.探测目标靶机开放端口和服务情况。 nmap -p- -sV -A 192.168.1.108 漏洞…

嵌入式串行接口标准

在嵌入式系统中&#xff0c;经常使用UART接口实现通讯、调试日志数据等功能&#xff0c;但UART是一种异步通信协议&#xff0c;并未定义物理层的电气接口标准。 在板件通信时&#xff0c;UART接口之间通常基于IO直接连接进行通信&#xff08;TTL/CMOS电平标准&#xff0c;3.3V电…

梦熊杯-十二月月赛-白银组题解-B.契约

B. Problem B.契约&#xff08;contract.cpp&#xff09; 内存限制&#xff1a;256 MiB 时间限制&#xff1a;1000 ms 标准输入输出 题目类型&#xff1a;传统 评测方式&#xff1a;文本比较 题目描述&#xff1a; 「璃月」是「契约」的国度。 摩拉克斯认为&#xff0c…

Lua 字符串

Lua 字符串 参考至菜鸟教程。 字符串或串(String)是由数字、字母、下划线组成的一串字符。 Lua 语言中字符串可以使用以下三种方式来表示&#xff1a; 单引号间的一串字符。双引号间的一串字符。[[ 与 ]] 间的一串字符。 以上三种方式的字符串实例如下&#xff1a; string1 …

基于语义分割Ground Truth(GT)转换yolov5目标检测标签(路面积水检测例子)

基于语义分割Ground Truth&#xff08;GT&#xff09;转换yolov5目标检测标签&#xff08;路面积水检测例子&#xff09; 概述 许多目标检测的数据是通过直接标注或者公开平台获得&#xff0c;如果存在语义分割Ground Truth的标签文件&#xff0c;怎么样实现yolov5的目标检测…

【图论】求欧拉回路

前言 你的qq密码是否在圆周率中出现&#xff1f; 一个有意思的编码问题&#xff1a;假设密码是固定位数&#xff0c;设有nnn位&#xff0c;每位是数字0-9&#xff0c;那么这样最短的“圆周率”的长度是多少&#xff1f;或者说求一个最短的数字串定包含所有密码。 理论 一些…

acwing1264_动态求连续区间和

目录 算法分类&#xff1a; 问题描述 算法适用题目范围&#xff1a; 实现代码&#xff1a; 算法分类&#xff1a; 树状数组/线段树 问题描述 给定 n个数组成的一个数列&#xff0c;规定有两种操作&#xff0c;一是修改某个元素&#xff0c;二是求子数列 [a,b]的连续和。 …

1602_MIT 6.828试验环境搭建

全部学习汇总&#xff1a; GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 最近尝试看一下MIT的操作系统教程&#xff0c;找到了一个6.828的课程。看了一下网络上的介绍&#xff0c;看起来这个大家的认可度还是很高的。开动之前&#xff0c…

Android面经_111道安卓基础问题(四大组件BroadCast、内容提供者篇)

该文章涉及的内容主要是&#xff1a;BroadCast、内容提供者&#xff1b; Android基础问题——四大组件之BroadCast、ContentProvider 内容提供者1、BroadCast1.1、Android的广播分类1.2、Android的广播注册方式1.3、广播作用域2、内容提供者Content provider2.1、什么是内容提供…

Google Protobuf 实践使用开发

Android 敏捷开发助手 Lottie动画 轻松使用PNG、JPG等普通图片高保真转SVG图Android 完美的蒙层方案Android MMKV框架引入使用强大无匹的自定义下拉列表Google Protobuf 实践使用开发 Protobuf 实践使用前言Protobuf基本介绍Protobuf 使用配置protobuf 基本语法1. 基本使用2. …

JavaWeb-Ajax

JavaWeb-Ajax 3&#xff0c;Ajax 3.1 概述 AJAX (Asynchronous JavaScript And XML)&#xff1a;异步的 JavaScript 和 XML。 我们先来说概念中的 JavaScript 和 XML&#xff0c;JavaScript 表明该技术和前端相关&#xff1b;XML 是指以此进行数据交换。 3.1.1 作用 AJAX…

用Python绘制傅里叶级数和泰勒级数逼近已知函数的动态过程

文章目录Taylor级数Fourier级数本文代码&#xff1a; Fourier级数和Taylor级数对原函数的逼近动画Taylor级数 级数是对已知函数的一种逼近&#xff0c;比较容易理解的是Taylor级数&#xff0c;通过多项式来逼近有限区间内的函数&#xff0c;其一般形式为 f(x)∑n0Nanxnf(x)\su…