2-2 A pretrained model that fakes it until it makes it

news2025/1/10 1:22:51

1.The GAN game
Let’s suppose, for a moment, that we’re career criminals who want to move into selling forgeries of “lost” paintings by famous artists. We’re criminals, not painters, so as we paint our fake Rembrandts and Picassos, it quickly becomes apparent that they’re amateur imitations rather than the real deal. Even if we spend a bunch of time practicing until we get a canvas that we can’t tell is fake, trying to pass it off at the local art auction house is going to get us kicked out instantly. Even worse, being told “This is clearly fake; get out,” doesn’t help us improve! We’d have to randomly try a bunch of things, gauge which ones took slightly longer to recognize as forgeries, and emphasize those traits on our future attempts, which would take far too long.Instead, we need to find an art historian of questionable moral standing to inspect our work and tell us exactly what it was that tipped them off that the painting wasn’t legit. With that feedback, we can improve our output in clear, directed ways, until our sketchy scholar can no longer tell our paintings from the real thing.

In the context of deep learning, what we’ve just described is known as the GAN game, where two networks, one acting as the painter and the other as the art historian, compete to outsmart each other at creating and detecting forgeries. GAN stands for generative adversarial network, where generative means something is being created (in this case, fake masterpieces), adversarial means the two networks are competing to outsmart the other, and well, network is pretty obvious. These networks are one of the most original outcomes of recent deep learning research.

在这里插入图片描述

Figure 2.5 shows a rough picture of what’s going on. The end goal for the generator is to fool the discriminator into mixing up real and fake images. The end goal for the discriminator is to find out when it’s being tricked, but it also helps inform the generator about the identifiable mistakes in the generated images. At the start, the generator produces confused, three-eyed monsters that look nothing like a Rembrandt portrait. The discriminator is easily able to distinguish the muddled messes from the real paintings. As training progresses, information flows back from the discriminator, and the generator uses it to improve. By the end of training, the generator is able to produce convincing fakes, and the discriminator no longer is able to tell which is which.

2014 年来,生成对抗网络(Generative Adve-rsarial Network, GAN)的提出和改进使得用DNN(深度神经网络)合成逼真图像成为可能,相关方法的性能稳步提高,合成图像的视觉效果已经达到了肉眼难辨真伪的程度,吸引了学术界和工业界的广泛关注。

例如人脸年龄编辑技术旨在针对人脸图像的年龄特征进行逼真且准确地修改(如图 1 所示),在数字娱乐和公共安全等与人们日常生活息息相关的场景中发挥着关键作用。举例来说,风靡全球的人脸图像编辑软件 FaceApp就能够对任意人脸图像的年龄特征进行调整,具有很强的娱乐效果,在推特和脸书等众多社交平台上引起了极为广泛的讨论。该应用也因此于 2019 年荣登 IOS 和安卓平台应用下载量榜首,并至今活跃于数字内容创作相关的各种任务中。除此以外,人脸年龄编辑还可以根据旧时照片对走失儿童或者通缉人员的当前样貌进行预测,改善跨年龄人脸识别系统的精度,从而提升破案效率。

在这里插入图片描述

2.CycleGAN
A CycleGAN can turn images of one domain into images of another domain (and back), without the need for us to explicitly provide matching pairs in the training set.

As the figure shows, the first generator learns to produce an image conforming to a target distribution (zebras, in this case) starting from an image belonging to a different distribution (horses), so that the discriminator can’t tell if the image produced from a horse photo is actually a genuine picture of a zebra or not.
在这里插入图片描述

At the same time—and here’s where the Cycle prefix in the acronym comes in—the resulting fake zebra is sent through a different generator going the other way, to be judged by another discriminator on the other side.

在这里插入图片描述

So, we have a CycleGAN workflow for the task of turning a photo of a horse into a zebra, and vice versa. Note that there are two separate generator networks, as well as two distinct discriminators.

在这里插入图片描述
3.Let’s try to do it
首先导入两个类(了解即可,注释仅供参考)

ResNetBlock实现了ResNet中的一个残差块(Residual Block),目的是通过跨层连接(shortcut connection)来实现信息的直接传递,从而有助于减轻梯度消失问题和训练深层网络。

