CycleGAN深度学习项目

news2024/11/13 23:25:06

远程仓库

leftthomas/CycleGAN: A PyTorch implementation of CycleGAN based on ICCV 2017 paper "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" (github.com)

运行准备

Anaconda

安装需要的库

指令

pip install pandas -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install torch==1.11.0 -i Simple Index

pip install torchvision==0.12.0 -i Simple Index

pip install dominate==2.4.0 -i Simple Index

pip install visdom==0.1.8.8 -i Simple Index

pip install tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple

运行结果

数据集

我当前使用的数据集

leftthomas/CycleGAN: A PyTorch implementation of CycleGAN based on ICCV 2017 paper "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" (github.com)

因为数据集太大,训练时间慢所以删掉了很多图片

A-副本和B-副本里面是原始的数据集

A B是我自己删了图片的数据集

如果使用其他数据集也可以训练,例如:从网上随便下载图片

运行结果

程序解读

从main.py的if __name__ == '__main__':开始看

因为程序从这里开始执行

parser = argparse.ArgumentParser(description='Train Model')
# common args
parser.add_argument('--data_root', default='horse2zebra', type=str, help='Dataset root path')
# 文件放的位置
parser.add_argument('--batch_size', default=1, type=int, help='Number of images in each mini-batch')
#每个小批量中的图像数量
parser.add_argument('--epochs', default=2, type=int, help='Number of epochs over the data to train')
# 多少轮训练
parser.add_argument('--lr', default=0.0002, type=float, help='Initial learning rate')
# 开始时学习率
parser.add_argument('--decay', default=2, type=int, help='Epoch to start linearly decaying lr to 0')
# 从第几轮开始学习率逐渐减为0
parser.add_argument('--save_root', default='result', type=str, help='Result saved root path')
# 训练出来的保存在哪里
# args parse
args = parser.parse_args()
data_root, batch_size, epochs, lr = args.data_root, args.batch_size, args.epochs, args.lr
decay, save_root = args.decay, args.save_root

# data prepare
train_data = ImageDataset(data_root, 'train')
# 训练集
print("数据")
print(train_data.__len__())
# 打印出数据集的长度
test_data = ImageDataset(data_root, 'test')
# 验证集
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=8)

使用通义灵码解释

# optimizer setup
optimizer_G = Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_DA = Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_DB = Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
lr_scheduler_DA = LambdaLR(optimizer_DA, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
lr_scheduler_DB = LambdaLR(optimizer_DB, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))

这段代码是用于设置优化器和学习率调度器的。

首先,使用Adam优化器来初始化G_A和G_B的参数以及D_A和D_B的参数。Adam优化器是一种基于梯度的优化算法,它利用了动量和自适应学习率的特性。itertools.chain函数用于将G_A和G_B的参数组合在一起。

然后,使用LambdaLR学习率调度器来设置学习率的衰减。LambdaLR调度器使用给定的函数来计算每个迭代步骤的学习率。这里使用了一个lambda函数,它在迭代次数eiter超过decay后开始衰减学习率,衰减的速度由decay参数控制。

这些优化器和学习率调度器将用于训练生成器和判别器模型。

# training loop
results = {'train_g_loss': [], 'train_da_loss': [], 'train_db_loss': []}
if not os.path.exists(save_root):
    os.makedirs(save_root)
for epoch in range(1, epochs + 1):
    g_loss, da_loss, db_loss = train(G_A, G_B, D_A, D_B, train_loader, optimizer_G, optimizer_DA, optimizer_DB)
    results['train_g_loss'].append(g_loss)
    results['train_da_loss'].append(da_loss)
    results['train_db_loss'].append(db_loss)
    val(G_A, G_B, test_loader)
    lr_scheduler_G.step()
    lr_scheduler_DA.step()
    lr_scheduler_DB.step()
    # save statistics
    data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
    data_frame.to_csv('{}/results.csv'.format(save_root), index_label='epoch')
    torch.save(G_A.state_dict(), '{}/GA.pth'.format(save_root))
    torch.save(G_B.state_dict(), '{}/GB.pth'.format(save_root))
    torch.save(D_A.state_dict(), '{}/DA.pth'.format(save_root))
    torch.save(D_B.state_dict(), '{}/DB.pth'.format(save_root))

这段代码是一个训练循环,用于训练深度学习模型。以下是代码的详细解释:

