KAGGLE · GETTING STARTED CODE COMPETITION 图像风格迁移 示例代码阅读

news2025/1/23 5:58:01

本博文阅读的代码来自于I’m Something of a Painter Myself | Kaggle倾情推荐:

Monet CycleGAN Tutorial | Kaggle

数据集说明

I’m Something of a Painter Myself | Kaggle

Files

  • monet_jpg - 300 Monet paintings sized 256x256 in JPEG format
  • monet_tfrec - 300 Monet paintings sized 256x256 in TFRecord format
  • photo_jpg - 7028 photos sized 256x256 in JPEG format
  • photo_tfrec - 7028 photos sized 256x256 in TFRecord format

简单介绍一下,就是有两种类型的数据提供使用,一个是JPEG格式,一个是TFRecord,训练集的size是300,测试集的size是7028。并且每张图片的大小都是256×256。

代码阅读

首先是一些说明,这个代码使用的是TensorFlow,所以也就大概看看,后面会搬家到PyTorch写写看。使用的是CycleGAN,这个很合理,因为这里是无监督学习,不过GAN的种类有超多哎,也许会有更好的GAN可以选择呢?anyway,CycleGAN也是很经典的方法啦。

先放一张PPT:

加载数据

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))
tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))

tf.io.gfile.glob(pattern): Returns a list of files that match the given pattern(s).

查找匹配pattern的文件并以列表的形式返回,pattern可以是一个具体的文件名,也可以是包含通配符的正则表达式。

参考链接:

tf.io.gfile.glob  |  TensorFlow v2.15.0.post1 (google.cn)
TensorFlow函数教程:tf.io.gfile.glob_w3cschool
Tensorflow 2.0 gfile 文件操作 - 知乎 (zhihu.com)
tf.io.gfile.glob 遍历文件-CSDN博客
TensorFlow函数:tf.io.gfile.glob_tf.io.gfile.remove函数参数-CSDN博客

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image
tf.image.decode_jpeg(image, channels=3)

将JPEG编码的图像解码为unit8的Tensor,channels取3表示返回的是RGB图像,channels取1表示返回的是灰度图像,channels取0表示使用JPEG编码图像中的通道数量。

参考链接:

加载和预处理图像  |  TensorFlow Core (google.cn)
TensorFlow函数:tf.image.decode_jpeg_w3cschool
tf.image.decode_jpeg函数与tf.image.encode_jpeg函数用法-CSDN博客

tf.cast(image, tf.float32)

将前面得到的Tensor数据类型从unit8改到float32

参考链接:

Tensorflow中 tf.cast()的用法_tensorflow.cast-CSDN博客
tensorflow——tf.cast()详解_tensorflow的cast-CSDN博客
tf.cast - TensorFlow Python - W3cubDocs

tf.reshape(image, [*IMAGE_SIZE, 3])

对image(也就是前面处理的Tensor)进行维度的调整。比如:

参考链接:

tf.reshape函数用法&理解-CSDN博客
【tensorflow】tf.reshape函数说明:重塑张量_tensorflow reshape 变大-CSDN博客
TensorFlow:使用tf.reshape函数重塑张量_w3cschool
tf.reshape(x, [-1, 28, 28, 1])_reshape((-1, 28, 28, 1)-CSDN博客
Python的reshape的用法:reshape(1,-1)-CSDN博客

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image
tf.io.FixedLenFeature([], tf.string)

解析每个输入样本的每一列数据

参考链接:

TFRecord 中 FixedLenFeature、VarLenFeature、FixedLenSequenceFeature 说明_tf.fixedlenfeature-CSDN博客
Tensorflow2.0之TFRecord文件的写入与读取_tensorflow2 使用tfrecord-CSDN博客
TensorFlow函数教程:tf.io.FixedLenFeature_w3cschool

tf.io.parse_single_example(example, tfrecord_format)

输入一个Tensor,输出一个dict

使用tf.parse_single_example() 按照schema解析dataset中每个样本;

schema的意义在于指定每个样本的每一列数据应该用哪一种特征解析函数去解析。

参考链接:

tensorflow2.0 环境下的tfrecord读写及tf.io.parse_example和tf.io.parse_single_example的区别-CSDN博客
TensorFlow2.0 TFrecord数据集的写入、读取和训练示例详解_tensorflow将图片数据写入tfrecord-CSDN博客
TensorFlow函数教程:tf.io.parse_single_example_w3cschool
Tensorflow之TFRecord的原理和使用心得 - 知乎 (zhihu.com)

def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset
tf.data.TFRecordDataset(filenames)

该数据集以字节形式从文件中加载 TFRecord,与写入时完全相同。 TFRecordDataset 本身不进行任何解析或解码。可以通过在 TFRecordDataset 之后应用 Dataset.map 转换来完成解析和解码。

参考链接:

tensorfow学习(一) ——tf.data.TFRecordDataset的使用-CSDN博客
TensorFlow2.0 TFrecord数据集的写入、读取和训练示例详解_tensorflow将图片数据写入tfrecord-CSDN博客
tensorflow入门:tfrecord 和tf.data.TFRecordDataset-CSDN博客
TFRecord + Dataset 进行数据的写入和读取 - 知乎 (zhihu.com)TensorFlow - tf.data.TFRecordDataset (runebook.dev)

创建generator

为了创建我们的generator,首先定义downsample和upsample方法。downsample,顾名思义,通过步幅减少图像的2D维度。upsample与downsample相反,增加图像的尺寸。Conv2DTranspose基本上与Conv2D层相反。

initializer = tf.random_normal_initializer(0., 0.02)

生成一组符合标准正态分布的Tensor的初始化器,类似的也可以初始化成别的形式(按照逻辑来说可能normal distribution并不是一个最好的选择,but anyway,既然GAN都能train起来,我更愿意相信……这玩意儿就是炼丹)

参考链接:

tensorflow和pytorch中的参数初始化调用方法-CSDN博客
tf.random_normal_initializer:TensorFlow初始化器_w3cschool
Tensorflow API——tf.random_normal_initializer_python中tf.random_normal_initializer什么意思-CSDN博客

gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

初始化定义了设置 Keras 各层权重随机初始值的方法。

按照正态分布生成随机张量的初始化器。

参数

  • mean: 一个 Python 标量或者一个标量张量。要生成的随机值的平均数。
  • stddev: 一个 Python 标量或者一个标量张量。要生成的随机值的标准差。
  • seed: 一个 Python 整数。用于设置随机数种子。

参考链接:

初始化 Initializers - Keras 中文文档 (kldivergence.github.io)
Layer weight initializers (keras.io)
Keras教学(6):Keras的初始化Initializers,看这一篇就够了_bias_initializer": {"module": "keras.initializers"-CSDN博客

result = keras.Sequential()
result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  use_bias=False))

将图像恢复到原来的尺寸(上采样)->实现图像从小分辨率到大分辨率映射的操作。(有很多上采样的方法,反卷积只是其中的一种方法)

参考链接:

深度卷积生成对抗网络  |  TensorFlow Core (google.cn)
Conv2DTranspose layer (keras.io)
反卷积操作Conv2DTranspose-CSDN博客

有了upsampling和downsampling方法之后就可以创建generator了。在这里采用了skip的方法,提到这样是为了缓解梯度消失(一些resnet开始在脑海里旋转)

layers.Concatenate()([x, skip])

Concatenates a list of inputs.

It takes as input a list of tensors, all of the same shape except for the concatenation axis, and returns a single tensor that is the concatenation of all inputs.

简要来说就是把两个Tensor拼起来(所以也不算resnet死灰复燃×)按照axis取值的不同决定拼接的方式。如果没有指定axis的值,default=-1,就是从倒数第一个维度进行拼接。

创建discriminator

discriminator接收输入图像并将其分类为真实或虚假(生成)。鉴别器不是输出单个节点,而是输出一个较小的2D图像,其中像素值较高表示真实分类,像素值较低表示虚假分类。

with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

从最开始的cycleGAN的图片可以看出,对于generator和discriminator,是双向的,所以两个方向都需要定义。

因为此刻generator尚未训练,所以自然生成出来的图片只能说是皇帝的新图×

创建CycleGAN

在训练步骤中,模型将照片转换为莫奈的画作,然后再转换回照片。原始照片和经过两次变换的照片之间的区别是循环一致性损失。我们希望原始照片和经过两次变换的照片彼此相似。

简要来说,为了好理解,就是auto-encoder

