基于 CycleGAN 对抗网络的自定义数据集训练

news2025/1/17 5:59:31

目录

生成对抗网络(GAN)

CycleGAN模型训练

训练数据生成

下载开源项目CycleGAN

配置训练环境

开始训练

模型测试

可视化结果


生成对抗网络(GAN)

        首先介绍一下什么是GAN网络,它是由生成器(Generator)和判别器(Discriminator)组成,二者均是由神经网络构成,通过不断的博弈来提高输出数据质量。

        生成器的目的是学习真实数据的分布,从而能够生成与真实数据相似的新样本。它接收随机噪声作为输入,并通过一系列的神经网络层将其转化为具有特定特征的输出,试图欺骗判别器使其认为生成的数据是真实的。

        判别器则负责区分输入数据是来自真实数据集还是由生成器生成的。它接收数据并输出一个概率值,表示该数据为真实数据的可能性。判别器通过不断学习来提高自己区分真实数据和生成数据的能力

        在训练过程中,生成器和判别器进行对抗性的博弈。生成器努力提高生成数据的质量,以使其能够骗过判别器;而判别器则努力提高自己的鉴别能力,不被生成器欺骗。通过不断地迭代训练,双方的性能逐渐提升,最终达到一种平衡状态,此时生成器能够生成非常逼真的样本,而判别器也具有较高的鉴别能力。

CycleGAN 是由 Jun-Yan Zhu 等人于 2017 年提出的,核心思想是通过两个生成器和两个判别器来实现无监督的图像转换2。它引入了循环一致性损失,确保转换是双向的且在转换前后能够保持图像的一致性。

CycleGAN 论文:https://arxiv.org/abs/1703.10593

上面这个图是该网络实现的风格迁移,感觉这个网络还是挺有意思的,就想着训练一下自己的数据集看下效果,那下面我们直接进入正题吧。

CycleGAN模型训练

注意:目前只尝试过图像对的训练,仅支持包含src和dst的数据集

GitHub项目:CycleGAN-based-train

整体目录架构:

训练数据生成

首先准备自己需要训练的数据集,需要包含源和目标,数据集的格式如下:

其中,O-HAZY NTIRE 2018是根目录,GT是源图像存放路径,hazy是目标图像存放路径

同时请准备好测试样本文件夹test-sample(可自定义),准备的一定要是图像文件夹,暂时不会支持单张图像的测试,格式如下:

数据集准备好后运行main.py文件,需要注意参数设置,具体请查看文件说明

# main.py

import os
import shutil
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use('TkAgg')
from tqdm import tqdm

# ----------------------训练数据路径-----------------------#
#   仅支持包含src和dst的数据集(图像对)
# -------------------------------------------------------#
root = r'O-HAZY NTIRE 2018'
# --------------------------------------------------------#
#       label1:src的路径名  |  label2:dst的路径名
# --------------------------------------------------------#
label1 = 'GT'
label2 = 'hazy'
# -------------------------生成图像可视化-------------------------#
#   !!! 在训练和测试均完成后进行结果检查时仅可设置为True,否则报错  !!!
#   该部分只是对结果的可视化,预测阶段请查看README
# -------------------------------------------------------------#
test = False
# ------------------------测试样本------------------------------#
test_data_path = './test-sample'
# ------------------------测试结果图像保存路径---------------------#
# !!!   里面是已经得到的测试结果和原图     !!!
# -------------------------------------------------------------#
results_path = './results/dehaze_cyclegan/test_latest/images/'


def make_data(src_path, dst_path, label):
    src_path = src_path + f'/{label}/'
    image_files = [f for f in os.listdir(src_path) if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp'))]
    
    with tqdm(total=len(image_files)) as pbar:
        for filename in image_files:
            file_path = os.path.join(src_path, filename)
            if filename.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):
                image = Image.open(file_path)
                target_file = os.path.join(dst_path, filename)
                image.save(target_file)
                
            pbar.update(1)



if __name__ == '__main__':
    if not test:
        # -------------------创建CycleGAN的训练数据路径-----------------------#
        if not os.path.exists('dataset'):
            os.makedirs('dataset')
        if not os.path.exists('dataset/trainA'):
            os.makedirs('dataset/trainA')
        if not os.path.exists('dataset/trainB'):
            os.makedirs('dataset/trainB')

        # --------------------------检查图像对数量----------------------------#
        num_images = len(os.listdir(root + f'/{label1}/'))
        idx = np.arange(1, num_images + 1)
        print(f'查找到{num_images}个图像对')

        make_data(root, 'dataset/trainA/', label1)
        make_data(root, 'dataset/trainB/', label2)

    # ----------------------可视化阶段-----------------------------------#
    else:
        for f in os.listdir(test_data_path):
            fake = f.split('.')[0] + '_fake.png'
            real = f.split('.')[0] + '_real.png'

            fig = plt.figure()
            ax = plt.subplot(1, 2, 1)
            img1 = Image.open(results_path + real)
            plt.imshow(img1)

            ax = plt.subplot(1, 2, 2)
            img2 = Image.open(results_path + fake)
            plt.imshow(img2)

            plt.show()