class ResNetBlock(nn.Module):
    def __init__(self, dim):
        super(ResNetBlock, self).__init__() # 调用父类的__init__方法进行初始化
        self.conv_block = self.build_conv_block(dim) # 调用build_conv_block方法来构建残差块的卷积块部分,将其赋值给self.conv_block属性

    def build_conv_block(self, dim): # 构建残差块的卷积块部分
        conv_block = [] # 创建一个空的列表conv_block
        conv_block += [nn.ReflectionPad2d(1)] # 在卷积块之前添加一个反射填充层,进行1个像素的反射填充

# 添加一个卷积层、实例归一化层和ReLU激活函数层。卷积层使用3x3的卷积核进行卷积操作,dim表示输入和输出的通道数。实例归一化层用于归一化卷积层的输出,并提供额外的正则化效果。ReLU激活函数层用于引入非线性。
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]
                       
        conv_block += [nn.ReflectionPad2d(1)] # 再次添加一个反射填充层,进行1个像素的反射填充
        
# 添加另一个卷积层和实例归一化层,但不再使用激活函数
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]
                       
        return nn.Sequential(*conv_block) # 将列表中的层组合成一个顺序的序列
        
    def forward(self, x): # 定义了前向传播的过程
        out = x + self.conv_block(x) # 将输入x传入残差块的卷积块部分进行处理,得到卷积块的输出,然后和x相加得到out
        return out

基于残差网络的生成器模型,将输入图像转换为特定领域的输出图像,通过残差块和跳跃连接,可以增强模型的特征表达能力和图像细节保留能力,进而提高生成图像的质量和准确性

class ResNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): #定义了生成器模型的结构
        assert(n_blocks >= 0) # 确保n_blocks的值大于等于0,使残差网络中的残差块数量是合理的
        super(ResNetGenerator, self).__init__() # 确保正确地初始化生成器模型
        # 将传入的输入通道数、输出通道数和每个卷积层的通道数存储到三个参数中
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        # 创建一个列表model,用于存储生成器模型的层
        model = [nn.ReflectionPad2d(3), # 进行3个像素的反射填充
       			 # 进行7x7的卷积操作,卷积核大小为7,填充为0,输出通道数为ngf
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True), 
                 # 经过nn.InstanceNorm2d和ReLU激活函数
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]
# 使用循环添加了若干个下采样块。每个下采样块由nn.Conv2d、nn.InstanceNorm2d和ReLU激活函数组成,其中卷积操作的步长为2,实现了图像的尺寸减半。
        n_downsampling = 2 # 下采样的层数
# 在每次循环中,计算mult的值,它是一个倍数因子,用于确定当前卷积层的输入通道数和输出通道数之间的倍数关系。在每个下采样步骤中,输入通道数ngf * mult通过nn.Conv2d进行卷积操作,卷积核大小为3x3,步长为2,填充为1,输出通道数为ngf * mult * 2。接着,通过nn.InstanceNorm2d进行实例归一化,然后使用ReLU激活函数。
        for i in range(n_downsampling): # 对输入特征图进行多次下采样操作
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]
        mult = 2**n_downsampling
# 使用循环添加了n_blocks个残差块,每个残差块由ResNetBlock组成,其中ngf * mult表示残差块的输入和输出通道数。
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]
# 使用循环添加了若干个上采样块。每个上采样块由nn.ConvTranspose2d、nn.InstanceNorm2d和ReLU激活函数组成,其中卷积操作的步长为2,实现了图像的尺寸增加。
# 通过nn.ConvTranspose2d进行反卷积操作,卷积核大小为3x3,步长为2,填充为1,输出通道数为ngf * mult / 2。输出通道数是下采样层输出通道数的一半。接着,通过nn.InstanceNorm2d进行实例归一化,然后使用ReLU激活函数。
#反卷积层的输出通道数逐渐减小,而空间尺寸逐渐增大,实现了图像的上采样。
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]
# 通过nn.ReflectionPad2d进行3个像素的反射填充,使用nn.Conv2d进行7x7的卷积操作,输出通道数为output_nc,最后经过nn.Tanh激活函数
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]
        self.model = nn.Sequential(*model)
# 将输入数据传入生成器模型的self.model中进行前向传播,得到生成的输出。              
    def forward(self, input): 
        return self.model(input)

