【信号处理】基于变分自编码器(VAE)的图片典型增强方法实现

news2024/10/11 18:20:54

关于

深度学习中,经常面临图片数据量较小的问题,此时,对数据进行增强,显得比较重要。传统的图片增强方法包括剪切,增加噪声,改变对比度等等方法,但是,对于后端任务的性能提升有限。所以,变分自编码器被用来实现深度数据增强。

变分自编码器的主要缺点在于生成图像过于平滑和模糊,图像细节重建不足。

常见的图像增强方法:https://www.tensorflow.org/tutorials/images/data_augmentation

工具

数据集下载地址: CIFAR-10 and CIFAR-100 datasets

方法实现

加载数据和必要的库函数
import tensorflow.compat.v1.keras.backend as K
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
import tensorflow_datasets as tfds
import keras
from keras.models import Model
from keras.layers import Conv2D, Conv2DTranspose, Input, Flatten, Dense, Lambda, Reshape


xtrain , ytrain = tfds.as_numpy(tfds.load('cifar10',split='train',batch_size=-1,as_supervised=True,))
xtest , ytest = tfds.as_numpy(tfds.load('cifar10',split='test',batch_size=-1,as_supervised=True,))
xtrain = (xtrain.astype('float32'))/255
xtest = (xtest.astype('float32'))/255

height=32
width=32
channels=3
print(f"Train Shape: {xtrain.shape},Test Shape: {xtest.shape}")
plt.imshow(xtrain[0])

编码器模型搭建
input_shape=(height,width,channels)
latent_dims=3072