下载开源项目CycleGAN

这一步如果下载了我上传的GitHub仓库的可以直接跳过,因为我已经将该项目放置在仓库里面,不需要重复下载。当然如果没有下载,请继续往下看

方式一:git clone GitHub - junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch

方式二:百度网盘:pytorch-CycleGAN-and-pix2pix

链接:https://pan.baidu.com/s/1WC-kEonwm7bFujO72GZAcQ        提取码:jsw2

配置训练环境

终端打开pytorch-CycleGAN-and-pix2pix,输入以下命令

pip install -r requirements.txt

开始训练

同样的,在终端打开该项目,输入以下指令:

python train.py --dataroot ./dataset --name dehaze_cyclegan --model cycle_gan

其中,只有 --name 是可改参数,可以自己命名模型的名称,但是修改后一定要与测试时的名称一致,请一定注意这一点

此外,如果在训练过程中出现“OSError: [WinError 1455] 页面文件太小,无法完成操作”报错信息,这是由于训练环境所在磁盘虚拟内存不足导致,调整方法如下:

最后一步选择训练环境所在的磁盘进行修改即可

训练过程截图

模型测试

在终端打开该项目,输入以下指令:

cp ./checkpoints/dehaze_cyclegan/latest_net_G_A.pth ./checkpoints/dehaze_cyclegan/latest_net_G.pth
python test.py --dataroot ./test-sample --name dehaze_cyclegan --model test --no_dropout --direction AtoB

这里需要注意的是 --dataroot 是测试样本,可以自己调整路径,同时注意模型名称是否与训练的一致,不一致请修改

生成的结果会保存在results文件夹下,目录结构如下:

其中,fake是生成图像,real是原图像,同时所有图像尺寸均会被调整为256\times 256

可视化结果

运行main.py文件,需要设置3个参数:test、test_data_path、results_path(test=True),详情请查看具体文件

我想要实现图像加雾,但是这个效果看起来一般吧,也有可能是图像数据对和训练轮次太少了。但不管怎么说,终究还是成功了嘛。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2132736.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工具、环境等其他小问题归纳

此篇文章内容会不定期更新,仅作为学习过程中的笔记记录 一、查询Windows 10环境下python版本与安装路径 若电脑成功安装了python环境,不小心忘了版本。 I、查询版本 1、cmd窗口快捷查询 Win R 输入cmd 进入窗口; 直接输入 python --version …

2024.9.13 系统运维