具体请参考文章开头处的链接

定义损失函数

with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

将真实图像比作1的矩阵,将假图像比作0的矩阵。完美的discriminator将为真实图像输出所有的1,为假图像输出所有的0。discriminator损耗输出实际损耗和生成损耗的平均值。

tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

计算真实标签和预测标签之间的交叉熵损失。将这种交叉熵损失用于二值(0或1)分类应用程序。

参考链接:

Probabilistic losses (keras.io)
TF2.0—tf.keras.losses.BinaryCrossentropy-CSDN博客
tf.keras.losses.BinaryCrossentropy函数-CSDN博客

这里延伸一下:

为什么分类要用交叉熵而不用MSE

比较直觉的解释可以是,假设我们有100个类别,假设类别1被分类成类别2,这与类别1被分成类别100其实是一样的,都是分错了,但是MSE就会觉得分成类别100错的更离谱,这是显然不合理的。

非直觉的理由(不好算)

  1. MSE作为分类的损失函数会有梯度消失的问题。
  2. MSE是非凸的,存在很多局部极小值点。

请参考链接:

分类为什么用CE而不是MSE - 知乎 (zhihu.com)

为什么分类问题不能使用mse损失函数_为什么分类不用mse-CSDN博客

训练和可视化

——————————————————————————

后续就是会用PyTorch自己写写,看情况放不放链接吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1403149.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis相关面试题大全

📕作者简介: 过去日记,致力于Java、GoLang,Rust等多种编程语言,热爱技术,喜欢游戏的博主。 📗本文收录于java面试题系列,大家有兴趣的可以看一看 📘相关专栏Rust初阶教程、go语言基…

如何快速制作动态gif图?制作gif动图就这么简单

静图和动图是图像的两种不同形式。静图是一张静止不动的图片,没有任何动作或变化。而动图则是由一系列静止的图像组成,通过快速连续播放这些图像,可以形成看起来像是有动作的效果。简单来说,静图是静止的,而动图是具有…

Copilot安装和使用最全教程

背景 Copilot 是一款由 GitHub 和 OpenAI 合作开发的代码辅助工具。它基于 OpenAI 的大型语言模型 GPT-3.5,专为帮助软件开发者提升编程效率而设计 Copilot的主要功能是通过理解用户输入的代码注释或部分代码片段、自动生成或补全代码,本文主要介绍copi…

常规二分查找中遇到的问题

以前我们写二分查找的时候&#xff0c;是这么写的&#xff1a; public static int binarySearch2(int []a,int target){int i0,ja.length-1;while(i<j){int mid(ij)/2;if(a[mid]target){return mid;}else if(a[mid]<target){imid1;}else {jmid-1;}}return -1;} 这么写&…

Conda python管理环境environments 一 从入门到精通

Conda系列&#xff1a; 翻译: Anaconda 与 miniconda的区别Miniconda介绍以及安装Conda python运行的包和环境管理 入门 使用 conda&#xff0c;可以创建、导出、列出、删除和更新 具有不同 Python 版本和/或 安装在其中的软件包。在两者之间切换或移动 环境称为激活环境。您…

如何在CentOS8使用宝塔面板本地部署Typecho个人网站并实现公网访问【内网穿透】

文章目录 前言1. 安装环境2. 下载Typecho3. 创建站点4. 访问Typecho5. 安装cpolar6. 远程访问Typecho7. 固定远程访问地址8. 配置typecho 前言 Typecho是由type和echo两个词合成的&#xff0c;来自于开发团队的头脑风暴。Typecho基于PHP5开发&#xff0c;支持多种数据库&#…

Java 面向对象案例01(黑马)

文字版格斗游戏 在Javabean类中定义方法的形参的数据类型可以是什么&#xff1f; 在JavaBean类中&#xff0c;方法的形参的数据类型可以是任何合法的Java数据类型&#xff0c;包括基本数据类型&#xff08;如int、char、boolean等&#xff09;、引用数据类型&#xff08;如Str…

【文件处理】spring boot 文件处理