input_img= Input(shape=input_shape, name='encoder_input')
x=Conv2D(128, 4, padding='same', activation='relu',strides=2)(input_img)
x=Conv2D(256, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(512, 4, padding='same', activation='relu',strides=2)(x)
x=Conv2D(1024, 4, padding='same', activation='relu',strides=2)(x)
conv_shape = K.int_shape(x)
x=Flatten()(x)
x=Dense(3072, activation='relu')(x)
z_mean=Dense(latent_dims, name='latent_mean')(x)
z_sigma=Dense(latent_dims, name='latent_sigma')(x)

def sampler(args):
  z_mean, z_sigma = args
  eps = K.random_normal(shape=(K.shape(z_mean)[0], K.int_shape(z_mean)[1]))
  return z_mean + K.exp(z_sigma / 2) * eps


z = Lambda(sampler, output_shape=(latent_dims, ), name='z')([z_mean, z_sigma])

encoder = Model(input_img, [z_mean, z_sigma, z], name='encoder')
print(encoder.summary())

 解码器模型构建
decoder_input = Input(shape=(latent_dims, ), name='decoder_input')
x = Dense(conv_shape[1]*conv_shape[2]*conv_shape[3], activation='relu')(decoder_input)
x = Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = Conv2DTranspose(256, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(128, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(64, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(3, 3, padding='same', activation='relu',strides=(2, 2))(x)
x = Conv2DTranspose(channels, 3, padding='same', activation='sigmoid', name='decoder_output')(x)
decoder = Model(decoder_input, x, name='decoder')
decoder.summary()
z_decoded = decoder(z)

class CustomLayer(keras.layers.Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        
        # Reconstruction loss (as we used sigmoid activation we can use binarycrossentropy)
        recon_loss = keras.metrics.binary_crossentropy(x, z_decoded)
        
        # KL divergence
        kl_loss = -5e-4 * K.mean(1 + z_sigma - K.square(z_mean) - K.exp(z_sigma), axis=-1)
        return K.mean(recon_loss + kl_loss)

    # add custom loss to the class
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

 

整体模型构建
y = CustomLayer()([input_img, z_decoded])

vae = Model(input_img, y, name='vae')
vae.compile(optimizer='adam', loss=None)
vae.summary()

 

模型训练

history=vae.fit(xtrain, verbose=2, epochs = 100, batch_size = 64, validation_split = 0.2)
 训练可视化
f = plt.figure(figsize=(10,7))
f.add_subplot()
#Adding Subplot
plt.plot(history.epoch, history.history['loss'], label = "loss") # Loss curve for training set
plt.plot(history.epoch, history.history['val_loss'], label = "val_loss") # Loss curve for validation set

plt.title("Loss Curve",fontsize=18)
plt.xlabel("Epochs",fontsize=15)
plt.ylabel("Loss",fontsize=15)
plt.grid(alpha=0.3)
plt.legend()
plt.savefig("VAE_Loss_Trial5.png")
plt.show()

 中间编码特征可视化
mu, _, _ = encoder.predict(xtest)
#Plot dim1 and dim2 for mu
plt.figure(figsize=(10, 10))
plt.scatter(mu[:, 0], mu[:, 1], c=ytest, cmap='brg')
plt.xlabel('dim 1')
plt.ylabel('dim 2')
plt.colorbar()
plt.show()
plt.savefig("VAE_Colourbar_Trial5.png")

 

数据增强生成
#RANDOM GENERATION
def generate():
    n=20
    figure = np.zeros((width *2 , height * 10, channels))

#Create a Grid of latent variables, to be provided as inputs to decoder.predict
#Creating vectors within range -5 to 5 as that seems to be the range in latent space

    for k in range(2):
        for l in range(10):
            z_sample =random.rand(3072)
            z_out=np.array([z_sample])
            x_decoded = decoder.predict(z_out)
            digit = x_decoded[0].reshape(width, height, channels)
            figure[k * width: (k + 1) * width,
                    l * height: (l + 1) * height] = digit

    plt.figure(figsize=(10, 10))
#Reshape for visualization
    fig_shape = np.shape(figure)
    figure = figure.reshape((fig_shape[0], fig_shape[1],3))

    plt.imshow(figure, cmap='gnuplot2')
    plt.show()  
 
    plt.savefig("VAE_imagesgen_Trial5.png")

解码器图像重建
#IMAGE RECONSTRUCT USING TEST SET IMGS
def reconstruct():
    num_imgs = 6
    rand = np.random.randint(1, xtest.shape[0]-6) 

    xtestsample = xtest[rand:rand+num_imgs]
    x_encoded = np.array(encoder.predict(xtestsample))
    latent_xtest=x_encoded[2]
    x_decoded = decoder.predict(latent_xtest)

    rows = 2 # defining no. of rows in figure
    cols = 3 # defining no. of colums in figure
    cell_size = 1.5
    f = plt.figure(figsize=(cell_size*cols,cell_size*rows*2)) # defining a figure 
    f.tight_layout()
    for i in range(rows):
        for j in range(cols): 
            f.add_subplot(rows*2,cols, (2*i*cols)+(j+1)) # adding sub plot to figure on each iteration
            plt.imshow(xtestsample[i*cols + j]) 
            plt.axis("off")
        
        for j in range(cols): 
            f.add_subplot(rows*2,cols,((2*i+1)*cols)+(j+1)) # adding sub plot to figure on each iteration
            plt.imshow(x_decoded[i*cols + j]) 
            plt.axis("off")

    f.suptitle("Autoencoder Results - Cifar10",fontsize=18)
    plt.savefig("VAE_imagesrecons_Trial5.png")

    plt.show()

 

代码获取

已经附在文章底部,自行拿取。

项目开发,相关问题咨询,欢迎交流沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1567641.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis 主从复制,哨兵模式,集群

目录 主从复制 主从复制 作用 缺陷 主从复制流程 实现Redis主从复制 哨兵模式 主从复制切换的缺点 哨兵的核心功能 哨兵模式原理 哨兵模式的作用 哨兵结构组成 故障转移机制 主节点的选举 实现哨兵模式 集群(Cluster) redis群集有三种模式,主从复制…

解决PDF打开后显示名称与文档名称不一致的问题【不需要word模板!!!】

文章目录 简介原因解决办法参考资料 简介 最近,博主在使用Acobat打开一个PDF文件的时候发现:打开后的PDF文件标签跟该文件的存储名称不一致。这是一件令人并不十分愉快和顺心的事情,网上搜索得到的解决办法基本上都是出奇的相似,…

css心跳动画

图标引入 <img class"icon" src"heart.svg" alt"" srcset""> CSS代码 <style>.icon {animation:bpm 1s linear,pulse 0.75s 1s linear infinite;}keyframes pulse {from,75%,to {transform: scale(1);}25% {transform:…

趣学前端 | 类,我想好好继承它的知识点

背景 最近睡前习惯翻会书&#xff0c;重温了《JavaScript权威指南》。这本书&#xff0c;文字小&#xff0c;内容多。两年了&#xff0c;我才翻到第十章。因为书太厚&#xff0c;平时都充当电脑支架。 JavaScript 类 话说当年类、原型、继承&#xff0c;差点给我绕晕。 在J…

生成式AI的情感实验——AI能否产生思想和情感?

机器人能感受到爱吗&#xff1f;这是一个很好的问题&#xff0c;也是困扰了科学家们很多年的科学未解之谜。虽然我们尚未准备好向智能机器赋予情感&#xff0c;但智能机器却已经可以借助生成式人工智能&#xff08;AI&#xff09;来帮助我们表达自己的情感。 自然情感表达 AI正…

【子集回溯】Leetcode 78. 子集 90. 子集 II

【子集回溯】Leetcode 78. 子集 90. 子集 II 78. 子集90. 子集 II ---------------&#x1f388;&#x1f388;78. 子集 题目链接&#x1f388;&#x1f388;------------------- 78. 子集 class Solution {List<List<Integer>> result new ArrayList<>()…

Java 7、Java 8常用新特性

目录 Java 8 常用新特性1、Lambda 表达式2、方法引用2.1 静态方法引用2.2 特定对象的实例方法引用2.3 特定类型的任意对象的实例方法引用2.4 构造器引用 3、接口中的默认方法4、函数式接口4.1 自定义函数式接口4.2 内置函数式接口 5、Date/Time API6、Optional 容器类型7、Stre…

【随笔】Git 基础篇 -- 分支与合并 git rebase(十)

&#x1f48c; 所属专栏&#xff1a;【Git】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大…

【Pt】马灯贴图绘制过程 05-铁丝与渲染出图

目录 效果 步骤 一、基本材质 二、浮尘 三、渲染 效果 步骤 一、基本材质 CtrlAlt鼠标右键选中指定的纹理集 在智能材质中将“Iron Forged Old”加入图层 将智能材质“Iron Forged Old”文件夹打开&#xff0c;将图层“Base”和“Edge”的基本颜色改暗一点 二、浮尘 新…

PHP+python高校教务处工作管理系统q535p

开发语言&#xff1a;php 后端框架&#xff1a;Thinkphp/Laravel 前端框架&#xff1a;vue.js 服务器&#xff1a;apache 数据库&#xff1a;mysql 运行环境:phpstudy/wamp/xammp等 系统根据现有的管理模块进行开发和扩展&#xff0c;采用面向对象的开发的思想和结构化的开发方…

21.兼容性测试

考试频率低&#xff1b; 一般考兼容性测试会结合web测试&#xff1b;&#xff08;兼容性矩阵&#xff09; 主要议题&#xff1a; 1.兼容性测试概述 2.硬件兼容性测试 最低配置不讲究工作负载&#xff0c;意思是软件能够运行的最低要求环境&#xff1b; 推荐配置&#xff0c…

【精品方案】智慧金融大数据分析平台总体架构方案

以下是部分PPT内容&#xff0c;请您参阅。如需下载完整PPTX文件&#xff0c;请前往星球获取&#xff1a; 1.实现数据共享 通过数据平台实现数据集中&#xff0c;确保金融集团各级部门均可在保证数据隐私和安全的前提下使用数据&#xff0c;充分发挥数据作为企业重要资产的业务价…

Nacos 服务发现 快速入门

Nacos 服务发现 快速入门 一、Nacos 服务发现 – 什么是服务发现 &#xff1f; 1、 Nacos 服务发现-什么是服务发现 在微服务架构中&#xff0c;整个系统会按职责能力划分为多个服务&#xff0c;通过服务之间协作来实现业务目标。 这样在我们的代码中免不了要进行服务间的远程…

[HackMyVM]靶场Zurrak

难度:medium kali:192.168.56.104 靶机:192.168.56.140 端口扫描 # nmap -sV -A 192.168.56.140 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-03-30 16:59 CST Nmap scan report for 192.168.56.140 Host is up (0.00039s latency). Not shown: 996 closed tcp po…

三相异步发电机在两相坐标系下的数学模型和状态方程

目录 1、异步发电机在两相静止坐标系下的数学模型 &#xff08;1&#xff09;磁链方程&#xff1a; &#xff08;2&#xff09;电压方程 &#xff08;3&#xff09;转矩方程 &#xff08;4&#xff09;异步电动机在两相静止坐标系&#xff08; &#xff09;上的数学模型 2、…

AcWing---转圈游戏---快速幂

太久没写快速幂了... 这是一道数学题orz&#xff0c;能看出来的话答案就是 &#xff0c;但是很大&#xff0c;同时还要mod n&#xff0c;直接用快速幂即可。 快速幂模版&#xff1a; long long int power(long long int a,long long int b,long long int mod){long long int r…

YUNBEE云贝-技术分享:PostgreSQL分区表

引言 PostgreSQL作为一款高度可扩展的企业级关系型数据库管理系统&#xff0c;其内置的分区表功能在处理大规模数据场景中扮演着重要角色。本文将深入探讨PostgreSQL分区表的实现逻辑、详细实验过程&#xff0c;并辅以分区表相关的视图查询、分区表维护及优化案例&#xff0c;…

基于深度学习的番茄成熟度检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)

摘要&#xff1a;在本博客中&#xff0c;我们深入探讨了基于YOLOv8/v7/v6/v5的番茄成熟度检测系统。核心技术基于YOLOv8&#xff0c;同时融合了YOLOv7、YOLOv6、YOLOv5的算法&#xff0c;对比了它们在性能指标上的差异。本文详细介绍了国内外在此领域的研究现状、数据集的处理方…

OpenHarmony实战:轻量级系统之移植验证

OpenHarmony芯片移植完成后&#xff0c;需要开展OpenHarmony兼容性测试以及芯片SDK功能性测试。除可获得测试认证之外&#xff0c;还可以在开发阶段提前发现缺陷&#xff0c;大幅提高代码质量。 OpenHarmony兼容性测试 OpenHarmony兼容性测试是XTS&#xff08;OpenHarmony生态…

基于深度学习的植物叶片病毒识别系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)

摘要&#xff1a;本文深入研究了基于YOLOv8/v7/v6/v5的植物叶片病毒识别系统&#xff0c;核心采用YOLOv8并整合了YOLOv7、YOLOv6、YOLOv5算法&#xff0c;进行性能指标对比&#xff1b;详述了国内外研究现状、数据集处理、算法原理、模型构建与训练代码&#xff0c;及基于Strea…