创建ResNetGenerator的实例对象netG

netG = ResNetGenerator()

We can load those into ResNetGenerator using the model’s load_state_dict method:

model_path = 'D:/Deep-Learning/资料/dlwpt-code-master/data/p1ch2/horse2zebra_0.4.0.pth' # 预训练模型的存储位置
model_data = torch.load(model_path) # 加载模型参数文件
netG.load_state_dict(model_data) # 将加载的模型参数数据加载到netG模型中,使netG模型的参数与预训练模型一致

在这里插入图片描述
At this point, netG has acquired all the knowledge it achieved during training.

Let’s put the network in eval mode, as we did for resnet101:

netG.eval()

It takes an image, recognizes one or more horses in it by looking at pixels, and individually modifies the values of those pixels so that what comes out looks like a credible zebra.

在这里插入图片描述

First, we need to import PIL and torchvision. Then we define a few input transformations to make sure data enters the network with the right shape and size:

from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256), # 将图像的大小调整为256x256像素,保持图像的长宽比不变
                                 transforms.ToTensor()]) # 将图像转换为Tensor对象

Let’s open a horse file

img = Image.open('D:/Deep-Learning/资料/dlwpt-code-master/data/p1ch2/horse.jpg')
img

在这里插入图片描述
Anyhow, let’s pass it through preprocessing and turn it into a properly shaped variable:

img_t = preprocess(img) # 按照预处理操作序列进行处理
batch_t = torch.unsqueeze(img_t, 0) # unsqueeze增维
batch_out = netG(batch_t) # 将输入的图像批次数据传递给netG模型进行前向传播,即生成器的推断过程

在这里插入图片描述

在这里插入图片描述

batch_out is now the output of the generator, which we can convert back to an image:

out_t = (batch_out.data.squeeze() + 1.0) / 2.0 # squeeze降维,然后进行缩放和平移操作,以使其像素值范围在0到1之间
out_img = transforms.ToPILImage()(out_t) # 将处理后的图像数据转换为PIL图像对象,以便进行可视化
out_img

在这里插入图片描述

The generator has learned to produce an image that would fool the discriminator into thinking that was a zebra, and there was nothing fishy about the image.

总结

class ResNetBlock(nn.Module):
    def __init__(self, dim):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim)

    def build_conv_block(self, dim):
        conv_block = []
        conv_block += [nn.ReflectionPad2d(1)]
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim),
                       nn.ReLU(True)]
        conv_block += [nn.ReflectionPad2d(1)]

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                       nn.InstanceNorm2d(dim)]
        return nn.Sequential(*conv_block)
        
    def forward(self, x):
        out = x + self.conv_block(x)
        return out

class ResNetGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
        assert(n_blocks >= 0)
        super(ResNetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                 nn.InstanceNorm2d(ngf),
                 nn.ReLU(True)]
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=True),
                      nn.InstanceNorm2d(ngf * mult * 2),
                      nn.ReLU(True)]
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=True),
                      nn.InstanceNorm2d(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, input): # <3>
        return self.model(input)
netG = ResNetGenerator()
model_path = 'D:/Deep-Learning/资料/dlwpt-code-master/data/p1ch2/horse2zebra_0.4.0.pth' 
model_data = torch.load(model_path) 
netG.load_state_dict(model_data)
netG.eval()
from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256), 
                                 transforms.ToTensor()]) 
img = Image.open('D:/Deep-Learning/资料/dlwpt-code-master/data/p1ch2/horse.jpg') # 选择图片
img_t = preprocess(img) 
batch_t = torch.unsqueeze(img_t, 0) 
batch_out = netG(batch_t) 
out_t = (batch_out.data.squeeze() + 1.0) / 2.0 
out_img = transforms.ToPILImage()(out_t) 
out_img # 输出

在这里插入图片描述
若需更换图片只需要改变路径即可

在这里插入图片描述
在这里插入图片描述

资料下载:
提取码:t0is
模型:horse2zebra_0.4.0.pth
马图片:horse.jpg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/673927.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【系统开发】尚硅谷 - 谷粒商城项目笔记(九):Sentinel高并发方法论

