使用pytorch构建GAN模型的评估

news2024/12/23 17:11:07

本文为此系列的第六篇对GAN的评估,上一篇为Controllable GAN。文中使用训练好的分类模型的部分网络提取特征将真实分布与生成分布进行对比来评估模型的好坏,若有不懂的无监督知识点可以看本系列第一篇。

原理

1.评估模型的指标
一般来说,我们评估模型的好坏可以通过对测试集的错误率来体现:比如图像分类我们可以统计几张分错几张分对来量化错误率、目标检测我们可以通过比对每个框得到mAP从而量化错误率…但是我们怎么通过生成的图像来评估GAN的好坏呢?
在这里插入图片描述
我们总不能说,生成的某一个像素要更绿色一点比较好,或者某个像素要更黄色一点比较好吧?
先进行概括一下,全文主要围绕着生成质量(保真度fidelity)、多样性(diversity)进行讲解。
在这里插入图片描述
2. 图像对比有两种方法,pixel distance、feature distance。
第一种像素对比,直接做相减运算。这样做的缺点是尽管两张图片可能非常相似,但是每个像素的像素值会有一些细微的差异,即使我们肉眼看不出来,最终的差值也会非常大,太过于关注细节。
在这里插入图片描述
第二种则是特征对比,通俗的说是成片的像素区域进行对比是否相似,这样的对比更符合我们人眼观察标准。
在这里插入图片描述
那么,接下来的问题就是如何进行特征提取。
3. 特征提取的方法
我们训练好的分类器是一个很好的特征提取器,比如我们训练了一个识别猫狗的分类器,那它必然是学习到了猫狗的特征才会对他们进行分类。
在这里插入图片描述
直接将分类部分的最后一层分类层去掉,其余的都是对我们有价值的。我们一般选择的是连接最后一个全连接层的池化层作为输出特征的层,我们成为特征层,输出的特征我们称为embedding。
选择这个位置并不固定,只是选择的位置越后面,每个单元的感受野越大,所包含的信息就越多,更符合我们的要求。很前面的层获取到的特征可能只是一横或者一竖或者一个弧度等。

  • 我们使用Inception v3作为我们的特征提取器,Inception使用超1400万张图片、2万多类别的ImageNet数据库作为训练集。提取详细流程如图:
    在这里插入图片描述

对总的概括可以概括为一下流程:
在这里插入图片描述
最终我们就是对真实数据提取的特征于生成数据提取的特征进行对比。
4. Frechet Inception Distance(FID)
我们使用FID来量化真假特征的差异。
通俗来说Frechet Distance是用来衡量两条曲线之间的的最小距离,比如人狗同时走所需的最短牵引绳的长度。
在这里插入图片描述
严格来说,Frechet Distance是衡量两个分布之间的差异。
在这里插入图片描述
①我们可以使用以下公式来表示两个单维正态分布的Frechet Distance:
在这里插入图片描述
分别从真实数据和生成数据里面提取大量的特征,分别作为真实特征分布于生成特征分布,计算出各自的均值和标准差即可计算出真假之间的差值。
②两个多变量正态分布的Frechet Distance
我们可以为每个维度提供一个单变量的正态分布,假设是两个变量的(便于举例),如图:
在这里插入图片描述

协方差矩阵:
比如(x1,x2)代表第一变量的正态分布的随机变量与第二正态分布的随机变量之间的协方差。非对角线元素代表不同变量之间的协方差,即不同变量之间的相关性。若两个变量变化趋势一致则协方差为正值,反之负值,若没有线性关系则为0。上图就代表两个变量之间相互不影响相互独立,下图代表两变量之间负相关;
比如(x1,x1)代表第一变量的正态分布的方差。对角线元素代表每个变量分布的方差,即每个变量本身的变化程度。
在这里插入图片描述
由此可以计算我们的多变量正态分布之间的Frechet Distance,可以将单维正态分布之间的Frechet Distance公式展开进行对比发现他们之间其实是相似的:
在这里插入图片描述
Tr运算为矩阵的对角线元素之和,例如上面那个负相关的协方差矩阵的Tr运算结果为2+2=4。
将多变量正态分布之间的Frechet Distance应用于真假特征的分布就是FID了:
在这里插入图片描述
FID越小,就代表着真假分布就越接近,那么GAN就越好。

代码

import torch
import numpy as np
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import CelebA
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for our testing purposes, please do not change!

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=3, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

z_dim = 64
image_size = 299
device = 'cuda'

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = CelebA(".", download=True, transform=transform)

gen = Generator(z_dim).to(device)
gen.load_state_dict(torch.load(f"pretrained_celeba.pth", map_location=torch.device(device))["gen"])
gen = gen.eval()