接收文件 PostMappingpublic result<String> add(MultipartFile file) throws IOException {// 得到目标文件夹File directory new File("file");//如果文件夹不存在就创建if(!directory.exists()){directory.mkdirs();}//文件名称String fileName file.getO…

分子生成工具应用案例+流程 - Pocket Crafter

2023年10月9日&#xff0c;诺华公司的Lingling Shen和He Wang在Chemrxiv上发表了文章《Pocket Crafter: A 3D Generative Modeling Based Workflow for the Rapid Generation of Hit Molecules in Drug Discovery》&#xff0c;介绍了他们分子生成在hit finding项目应用中的pip…

python+appium自动化测试-Appium并发测试之python启动appium服务

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

【加速计算】从硬件、软件到网络互联,AI时代下的加速计算技术

AI、元宇宙、大模型…每一个火爆名词的背后都代表着巨大的算力需求。据了解,AI模型所需的算力每100天就要翻一倍,远超摩尔定律的18-24个月。5年后,AI所需的算力规模将是今天的100万倍以上。 在这种背景下,加速计算提供了必要的计算能力和内存,其解决方案涉及硬件、软件和…

ChatGPT:关于 OpenAI 的 GPT-4工具,你需要知道的一切

ChatGPT&#xff1a;关于 OpenAI 的 GPT-4工具&#xff0c;你需要知道的一切 什么是GPT-3、GPT-4 和 ChatGPT&#xff1f;ChatGPT 可以做什么&#xff1f;ChatGPT-4 可以做什么&#xff1f;ChatGPT 的费用是多少&#xff1f;GPT-4 与 GPT-3.5 有何不同&#xff1f;ChatGPT 如何…

红黑树(超详解)

文章目录 前言红黑树的概念红黑树的实现红黑树的结构 insert 前言 上一篇文章我们讲了AVL树,但是AVL树只是一个过渡,我们实际当中用的更多另外一颗树还是红黑树. 也不能说红黑树就是AVL树的改进,它是用另外一种方式来控制. 这棵树更抽象一些,下一步我们来看一下. 红黑树的概…

气膜建筑助力体育场馆智能化升级

随着科技的不断进步和人们对健康生活的日益重视&#xff0c;体育馆作为体育活动的主要场所也面临着智能化升级的时刻。在这个背景下&#xff0c;气膜建筑以其轻巧、灵活的特性正成为推动体育馆智能化升级的创新力量。 气膜建筑的独特优势 气膜建筑采用特殊的薄膜材料&#xff…

每日一题——LeetCode1299.将每个元素替换为右侧最大元素

方法一 个人方法&#xff1a; 题目意思就是求在i1;i的循环条件下&#xff0c;arr[i]-arr[arr.length-1]的最大值分别为多少&#xff0c;最后一项默认为-1 用slice方法可以每次把数组第一位去除&#xff0c;得到求最大值的目标数组 Math的max方法可以直接返回数组里的最大值 …

archlinux安装软件

用 pacman 安装 sudo pacman -S XXXX xxx 中填写要安装的软件就可以了 搜索的命令是 pacman -Ss 搜索的话不需要管理员权限 查看已经安装的程序 pacman -Q 可以通过 | 将前面的信息传给后面&#xff0c;相当于传参 pacman -Q | grep XXXX 删除软件 sudo pacman -Rs…

Python中的函数(二)

1 闭包与装饰器 1.1 闭包 闭包&#xff08;Closure&#xff09;是指在一个函数内部定义的函数&#xff0c;并且该内部函数可以访问外部函数作用域中的变量。闭包可以在外部函数执行完毕后&#xff0c;仍然保持对外部函数作用域的引用&#xff0c;从而可以继续访问和操作外部函…

银河麒麟桌面桌面操作系统v10保姆级安装

目录 一、下载ISO映像文件 1.产品试用申请 2.试用版下载 二、虚拟机搭建 1.新建虚拟机 2. 选择虚拟机硬件兼容性 3.选择安装客户机操作系统 4.选择客户机操作系统 5.命名虚拟机 6.处理器配置 7.虚拟机内存 8.网络类型 9.硬件 10.指定磁盘容量 三、修改虚拟…

「优选算法刷题」:在排序数组中查找元素的第一个和最后个位置

一、题目 给你一个按照非递减顺序排列的整数数组 nums&#xff0c;和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target&#xff0c;返回 [-1, -1]。 你必须设计并实现时间复杂度为 O(log n) 的算法解决此问题。 示例 1&a…