文章目录 Sentinel高并发方法论Sentinel是什么&#xff1f;基本概念及作用整合SpringBoot引入Sentinel依赖下载Sentinel控制台cmd启动Sentinel配置Sentinel控制台地址信息在控制台设置流控规则规则持久化流量监控自定义流控返回数据适配Feign&#xff0c;并加上熔断保护方法自定…

CorelDRAW2023最新工作室Win版矢量图形编辑与排版工具软件

CorelDRAW简称CDR,是一款专业的平面设计软件,专注于矢量图形编辑与排版。ps和cdr是常用的两款平面设计软件&#xff0c;一直以来深受广大设计师的喜爱&#xff0c;但仍然有很多小伙伴不知道它们之间的区别。那初学者先学ps还是cdr&#xff0c;初学者cdr哪个版本好&#xff1f;我…

【openGauss高级数据管理】--略有小成

【openGauss高级数据管理】--略有小成 &#x1f53b; 一、openGauss高级数据管理&#x1f530; 1.1 约束⛳ 1.1.1 NOT NULL约束⛳ 1.1.2 UNIQUE约束⛳ 1.1.3 PRIMARY KEY⛳ 1.1.4 FOREIGN KEY⛳ 1.1.5 CHECK约束 &#x1f530; 1.2 JOIN⛳ 2.3.1 CROSS JOIN---交叉连接⛳ 1.2.1…

基于spss的多元统计分析 之 实例2(挤压塑料胶卷的最优工艺研究)(7/8)

主成分分析 摘要 主成分分析是利用降维的思想&#xff0c;把多指标转化为少数几个综合指标的多元统计分析方法。 对同一个个体进行多项观察时&#xff0c;必定涉及多个随机变量&#xff0c;它们都是相关的&#xff0c;一时难以综合。这时就需要进行主成分分析来概括诸多信息…

7.延时消息与原理探究

highlight: arduino-light 4.3 延时消息 延迟消息对应的Topic是SCHEDULETOPICXXXX,注意就是SCHEDULETOPICXXXX,XXXX不是某某某的意思。 SCHEDULETOPICXXXX的队列名称是从2开始到17&#xff0c;对应的delayLevel为3到18&#xff0c;3对应10s&#xff0c;18对应2h&#xff0c;在类…

因为计算机中丢失mfc140.dll无法启动修复步骤分享

计算机报错提示mfc140.dll无法启动是怎么回事&#xff1f;mfc140.dll是什么文件&#xff0c;为什么会影响到软件程序的运行?相信你也有不少困惑&#xff0c;遇到这个情况不用慌&#xff0c;小编下面就分享关于mfc140.dll丢失的详细修复步骤以及mfc140.dll是什么。 mfc140.dll是…

java中集合类forEach删除元素报错:ConcurrentModificationException

如题所示&#xff0c;我们在java开发中&#xff0c;可能会有这样的一种情况&#xff0c;一个集合使用完了&#xff0c;我们想删除里面所有的元素&#xff0c;可能会遍历他们&#xff0c;然后依次调用删除操作。最简单的我们使用forEach遍历。 示例如下&#xff1a; public cla…

EasyCode代码生成插件-模板分享(基于数据表生成MyBatisPlus格式的dao,service,controller和vue组件)

目录 概述 使用演示 模板代码 实体类pojo 表现层controller 业务层service接口 业务层serviceImpl实现类 持久层dao Vue组件 概述 本片博客用于分享EasyCode的自定义模板&#xff08;模板在篇末&#xff09;&#xff0c;用于简化开发&#xff0c;免去重复性的工作。 …

SQL 基础语句

SQL 基础语句 DDL Data Definition Language 数据定义语言创建 create删除 drop修改 alter清空 truncate show tables ; --查看所有表&#xff1a; drop database db1; --删除数据库 create database db1 default character set utf8; --创建数据库 use databas…

【统信uos-server-20-1060e】-详细安装openGauss

【统信uos-server-20-1060e】-详细安装openGauss &#x1f53b; 前言&#x1f53b; 一、安装前准备&#x1f530; 1.1 openGauss安装包下载&#x1f530; 1.2 安装环境准备⛳ 1.2.1 硬件环境要求⛳ 1.2.2 软件环境要求⛳ 1.2.3 软件依赖要求⛳ 1.2.4 关闭操作系统防火墙、selin…