from torchvision.models import inception_v3
inception_model = inception_v3(pretrained=False)
inception_model.load_state_dict(torch.load("inception_v3_google-1a9a5a14.pth"))
inception_model.to(device)
inception_model = inception_model.eval() # Evaluation mode

inception_model.fc = torch.nn.Identity()

from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal
    [[1, 0],
     [0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

mean = torch.Tensor([0, 0])
covariance = torch.Tensor(
    [[2, -1],
     [-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

import scipy
def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)
    
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):
    return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

def preprocess(img):
    img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)
    return img

import numpy as np
def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

fake_features_list = []
real_features_list = []

n_samples = 512 # The total number of samples
batch_size = 4 # Samples per iteration

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True)

cur_samples = 0
with torch.no_grad(): # You don't need to calculate gradients here, so you do this to save memory
    try:
        for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
            real_samples = real_example
            real_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPU
            real_features_list.append(real_features)

            fake_samples = get_noise(len(real_example), z_dim).to(device)
            fake_samples = preprocess(gen(fake_samples))
            fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')
            fake_features_list.append(fake_features)
            cur_samples += len(real_samples)
            if cur_samples >= n_samples:
                break
    except:
        print("Error in loop")

fake_features_all = torch.cat(fake_features_list)
real_features_all = torch.cat(real_features_list)

mu_fake = fake_features_all.mean(0)
mu_real = real_features_all.mean(0)
sigma_fake = get_covariance(fake_features_all)
sigma_real = get_covariance(real_features_all)

indices = [2, 4, 5]
fake_dist = MultivariateNormal(mu_fake[indices], sigma_fake[indices][:, indices])
fake_samples = fake_dist.sample((5000,))
real_dist = MultivariateNormal(mu_real[indices], sigma_real[indices][:, indices])
real_samples = real_dist.sample((5000,))

import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()

with torch.no_grad():
    print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

代码中使用的生成器模型可以从上一篇当中下载,inception_v3_google-1a9a5a14.pth模型可以从这里下载。

代码解析

  • 去掉分类层
inception_model.fc = torch.nn.Identity()

将最后一层的全连接层替换为恒等函数,它将输入的数据不做任何操作、原封不动地输出。
通常Inception模型的全连接层用于图像分类任务,它将提取的特征映射到类别预测上。然而我们不需要进行图像分类,而是想要利用Inception模型的前面部分来提取图像的特征。
这样就将Inception模型从原始的分类任务模型转变为一个特征提取器,从而不再执行图像分类任务,而是将图像转换为特征向量。

  • 可视化多变量正态分布
from torch.distributions import MultivariateNormal
import seaborn as sns # This is for visualization
mean = torch.Tensor([0, 0]) # Center the mean at the origin
covariance = torch.Tensor( # This matrix shows independence - there are only non-zero values on the diagonal
    [[1, 0],
     [0, 1]]
)
independent_dist = MultivariateNormal(mean, covariance)
samples = independent_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

mean = torch.Tensor([0, 0])
covariance = torch.Tensor(
    [[2, -1],
     [-1, 2]]
)
covariant_dist = MultivariateNormal(mean, covariance)
samples = covariant_dist.sample((10000,))
res = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind="kde")
plt.show()

首先定义均值和协方差矩阵(原理中举的两个例子),然后使用MultivariateNormal构建一个多变量正态分布对象covariant_dist。然后从这个分布中抽取了10000个样本,每个样本是一个shape为(samples, 2)的二维向量。最后将生成的样本可视化为二维核密度估计图(Kernel Density Estimate,KDE)。
在这里插入图片描述
在这里插入图片描述

  • 计算矩阵的平方根
def matrix_sqrt(x):
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)

首先将输入矩阵转移到CPU上并将其转换为NumPy数组。这是因为scipy.linalg.sqrtm函数只能接受NumPy数组作为输入,不能接受PyTorch张量,且在CPU上计算更高效。
然后使用scipy.linalg.sqrtm函数计算平方根且返回一个复数矩阵,所以需要取其实部(real)部分,然后再转换为PyTorch张量。同时,函数还会确保新的张量与输入矩阵在相同的设备(device)上。

  • 计算FID
def frechet_distance(mu_x, mu_y, sigma_x, sigma_y):
    return (mu_x - mu_y).dot(mu_x - mu_y) + torch.trace(sigma_x) + torch.trace(sigma_y) - 2*torch.trace(matrix_sqrt(sigma_x @ sigma_y))

给定两个分布的均值和协方差矩阵,利用原理中的公式进行计算。

  • 对生成图像进行处理
def preprocess(img):
    img = torch.nn.functional.interpolate(img, size=(299, 299), mode='bilinear', align_corners=False)
    return img

将输入的图像进行插值操作,插值方法使用双线性插值,参数align_corners=False指示在进行插值操作时不对齐图像的角点,这在图像处理中常用于避免不必要的插值偏差。
在这里插入图片描述

  • 计算协方差矩阵