首先,定义一个字典results,用于存储训练过程中的损失值。

检查保存模型和结果的目录save_root是否存在,如果不存在则创建该目录。

使用for循环遍历epochs次,每次迭代都会进行一次训练和验证。

在每次迭代中,调用train函数训练生成器G_A、G_B和判别器D_A、D_B,并更新损失值。

将训练过程中的损失值分别添加到results字典中对应的列表中。

调用val函数对模型进行验证。

更新生成器和判别器的学习率。

将results字典转换为DataFrame,并将其保存为CSV文件。

保存生成器和判别器的模型参数。

这个训练循环的主要目的是在给定的训练数据集上训练生成对抗网络(GAN),并保存训练过程中的损失值和模型参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925163.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI时代:探索个人潜能的新视角

文章目录 Al时代的个人发展1 AI的高速发展意味着什么1.1 生产力大幅提升1.2 生产关系的改变1.3 产品范式1.4 产业革命1.5 Al的局限性1.5.1局限一:大模型的幻觉 1.5.2 局限二:Token 2 个体如何应对这种改变?2.1 职场人2.2 K12家长2.3 大学生2.4 创业者 3 人工智能发…

万界星空科技商业开源MES系统全面解析

万界星空科技商业开源MES源码可拖拽式数据大屏 开源MES系统具有定制化、节省成本、开放性和适应性等优势和特点,可以帮助企业更好地管理生产流程。万界星空MES制造执行系统的Java开源版本,为制造业企业提供了全面的生产管理解决方案。万界星空科技的目标…

从零开始做题:满屏的QR

