生成对抗:少样本学习

news2025/1/12 12:27:55

GAN:少样本学习

  任何深度学习模型要获得较好结果往往需要大量的训练数据。但是,高质量的数据往往是稀缺的和昂贵的。好消息是,自从GANs问世以来,这个问题得到妥善解决,我们可以通过GAN来生成高质量的合成数据样本帮助模型训练。通过设计一个特殊的DCGAN架构,在只有一个非常小的数据集上训练分类器,仍然可以实现良好的分类效果。

模型架构:

latent vector : x
generator
fake image : y
real image : x
real label : y
discriminator
discriminant predict
classificator
classification predict

数据集

FashionMNIST 是一个替代 MNIST 手写数字集的图像数据集。 它是由 Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自 10 种类别的共 7 万个不同商品的正面图片。

import os
import gzip
import numpy as np
def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz'% kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte.gz'% kind)
 
    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)
 
    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)
 
    return images, labels

数据采样

  • 从每个类别中随机采样相同数量的样本,构造小样本数据集。
def sampling_subset(feats, labels, n_samples=1280, n_classes=10):
    samples_per_class = int(n_samples / n_classes)
    X = []
    y = []
    for i in range(n_classes):
        class_feats = feats[labels == i]
        class_sample_idx = np.random.randint(low=0, high=len(class_feats), size=samples_per_class)
        X.extend([class_feats[j] for j in class_sample_idx])
        y.extend([i] * samples_per_class)
    return np.array(X), np.array(y)
  • 分类器训练数据批采样
def batch_sampling_for_classification(feats, labels, n_samples):
    sample_idx = np.random.randint(low=0, high=feats.shape[0], size=n_samples)
    X = np.array([feats[i] for i in sample_idx])
    y = np.array([labels[i] for i in sample_idx])
    return X, y
  • 判别器训练数据批采样
def batch_sampling_for_discrimination(feats, n_samples):
    sample_idx = np.random.randint(low=0,
    high=feats.shape[0],
    size=n_samples)
    X = np.array([feats[i] for i in sample_idx])
    y = np.ones((n_samples, 1))
    return X, y

生成器