def get_covariance(features):
    return torch.Tensor(np.cov(features.detach().numpy(), rowvar=False))

使用NumPy的np.cov()函数计算特征向量集合的协方差矩阵,rowvar=False参数表示传递的数据中每一列代表一个特征向量的观测值,而不是每一行代表一个观测样本。

  • 提取特征
for real_example, _ in tqdm(dataloader, total=n_samples // batch_size): # Go by batch
    real_samples = real_example
    real_features = inception_model(real_samples.to(device)).detach().to('cpu') # Move features to CPU
    real_features_list.append(real_features)

    fake_samples = get_noise(len(real_example), z_dim).to(device)
    fake_samples = preprocess(gen(fake_samples))
    fake_features = inception_model(fake_samples.to(device)).detach().to('cpu')
    fake_features_list.append(fake_features)
    cur_samples += len(real_samples)
    if cur_samples >= n_samples:
        break

使用预训练的Inception模型提取真实图像和生成图像的特征,并将这些特征存储在列表中,以备后续计算Fréchet Distance。
在这里需要对生成的图像进行preprocess()处理为299的宽高是因为真实数据的宽高为299,而生成数据的宽高为64。
我们可以将生成数据和preprocess处理后的数据显示出来看效果:

import matplotlib.pyplot as plt

# 选择其中一个样本进行显示
sample_index = 0

# 显示生成图像
fake_image = fake[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()

# 显示经过处理的图像
fake_image = fake_samples[sample_index].permute(1, 2, 0)  # 将张量形状转换为图像的形状(C, H, W)->(H, W, C)
plt.imshow(fake_image)
plt.axis('off')
plt.show()

在这里插入图片描述
在这里插入图片描述
可以看到插值操作后平滑很多。

  • 可视化真实数据分布与生成数据分布,并计算FID
indices = [2, 4, 5]
import pandas as pd
df_fake = pd.DataFrame(fake_samples.numpy(), columns=indices)
df_real = pd.DataFrame(real_samples.numpy(), columns=indices)
df_fake["is_real"] = "no"
df_real["is_real"] = "yes"
df = pd.concat([df_fake, df_real])
sns.pairplot(df, plot_kws={'alpha': 0.1}, hue='is_real')
plt.show()

with torch.no_grad():
    print(frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake).item())

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1620711.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多级嵌套对象数组:根据最里层id找出它所属的每层父级,适用于树形数据格式

文章目录 需求分析 需求 已知一个树形格式数据如下: // 示例数据 const data [{"id": "1","parentId": null,"children": [{"id": "1.1","parentId": "1","children"…

粒子群算法与优化储能策略python实践

粒子群优化算法(Particle Swarm Optimization,简称PSO), 是1995年J. Kennedy博士和R. C. Eberhart博士一起提出的,它是源于对鸟群捕食行为的研究。粒子群优化算法的基本核心是利用群体中的个体对信息的共享从而使得整个群体的运动…

Wireshark数据包分析入门

Wireshark数据包分析 1. 网络协议基础1.1. 应传网数物(应表会传网数物) 2. 三次握手2.1. 第一次握手2.2. 第二次握手2.3. 第三次握手2.4. 三次握手后流量特征 3. 第一层---物理层(以太网)4. 第二层---数据链路层(PPP L…

139GB,台北倾斜摄影OSGB数据V0.1版

本月初发布了谷歌倾斜摄影数据OSGB转换工具V0.2版(更新!谷歌倾斜摄影转换生成OSGB瓦片V0.2版),并免费分享了基于V0.2版转换工具生产的澳门地区OSGB数据(首发!澳门地区OSGB数据V0.2版免费分享),V0.2版本在生产速度、显示效率和OSGB数据轻量化方面进行了优…

软考:高级系统架构师案例必备概念

根据2013年-2023年真题整理 必背案例概念 软件架构风格 软件架构风格是指描述特定软件系统组织方式的惯用模式。 组织方式描述了系统的组成构件和这些构件的组织方式。 惯用模式则反映众多系统共有的结构和语义。 架构风险 架构风险是指架构设计中潜在的、存在问题的架构…

零碳家庭 “光”的力量

有行业专家乐观预测,在供给充足、基础设施建设与时俱进的情况下,2025年,我国新能源汽车市场的占有率将会达到50%,2030年更有望突破90%的大关。为了方便新能源汽车的出行,在家中安装一个智能充电桩是越来越多驾驶者的选…

计算机网络---第十一天

生成树协议 stp作用: 作用:stp用于解决二层环路问题。 BPDU: 含义:桥协议数据单元,用于传递stp协议相关报文 分类:配置bpdu---用于传递stp的配置信息 tcn bpdu---用于通告拓扑变更信息 包含信息&…

基于SpringBoot的智慧物业管理设计与实现论文

摘  要 随着我国发展和城市开发,物业管理已形成规模,其效益也越来越明显。在经济效益对地方政府而言,主要体现为:减少了大量的财政补贴,对住宅区开发企业而言,能提高物业市场竞争力,使开发企…

系统思考—啤酒游戏

最近有不少的合作伙伴来询问我啤酒游戏这个来自于MIT(麻省理工学院)经典的沙盘,上周刚刚结束Midea旗下的一家公司市场运营部《啤酒游戏沙盘-应对动态性复杂的系统思考智慧》的课程。 参与这次沙盘体验的团队成员深刻体会到了全局思考的重要性…

dtc、fdtdump、fdtget、fdtput、convert-dtsv0

目录标题 1. dtc(Device Tree Compiler)2. fdtdump3. fdtget4. fdtput5. convert-dtsv0 dtc、fdtdump、fdtget、fdtput、convert-dtsv0这些工具都与Linux设备树(Device Tree)的处理有关。 设备树是一种数据结构,用于描…

JavaSE学习文档(上)

JavaSE学习文档 第一章 Java概述1.2 计算机编程语言1.3 Java语言版本概述1.4 Java语言分类1.5 JDK,JRE,JVM的关系1.6 JDK安装1.7 DOS命令1.8 Java程序执行过程1.9 编写HelloWorld1.10 常见错误1.11 编写程序时要注意的点 第二章 Java基础语法2.1 Java中的注释文档注释 2.2 关键…

免杀技术之白加黑的攻击防御

一、介绍 1. 什么是白加黑 通俗的讲白加黑中的白就是指被杀软列入到可信任列表中的文件。比如说微软自带的系统文件或者一些有有效证书签名的文件,什么是微软文件,或者什么是有效签名文件在后面我们会提到他的辨别方法。黑就是指我们自己的文件,没有有…

【办公类-26-01】20240422 UIBOT网络教研(自动登录并退出多个账号,半自动半人工)

作品展示: 背景需求: 每学期有多次网络教研 因为我有历任搭档的进修编号和登录密码, 所以每次学习时,我会把历任搭档的任务也批量完成。 但是每次登录都要从EXCEL里复制一位老师的“进修编号”“密码”,还要点击多次…

53.基于微信小程序与SpringBoot的戏曲文化系统设计与实现(项目 + 论文)

项目介绍 本站采用SpringBoot Vue框架,MYSQL数据库设计开发,充分保证系统的稳定性。系统具有界面清晰、操作简单,功能齐全的特点,使得基于SpringBoot Vue技术的戏曲文化系统设计与实现管理工作系统化、规范化。 技术选型 后端:…

最快2周录用!多领域EI,征稿范围广!各指标优秀!

计算机工程类EI(最快2周录用) 【期刊简介】最新EI期刊目录内源刊 【检索情况】EI&Scopus双检 【版面情况】仅10篇版面 【年发文量】60篇左右 【国人占比】约13% 【收录年份】2009年被EI数据库收录 【审稿周期】预计1个月左右录用 【征稿领域…

aqs 条件队列和同步队列、独占模式和共享模式

同步/条件队列 先上代码 import java.util.LinkedList; import java.util.Queue; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock;public class ProducerConsumerExample {private static final int CAPACITY 5;private fi…

二维码存储图片如何实现?相册二维码的制作技巧

如何将照片生成二维码后存储展示?现在很多人会将图片生成二维码以后,用于分享或者储存的用途,减少个人内存的占用量,而且分享照片也会更加的方便,只需要扫描二维码就可以让其他人查看图片。 想要制作图片二维码的步骤…

【C++题解】1033. 判断奇偶数

问题:1033. 判断奇偶数 类型:分支 题目描述: 输入一个整数,判断是否为偶数。是输出 y e s ,否则输出n o。 输入: 输入只有一行,包括 1 个整数(该整数在 1∼10000 的范围内&#…

【算法刷题 | 贪心算法02】4.24(摆动序列)

文章目录 3.摆动序列3.1题目3.2解法:贪心3.2.1贪心思路3.2.2代码实现 3.摆动序列 3.1题目 如果连续数字之间的差严格地在正数和负数之间交替,则数字序列称为 摆动序列 。 第一个差(如果存在的话)可能是正数或负数。仅有一个元素…

python爬虫 - 爬取html中的script数据(zum.com新闻信息 )

文章目录 1. 分析页面内容数据格式2. 使用re.findall方法,编写爬虫代码3. 使用re.search 方法,编写爬虫代码 1. 分析页面内容数据格式 (1)打开 https://zum.com/ (2)按F12(或 在网页上右键 --…