Redis 2023面试5题(四)

一、AOF 持久化&#xff08;Append Only File&#xff09;如何配置&#xff1f; AOF&#xff08;Append Only File&#xff09;持久化是 Redis 的一种持久化方式&#xff0c;它通过记录所有收到的写命令来保存数据。以下是一些关于如何配置 AOF 持久化的重要信息&#xff1a; …

Linux系统下使用移动硬盘或者U盘,如何挂载硬盘分区到Linux系统

本文目录 1、查看当前磁盘分区状态2、查看当前磁盘的挂载状态3、将磁盘挂载到指定目录下4、从文件系统里卸载磁盘 Linux系统里&#xff0c;除根目录以外&#xff0c;任何文件或者目录要想被访问&#xff0c;需要将其“关联”到根目录下的某个目录来实现&#xff0c;这种关联操作…

网络安全等级保护2.0 | 等保合规5件事

网络安全等级保护工作包括定级、备案、安全建设、等级测评、监督检查五个阶段。 1、定级 确认定级对象&#xff0c;参考《定级指南》等初步确认等级&#xff0c;组织专家评审&#xff0c;主管单位审核&#xff0c;公安机关备案审查。 备案 持定级报告和备案表等材料到公安机…

一文读懂openguass dcf网络模块

一文读懂openguass dcf网络模块 文章目录 一文读懂openguass dcf网络模块0. mec概要1. compress2. mec2.1 agent2.1.1 初始化agent2.1.2 agent执行 2.2 channel2.2.1 初始化channel2.2.2 连接channel 2.3 api2.4 func2.5 queue2.5.1 初始化2.5.2 运行2.5.1.1 接收消息入队2.5.1…

基于spss的多元统计分析 之 实例3(血压、胆固醇于心脏病关系的研究)(8/8)

血压、胆固醇于心脏病关系的研究 摘要 一般线性模型中的一种&#xff0c;即反应变量 (dependent variables)为二分类变量的回归分析&#xff0c;模型输出为变量取特定值的概率。 在进行二元Logistic回归分析时&#xff0c;通常会涉及3个步骤&#xff0c;分别是数据处理、卡方分…

自动化运维管理工具——Ansible

目录 一、概述 &#xff08;一&#xff09;特点 &#xff08;二&#xff09;工作特性 二、运行机制 三、安装 &#xff08;一&#xff09;配置源 &#xff08;二&#xff09;安装ansible &#xff08;三&#xff09;查看相关文件 &#xff08;四&#xff09;配置文件 …

如何统计网页访问量

目录 一、搭建Nginx服务 安装Nginx服务 第一步 关闭防火墙和安全机制 第二步 安装扩展包 第三步 安装Nginx和依赖环境 第四步 安装依赖包 第五步 创建一个用户和组 第六步 解包 第七步 进入Nginx目录下编译安装 第八步 进行编译 第九步 添加系统识别操作 第十步 检…

跟朋友撞offer怎么办?接了offer,下个月入职,结果老板面了我朋友,她已经入职了,我的offer还算数吗?...

职场上什么奇葩事都可能发生&#xff0c;跟朋友撞了offer是什么感受&#xff1f; 一位网友求助&#xff1a; 接了offer&#xff0c;正在和现公司谈判离职&#xff0c;下个月才能入职。结果老板面了其他人&#xff0c;正好是楼主认识的人&#xff0c;比楼主大十几岁。更尴尬的是…

浅谈C++|引用篇

目录 引入 一.引用的基本使用 (1)引用的概念&#xff1a; (2)引用的表示方法 (3)引用注意事项 (4)引用权限 二.引用的本质 三.引用与函数 (1)引用做函数参数 (2)引用做函数返回值 四.常量引用 五.引用与指针 引入 绰号&#xff0c;又称外号&#xff0c;是人的本名以外…

基于深度学习的目标检测的介绍(Introduction to object detection with deep learning)

物体检测的应用已经深入到我们的日常生活中&#xff0c;包括安全、自动车辆系统等。对象检测模型输入视觉效果(图像或视频)&#xff0c;并在每个相应对象周围输出带有标记的版本。这说起来容易做起来难&#xff0c;因为目标检测模型需要考虑复杂的算法和数据集&#xff0c;这些…