import os
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.layers import Conv2D,LeakyReLU,Input,BatchNormalization,Flatten,Conv2DTranspose
from tensorflow.keras.layers import Activation, Dense,Lambda,Dropout,Softmax,Reshape
from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
latent vector : x
generator
fake image : y
def create_generator(latent_size):
    inputs = Input(shape=(latent_size,))
    x = Dense(units=128 * 7 * 7)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Reshape((7, 7, 128))(x)
    for _ in range(2):
        x = Conv2DTranspose(filters=128, kernel_size=(4, 4),strides=2, padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(filters=1, kernel_size=(7, 7), padding='same')(x)
    output = Activation('tanh')(x)
    return Model(inputs, output)

分类器和判别器

  分类器和鉴别器共享相同的特征提取层,不同只在最终输出层,这意味着,当每次分类器训练一批标记数据时,以及当鉴别器训练真假图像时,这些共享的权值都会得到更新。

real image
classificator
discriminator
class predict : y1
discriminant predict : y2
network body
fake image
def create_classificator_discriminators(input_shape, num_classes=10):
    def custom_activation(x):
        log_exp_sum = K.sum(K.exp(x), axis=-1, keepdims=True)
        return log_exp_sum / (log_exp_sum + 1.0)
    inputs = Input(shape=input_shape)
    x = inputs
    for _ in range(3):
        x = Conv2D(filters=128, kernel_size=(3, 3), strides=2, padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)
    x = Flatten()(x)
    x = Dropout(rate=0.4)(x)
    x = Dense(units=num_classes)(x)
    clf_output = Softmax()(x)
    clf_model = Model(inputs, clf_output)
    dis_output = Lambda(custom_activation)(x)
    dis_model = Model(inputs, dis_output)
    return clf_model, dis_model

训练模型

clf_model, dis_model = create_classificator_discriminators(input_shape=(28, 28, 1), num_classes=10)

clf_opt = Adam(learning_rate=2e-4, beta_1=0.5)
clf_model.compile(loss='sparse_categorical_crossentropy',optimizer=clf_opt,metrics=['accuracy'])
dis_opt = Adam(learning_rate=2e-4, beta_1=0.5)
dis_model.compile(loss='binary_crossentropy',optimizer=dis_opt)
gen_model = create_generator(latent_size=100)

dis_model.trainable = False
gan_model = Model(gen_model.input, dis_model(gen_model.output))

gan_opt = Adam(learning_rate=2e-4, beta_1=0.5)
gan_model.compile(loss='binary_crossentropy', optimizer=gan_opt)
def generate_fake_samples(model, latent_size, n_samples):
    z_input = tf.random.normal((n_samples, latent_size))
    images = model.predict(z_input)
    y = np.zeros((n_samples, 1))
    return images, y

加载数据

test_x, test_y = load_mnist('./SeData/fashion-mnist/',kind='t10k')

train_x, train_y = load_mnist('./SeData/fashion-mnist/',kind='train')
train_x = train_x.reshape((-1, 28, 28, 1))
train_x = (train_x.astype(np.float32) - 127.5)/ 127.5
test_x = test_x.reshape((-1, 28, 28, 1))
test_x = (test_x.astype(np.float32) - 127.5) / 127.5
# 小数据集
sub_x, sub_y = sampling_subset(train_x, train_y)
epochs=20
num_batches=64
batches_per_epoch = 100
num_steps = batches_per_epoch * epochs
num_samples = int(num_batches / 2)
for _ in tqdm(range(num_steps), ncols=60):
    real_x, real_y = batch_sampling_for_classification(sub_x, sub_y, num_samples)
    clf_model.train_on_batch(real_x, real_y)
    real_x, real_y = batch_sampling_for_discrimination(sub_x, num_samples)
    dis_model.train_on_batch(real_x, real_y)
    fake_x, fake_y = generate_fake_samples(gen_model, latent_size=100, n_samples=num_samples)
    dis_model.train_on_batch(fake_x, fake_y)
    gen_x = tf.random.normal((num_batches, 100))
    gen_y = np.ones((num_batches, 1))
    gan_model.train_on_batch(gen_x, gen_y)
100%|███████████████████| 2000/2000 [03:10<00:00, 10.49it/s]

验证模型

  • 生成器
  • 分类器
import matplotlib.pyplot as plt
gen_x = tf.random.normal((25, 100))
gen_y = gen_model(gen_x, training=False)
plt.figure(figsize=(6, 6))
for i in range(gen_y.shape[0]):
    plt.subplot(5, 5, i + 1)
    image = gen_y[i, :, :, :] *127.5 + 127.5
    image = tf.cast(image, tf.uint8)
    plt.imshow(image, cmap='Greys_r')
    plt.axis('off')

在这里插入图片描述

train_acc = clf_model.evaluate(train_x, train_y)[1]*100
print(f'Train accuracy: {train_acc:.2f}%')
test_acc = clf_model.evaluate(test_x, test_y)[1]*100
print(f'Test accuracy: {test_acc:.2f}%')
1875/1875 [==============================] - 5s 3ms/step - loss: 0.4794 - accuracy: 0.8287
Train accuracy: 82.87%
313/313 [==============================] - 1s 4ms/step - loss: 0.5083 - accuracy: 0.8191
Test accuracy: 81.91%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/118015.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue 使用 PDF.js 浏览pdf文件

学习关键语句: 使用 PDF.js 在网页浏览pdf vue 使用 PDF.js vue 浏览pdf文件 写在前面 很头大 , 本来网络实际地址的 pdf 文件直接放在 iframe 的 src 中就可以浏览 pdf 文件的 , 但是对于虚拟地址来说 , 这样子只会让网页当场开始下载 pdf 文件到本地 , 而并不能在网页上浏览…

C规范编辑笔记(九)

往期文章&#xff1a; C规范编辑笔记(一) C规范编辑笔记(二) C规范编辑笔记(三) C规范编辑笔记(四) C规范编辑笔记(五) C规范编辑笔记(六) C规范编辑笔记(七) C规范编辑笔记(八) 正文&#xff1a; 今天我们来分享一下C规范编辑笔记第九篇&#xff0c;话不多说&#xff0c;我…

【聆思CSK6 视觉AI开发套件试用】初体验

本篇文章来自极术社区与聆思科技组织的CSK6 视觉AI开发套件活动&#xff0c;更多开发板试用活动请关注极术社区网站。作者&#xff1a;米樂 非常幸运能有评测这次的CSK6的机会。记录使用该套件进行开发的过程和感受。 套件介绍 CSK6是聆思科技推出的一款MCUDSPNPU的SoC芯片 套件…

免费pdf合并在线,这几个神仙网站请收好

对于经常要处理PDF文档的人来说&#xff0c;pdf合并如今已经是很常见的需求了。但是这个操作对一般人来说还有点难度&#xff0c;因此很多人都在寻找好用的免费pdf合并在线网站。今天小编就为大家吐血整理了工作几年来遇到的几个免费pdf合并在线的神仙网站。 1. Pdfio 这是一…

网络故障分析助您高效网上办公(一)

前言 信息中心负责人表示&#xff0c;有用户反馈&#xff0c;在通过VPN访问某一IP的80端口时连接时断时续。同时信息中心给到的信息是通过VPN&#xff1a;XXX.XXX.253.5访问IP地址XXX.XXX.130.200的80端口出现访问时断时续问题。 需要通过分析系统看一下实际情况&#xff0c;…

【Linux修炼】11.进程的创建、终止、等待、程序替换

每一个不曾起舞的日子&#xff0c;都是对生命的辜负。 进程的创建、终止、等待、程序替换本节重点1. 进程的创建1.1 fork函数初识1.2 fork的返回值问题1.3 写时拷贝1.4 创建多个进程2. 进程终止2.1 进程退出码2.2 进程如何退出3. 进程等待3.1 进程等待的原因3.2 进程等待的方法…

Uboot中的DM驱动模型

这一篇我们学习uboot中的驱动模型的初始化&#xff0c;在uboot中&#xff0c;驱动模型被称为Driver Model&#xff0c;简称DM。这种驱动模型为uboot中的各类驱动提供了统一的接口。 1. 数据结构及概念 DM模型主要依赖于下面四种数据结构&#xff1a; udevice&#xff0c;具有…

MySQL数据库闭包 Closure Table 表实现

1、 数据库闭包表简介 像MySQL这样的关系型数据库&#xff0c;比较适合存储一些类似表格的扁平化数据&#xff0c;但是遇到像树形结构这样有深度的数据&#xff0c;就很难驾驭了。 针对这种场景&#xff0c;闭包表&#xff08;Closure Table &#xff09;是最通用的设计&…

面试官:系统需求多变时如何设计?

面试官&#xff1a;我想问个问题哈&#xff0c;项目里比较常见的问题 面试官&#xff1a;我现在有个系统会根据请求的入参&#xff0c;做出不同动作。但是&#xff0c;这块不同的动作很有可能是会发生需求变动的&#xff0c;这块系统你会怎么样设计&#xff1f; 面试官&#…

FFmpeg简单使用:视频编码 ---- YUV转H264

基本流程 从本地读取YUV数据编码为h264格式的数据&#xff0c;然后再存⼊到本地&#xff0c;编码后的数据有带startcode。 与FFmpeg 示例⾳频编码的流程基本⼀致。 函数说明&#xff1a;avcodec_find_encoder_by_name&#xff1a;根据指定的编码器名称查找注册的编码器。 av…

第二十九章 数论——中国剩余定理与线性同余方程组

第二十九章 数论——中国剩余定理与线性同余方程组一、中国剩余定理1、作用&#xff1a;2、内容&#xff1a;3、证明&#xff1a;&#xff08;1&#xff09;逆元的存在性&#xff08;2&#xff09;验证定理的正确性4、代码实现&#xff1a;&#xff08;1&#xff09;步骤&#…

国产操作系统openEuler22.03配置yum源

作者&#xff1a;IT圈黎俊杰 本文选用的操作系统版本是openEuler22.03-LTS。openEuler是指操作系统的品牌英文名&#xff0c;中文名叫“欧拉”&#xff1b;22.03是指版本号&#xff08;openEuler以年月为版本号&#xff0c;22.03表示2022年03月发布的版本&#xff09;&#xff…

sonarqube——前端vue本地代码审查code review查看代码行数和注释率

目录一、环境二、操作1.启动2.中文3.使用三、过程踩坑1.sonarqube启动闪退2.解析报错 node 14.17一、环境 windows 64位 环境压缩包下载&#xff08;sonar9.8&#xff0c;jdk11&#xff0c;sonar-scanner&#xff09; 下载完成解压后&#xff0c;将 sonar-scanner-4.7.0.2747-…

curl 指令

勿以恶小而为之&#xff0c;勿以善小而不为---- 刘备 curl 是常用的命令行工具&#xff0c;用来请求 Web 服务器。 它的名字就是客户端&#xff08;client&#xff09;的 URL 工具的意思。 它的功能非常强大&#xff0c;命令行参数多达几十种 我们后端开发者&#xff0c; 可以…

MyISAM索引解析、InnoDB索引解析

我们经常说到的存储引擎是说数据库级别还是说表级别&#xff1f; 答&#xff1a;表级别。&#xff08;数据库级别也可以设置&#xff0c;但是最终它的级别生效是在表级别&#xff09; 1、MylSAM存储引擎索引实现 MylSAM索引文件和数据文件是分离的&#xff08;非聚集&#xf…

大数据开发中级练习题目(python超详细)

给定长度为m的非重复数组p&#xff0c;以及从其中取n&#xff08;n<m&#xff09;个数字组成新的子数组q。现要对p进行排序&#xff0c;要求&#xff1a;q在数组的最前方&#xff0c;其余数字按从小到大的顺序依次排在后面 输入样例&#xff1a; q [3, 5, 4] p [5, 4, 3…

37. 解数独

37. 解数独 编写一个程序&#xff0c;通过填充空格来解决数独问题。 数独的解法需 遵循如下规则&#xff1a; 数字 1-9 在每一行只能出现一次。 数字 1-9 在每一列只能出现一次。 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。&#xff08;请参考示例图&#xff…

SAP 财务年结操作宝典

目录 一 、后台操作篇 1.1 维护会计凭证编号范围 2.2 维护CO版本 1.3 维护利润中心版本 1.4 维护物料分类账文档的编号范围 (如 1.5 复制合并凭证编号范围(如果公司没有这个业务的) 1.6 维护发票凭证的编号范围间隔 (如果不针对年度则不用维护) 1.7 维护发票凭证的编号范…

MCU-51:单片机串口详解

目录一、计算机通信简介二、串口通信简介2.1 同步通信2.2 异步通信三、串行通信的传输方式四、串口通信硬件电路五、常见接口介绍六、串口相关寄存器详解6.1 特殊功能寄存器SCON6.2 PCON寄存器6.3 TMOD寄存器七、代码演示-单片机和电脑通信7.1 串口向电脑发送数据7.2 电脑通过串…

YOLO-V5 算法和代码解析系列(二)—— 【train.py】核心内容

文章目录调试设置整体结构代码解析ModelTrainloader分布式训练FreezeOptimizerSchedulerEMA调试设置 调试平台&#xff1a;Ubuntu&#xff0c;VSCode 调试设置&#xff0c;打开【/home/slam/kxh-1/2DDection/yolov5/.vscode/launch.json】&#xff0c;操作如下图所示&#xff…