题目 给出一张png图片 解题 import os import re import cv2 import argparse import itertools import numpy as npparser argparse.ArgumentParser() parser.add_argument(-f, typestr, defaultNone, requiredTrue,help输入文件名称) parser.add_argument(-p, typestr, d…

[Vulnhub] Stapler wp-videos+ftp+smb+bash_history权限提升+SUID权限提升+Kernel权限提升

信息收集 IP AddressOpening Ports192.168.8.106TCP:21,22,53,80,123,137,138,139,666,3306, Using Nmap for scanning: $ nmap -p- 192.168.8.106 --min-rate 1000 -sC -sV The results are as follows: PORT STATE SERVICE VERSION 20/tcp closed ftp-data…

昇思25天学习打卡营第20天 | 基于MindNLP+MusicGen生成自己的个性化音乐

基于MindNLPMusicGen生成个性化音乐 实验简介 MusicGen是Meta AI提出的音乐生成模型,能够根据文本描述或音频提示生成高质量音乐。该模型基于Transformer结构,分为三个阶段:文本编码、音频token预测和音频解码。此实验将演示如何使用MindSpo…

Java常用排序算法

冒泡排序(Bubble Sort) arr[0] 与 arr[1]比较,如果前面元素大就交换,如果后边元素大就不交换。然后依次arr[1]与arr[2]比较,第一轮将最大值排到最后一位。 第二轮arr.length-1个元素进行比较,将第二大元素…

高速数据采集与图像传输对带宽需求的对比分析

对于120MHz高速采集的数据,直接传输原始数据和将数据计算生成1024x1024的图像后再传输图像,这两种方法对带宽的影响会有显著不同。为了进行详细分析,我们需要考虑以下因素:数据采样率、数据量、图像生成算法、图像压缩和传输带宽需…

Spark调度底层执行原理详解(第35天)

系列文章目录 一、Spark应用程序启动与资源申请 二、DAG(有向无环图)的构建与划分 三、Task的生成与调度 四、Task的执行与结果返回 五、监控与容错 六、优化策略 文章目录 系列文章目录前言一、Spark应用程序启动与资源申请1. SparkContext的创建2. 资…

python:绘制一元四次函数的曲线

编写 test_x4_x2_4x.py 如下 # -*- coding: utf-8 -*- """ 绘制函数 y x^4x^24x-3 在 -2<x<2 的曲线 """ import numpy as np from matplotlib import pyplot as plt# 用于正常显示中文标题&#xff0c;负号 plt.rcParams[font.sans-s…

值得关注的数据资产入表

不错的讲解视频&#xff0c;来自&#xff1a;第122期-杜海博士-《数据资源入表及数据资产化》-大数据百家讲坛-厦门大学数据库实验室主办第122期-杜海博士-《数据资源入表及数据资产化》-大数据百家讲坛-厦门大学数据库实验室主办-20240708_哔哩哔哩_bilibili

《昇思25天学习打卡营第20天|onereal》

应用实践/LLM原理和实践/基于MindSpore的GPT2文本摘要 基于MindSpore的GPT2文本摘要 数据集加载与处理 数据集加载 本次实验使用的是nlpcc2017摘要数据&#xff0c;内容为新闻正文及其摘要&#xff0c;总计50000个样本。 数据预处理 原始数据格式&#xff1a; article: [CLS…

java框架-springmvc

文章目录 2. Springmvc概述3. springmvc与struts2不同5. springmvc入门6. springmvc 配置7. Handler配置8. 异常处理器9. ssm整合思路10. 上传图片11. RESTful支持12. 拦截器总结 2. Springmvc概述 Spring web mvc和Struts2都属于表现层的框架,它是Spring框架的一部分 3. sp…

QML 鼠标和键盘事件

学习目标&#xff1a;Qml 鼠标和键盘事件 学习内容 1、QML 鼠标事件处理QML 直接提供 MouseArea 来捕获鼠标事件&#xff0c;该操作必须配合Rectangle 获取指定区域内的鼠标事件, 2、QML 键盘事件处理&#xff0c;并且获取对OML直接通过键盘事件 Keys 监控键盘任意按键应的消…

防御第二次作业完成接口配置实验

一、实验括扑图 二、实验要求 1.防火墙向下使用子接口分别对应生产区和办公区 2.所有分区设备可以ping通网关 三、实验思路 1、配置各设备的IP地址 2、划分VLAN及VLAN的相关配置 3、配置路由及安全策略 四、实验步骤 1、配置PC跟Client还有server配置&#xff0…

Hive表【汇总】

提前必备 1、内部表和外部表的区别 概念讲解&#xff1a; 外部表&#xff1a;1、存放他人给予自己的数据2、当我们删除表操作时&#xff0c;会将表的元数据删除&#xff0c;保留数据文件 内部表&#xff1a;1、存放已有的数据2、当我们删除表操作时&#xff0c;会将表的元数据…

LeetCode Day8|● 344.反转字符串(原地) ● 541. 反转字符串II(i可以大步跨越) ● 卡码网:54.替换数字(ACM模式多熟悉熟悉)

字符串part01 day8-1 ● 344.反转字符串整体思路代码实现总结 day8-2 ● 541. 反转字符串II整体思路代码实现总结 day8-3 ● 卡码网&#xff1a;54.替换数字题目解题思路代码实现总结 day8-1 ● 344.反转字符串 整体思路 字符串和数组的思路差不多 原地操作 代码实现 class…

递归解决换零钱问题--代码实现

在上一篇中, 经过深入分析, 已经得出一个能够递归的形式化的结果, 现在则准备给出一个具体实现. 结果回顾 前述结果如下: caseOfChange(amount, cashList) { // base caseif (amount.isNegative()) { // 负数 return 0; } if (amount.isZero()) { // 0元 return 1; }if (cas…

vscode终端(控制台打印乱码)

乱码出现的两种可能&#xff08;重点是下面标题2&#xff09; 1、文件中的汉字本来就是乱码&#xff0c;输出到控制台(终端)那就当然是乱码 在vscode中设置文件的编码格式为UTF-8&#xff0c; 2、输出到控制台(终端)之前的汉字不是乱码&#xff0c;针对此种情况如下设置 原因…

MySQL卸载 - Windows版

MySQL卸载 - Windows版 1. 停止MySQL服务 winR 打开运行&#xff0c;输入 services.msc 点击 “确定” 调出系统服务。 2. 卸载MySQL相关组件 打开控制面板 —> 卸载程序 —> 卸载MySQL相关所有组件 3. 删除MySQL安装目录 4. 删除MySQL数据目录 数据存放目录是在 …

C++从入门到起飞之——缺省参数/函数重载/引用全方位剖析!

目录 1.缺省参数 2. 函数重载 3.引⽤ 3.1 引⽤的概念和定义 3.2 引⽤的特性 3.3 引⽤的使⽤ 3.4 const引⽤ 3.5 指针和引⽤的关系 4.完结散花 个人主页&#xff1a;秋风起&#xff0c;再归来~ C从入门到起飞 个人格言&#xff1a;悟已往之不谏…