学习目标:了解 云计算运维 “云计算是中国的骄傲!” 企业向云服务商租用云服务(省钱、省心、省力) 云计算:公有云、私有云(大公司,数据隐私性)、混合云(私有云跑重要…

前端刷新进不了登录页面

报错props.ts:15 Uncaught (in promise) SyntaxError: Unexpected token 错误截图: 原因:谷歌浏览器版本过低,升级浏览器 比如这边版本就过低了

ThinkCMF框架任意内容包含漏洞的讲解

本文来自无问社区,更多网安资料可前往查看http://www.wwlib.cn 背景描述 ThinkCMF是一款基于PHPMYSQL开发的中文内容管理框架,底层采用ThinkPHP3.2.3构建。 ThinkCMF提出灵活的应用机制,框架自身提供基础的管理功能,而开发者可…

CSP 2023 提高级第一轮单项选择题解析

CSP 2023 提高级第一轮单项选择题解析 第1题第2题第3题第4题第5题第6题第7题第8题第9题第10题第11题第12题第13题第14题第15题 第1题 在 Linux 系统终端中,以下哪个命令用于创建一个新的目录?(B) A.newdir B.mkdir C.create D.mkfold 解析:记…

部署Tomcat和抓包

部署Tomcat 复制文件到桌面 查看自己是否有java环境,下图所示是有的,若没有需另行下载 解压tomcat文件 tar -xzvf apache-tomcat-7.0.96.tar.gz 下列为tomcat文件的几个重要文件 进入到bin文件中 启动tomcat ./startup.sh 可以先用本机查看是否启动…

【PostgreSQL里的restartpoint重启点】

不知道大家有没有关注过,配置文件里archive_cleanup_command参数的注释部分有着这么一句"command to execute at every restartpoint",意思是在每个restartpoint时执行的命令。 提起checkpoint大家可能比较熟悉,对于这个restartpoint&#xff…

英文软件汉化中文软件教程asi exe dll 等汉化教程

相信大家在使用国际软件的时候,会经常碰到英文类型的软件 或者玩一些游戏使用一些工具,也基本都是外网的,那么对于用户来讲 就会非常的不方便! 小编为大家整理了一些国内大佬出的的英文软件汉化中文软件的视频教程 教程分为EX…

HarmonyOS开发实战( Beta5.0)滑动视频自动播放案例实践

鸿蒙HarmonyOS开发往期文章必看: HarmonyOS NEXT应用开发性能实践总结 最新版!“非常详细的” 鸿蒙HarmonyOS Next应用开发学习路线!(从零基础入门到精通) 介绍 本示例主要介绍视频列表滑动到屏幕中间自动播放场景&…

[项目] -登录框

前言 各位师傅大家好,我是qmx_07,今天来给大家讲解登录框的小练习,就此SDK的相关学习就此结束 登录框 对话框绘制 通过添加DIaLog对话框,添加 static test文本、Edit Control输入框、Button按钮,制作登录框passwor…

快速入门编写一个Java程序

一、jdk配置 下载完jdk后需要配置环境变量 以下是其步骤 1、我的电脑-属性-高级系统设置-环境变量 2、在系统变量中新建JAVA_HOME环境变量,指向jdk的安装目录 3、编辑path环境变量,新建%JAVA_HOME%\bin 4、打开Dos命令行,任意目录下敲入j…

CGAL and the Boost Graph Library

CGAL and the Boost Graph Library 许多几何数据结构都可以解释为图,因为它们由顶点和边组成。对于halfedge数据结构、多面体曲面、arrangement以及二维三角剖分类来说,情况都是如此。利用对偶性,人们也可以将面解释为顶点,相邻面…

AcWing119 袭击

目录 AcWing119 袭击题目描述背景输入输出数据范围 题解解法优化 打赏 AcWing119 袭击 题目描述 背景 特工进入据点突袭发电站,已知所有发电站的位置和所有特工的降落位置,求任意特工距离任意核电站的最短距离 输入 第一行一个整数 T T T&#xff0…

基于SpringBoot实现SpringMvc上传下载功能实现

SpringMvc上传下载功能实现 1.创建新的项目 1)项目信息填写 Spring Initializr (单击选中)Name(填写项目名字)Language(选择开发语言)Type(选择工具Maven)Group()JDK(jdk选择17 &…

深度学习——D1(环境配置)

课程内容 W-H-W 资源 AI地图 物体检测和分割 样式迁移 人脸合成 文字生成图片 预测与训练 本地安装

【IPV6从入门到起飞】5-2 IPV6+Home Assistant(ESP32+MQTT+DHT11+BH1750)传感器采集上传监测

IPV6Home Assistant[ESP32MQTTDHT11BH1750]传感器采集上传监测 1 背景2 实现效果3 Home Assistant配置3-1 MQTT配置3-2 yaml 配置3-3 加载配置 4 ESP32搭建4-1 开发环境4-2 工程代码 5 实现效果 1 背景 在上一小节【IPV6从入门到起飞】5-1 IPV6Home Assistant(搭建基本环境)我…

luogu基础课题单 入门 上

【深基2.例5】苹果采购 题目描述 现在需要采购一些苹果,每名同学都可以分到固定数量的苹果,并且已经知道了同学的数量,请问需要采购多少个苹果? 输入格式 输入两个不超过 1 0 9 10^9 109 正整数,分别表示每人分到…

chapter1-项目搭建

文章目录 序章1. 项目开发基础概念1.1 企业开发中常见的web项目类型1.2 企业项目开发流程1.3 立项申请阶段 2. 需求分析2.1 首页2.2 登录注册2.3 课程列表2.4 课程详情2.5 购物车2.6 商品结算2.7 购买成功2.8 个人中心2.9 我的课程及课程学习 3. 环境搭建3.1 创建虚拟环境3.2 相…

2024.9.13 Python与图像处理新国大EE5731课程大作业,索贝尔算子计算边缘,高斯核模糊边缘,Haar小波计算边缘

1.编写一个图像二维卷积程序。它应该能够处理任何灰度输入图像,并使用以下内核进行操作: %matplotlib inline import numpy as np import matplotlib.pyplot as plt from scipy import linalg import random as rm import math import cv2# import and …

基于云端的SIEM解决方案

最近的一项市场研究爆出了一组惊人的数字,在2024年,网络攻击增加了600%!更加令人担忧的是,这恐怕只是冰山一角。世界各地的组织都已经认识到了这一威胁,并正在采取多重措施来抵御来自线下和远程混合式办公模式带来的网…