PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码

news2025/3/18 15:18:38

以下是一个使用 PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码。该代码包含训练和可视化部分,假设输入为图片和 4 个工艺参数,根据这些输入生成相应的图片。

1. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. 定义数据集类

class ImagePairDataset(Dataset):
    def __init__(self, image_pairs, params):
        self.image_pairs = image_pairs
        self.params = params

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        input_image, target_image = self.image_pairs[idx]
        param = self.params[idx]
        return input_image, target_image, param

3. 定义生成器和判别器

# 生成器
class Generator(nn.Module):
    def __init__(self, z_dim=4, img_channels=3):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # 输入: [batch_size, z_dim + 4, 1, 1]
            self._block(z_dim + 4, 1024, 4, 1, 0),  # [batch_size, 1024, 4, 4]
            self._block(1024, 512, 4, 2, 1),  # [batch_size, 512, 8, 8]
            self._block(512, 256, 4, 2, 1),  # [batch_size, 256, 16, 16]
            self._block(256, 128, 4, 2, 1),  # [batch_size, 128, 32, 32]
            nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, z, params):
        params = params.view(params.size(0), 4, 1, 1)
        x = torch.cat([z, params], dim=1)
        return self.gen(x)

# 判别器
class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # 输入: [batch_size, img_channels + 4, 64, 64]
            nn.Conv2d(img_channels + 4, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(64, 128, 4, 2, 1),  # [batch_size, 128, 16, 16]
            self._block(128, 256, 4, 2, 1),  # [batch_size, 256, 8, 8]
            self._block(256, 512, 4, 2, 1),  # [batch_size, 512, 4, 4]
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, img, params):
        params = params.view(params.size(0), 4, 1, 1).repeat(1, 1, img.size(2), img.size(3))
        x = torch.cat([img, params], dim=1)
        return self.disc(x)

4. 训练代码

def train_conditional_dcgan(image_pairs, params, batch_size=32, epochs=10, lr=0.0002, z_dim=4):
    dataset = ImagePairDataset(image_pairs, params)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    gen = Generator(z_dim).to(device)
    disc = Discriminator().to(device)

    criterion = nn.BCELoss()
    opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (input_images, target_images, param) in enumerate(dataloader):
            input_images = input_images.to(device)
            target_images = target_images.to(device)
            param = param.to(device)

            # 训练判别器
            opt_disc.zero_grad()
            real_labels = torch.ones((target_images.size(0), 1, 1, 1)).to(device)
            fake_labels = torch.zeros((target_images.size(0), 1, 1, 1)).to(device)

            # 计算判别器对真实图像的损失
            real_output = disc(target_images, param)
            d_real_loss = criterion(real_output, real_labels)

            # 生成假图像
            z = torch.randn(target_images.size(0), z_dim, 1, 1).to(device)
            fake_images = gen(z, param)

            # 计算判别器对假图像的损失
            fake_output = disc(fake_images.detach(), param)
            d_fake_loss = criterion(fake_output, fake_labels)

            # 总判别器损失
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            opt_disc.step()

            # 训练生成器
            opt_gen.zero_grad()
            output = disc(fake_images, param)
            g_loss = criterion(output, real_labels)
            g_loss.backward()
            opt_gen.step()

        print(f'Epoch [{epoch+1}/{epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')

    return gen

5. 可视化代码

def visualize_generated_images(gen, input_images, params, z_dim=4):
    input_images = input_images.to(device)
    params = params.to(device)
    z = torch.randn(input_images.size(0), z_dim, 1, 1).to(device)
    fake_images = gen(z, params).cpu().detach()

    fig, axes = plt.subplots(1, input_images.size(0), figsize=(15, 3))
    for i in range(input_images.size(0)):
        img = fake_images[i].permute(1, 2, 0).numpy()
        img = (img + 1) / 2  # 从 [-1, 1] 转换到 [0, 1]
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.show()

6. 示例使用

# 假设 image_pairs 是一个包含图像对的列表,params 是一个包含 4 个工艺参数的列表
image_pairs = []  # 这里需要替换为实际的图像对数据
params = []  # 这里需要替换为实际的工艺参数数据

# 训练模型
gen = train_conditional_dcgan(image_pairs, params)

# 可视化生成的图像
test_input_images, test_target_images, test_params = image_pairs[:5], image_pairs[:5], params[:5]
test_input_images = torch.stack([torch.tensor(img) for img in test_input_images]).float()
test_params = torch.tensor(test_params).float()
visualize_generated_images(gen, test_input_images, test_params)

代码说明

  1. 数据集类ImagePairDataset 用于加载图像对和工艺参数。
  2. 生成器和判别器GeneratorDiscriminator 分别定义了生成器和判别器的网络结构。
  3. 训练代码train_conditional_dcgan 函数用于训练 Conditional DCGAN 模型。
  4. 可视化代码visualize_generated_images 函数用于可视化生成的图像。
  5. 示例使用:最后部分展示了如何使用上述函数进行训练和可视化。

请注意,你需要将 image_pairsparams 替换为实际的数据集。此外,代码中的超参数(如 batch_sizeepochslr 等)可以根据实际情况进行调整。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2316559.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp 实现的步进指示器组件

采用 uniapp 实现的一款步进指示器组件,展示业务步骤进度等内容,对外提供“前进”、“后退”方法,让用户可高度自定义所需交互,适配 web、H5、微信小程序(其他平台小程序未测试过,可自行尝试) 可…

大模型-提示词调优

什么是提示词 提示词(Prompt)在大模型应用中扮演着关键角色,它是用户输入给模型的一段文本指令 。简单来说,就是我们向大模型提出问题、请求或描述任务时所使用的文字内容。例如,当我们想让模型写一篇关于春天的散文&a…

继承知识点—详细

一:普通写法 package extend_;public class Extends01 {public static void main(String[] args) {Pubil pubil new Pubil();pubil.name"小明";pubil.age18;pubil.testing();pubil.setScore(60);pubil.showInfo();System.out.println("-----------…

设备管理VTY(Telnet、SSH)

实验目的:物理机远程VTY通过telnet协议登录AR1,ssh协议登录AR2和sw 注意配置Cloud1: 注意!!博主的物理机VMnet8--IP:192.168.160.1,所以AR1路由0/0/0端口才添加IP:192.168.160.3,每个…

Linux 中 Git 使用指南:从零开始掌握版本控制

目录 1. 什么是 Git? Git 的核心功能: 2. Git 的安装 Ubuntu/Debian 系统: 验证安装: 3.gitee库 4. Git 的首次配置 配置用户名和邮箱: 查看配置: 5. Git 的基本使用 初始化仓库 添加文件到暂存区…

CSS -属性值的计算过程

目录 一、抛出两个问题1.如果我们学过优先级关系,那么请思考如下样式为何会生效2.如果我们学习过继承,那么可以知道color是可以被子元素继承使用的,那么请思考下述情景为何不生效 二、属性值计算过程1.确定声明值2.层叠冲突3.使用继承4.使用默…

百度贴吧IP和ID是什么意思?怎么查看

在百度贴吧这一充满活力的网络社区中,IP和ID是两个频繁出现的概念。它们各自承载着不同的意义和作用,对于贴吧用户而言,了解这两个概念有助于更好地参与社区互动、保护个人隐私以及维护社区秩序。本文将详细解析百度贴吧中IP和ID的含义&#…

SpiderX:专为前端JS加密绕过设计的自动化工具

SpiderX 一、工具概述 SpiderX是一款专为解决前端JS加密问题而设计的自动化绕过工具。在网络安全领域,随着前端加密技术的普及,传统的爬虫和自动化测试工具在面对复杂的JS加密时显得力不从心。SpiderX应运而生,旨在通过自动化手段高效绕过前…

基于银河麒麟系统ARM架构安装达梦数据库并配置主从模式

达梦数据库简要概述 达梦数据库(DM Database)是一款由武汉达梦公司开发的关系型数据库管理系统,支持多种高可用性和数据同步方案。在主从模式(也称为 Master-Slave 或 Primary-Secondary 模式)中,主要通过…

【AWS入门】AWS云计算简介

【AWS入门】AWS云计算简介 A Brief Introduction to AWS Cloud Computing By JacksonML 什么是云计算?云计算能干什么?我们如何利用云计算?云计算如何实现? 带着一系列问题,我将做一个普通布道者,引领广…

适合企业内训的AI工具实操培训教程(37页PPT)(文末有下载方式)

详细资料请看本解读文章的最后内容。 资料解读:适合企业内训的 AI 工具实操培训教程 在当今数字化时代,人工智能(AI)技术迅速发展,深度融入到各个领域,AIGC(人工智能生成内容)更是成…

【数据结构与算法】Java描述:第四节:二叉树

一、树的相关概念 编程中的树是模仿大自然中的树设计的,呈现倒立的结构,我们着重掌握 二叉树 。 1.1 基本概念: 结点的度:一个结点有几个子结点,度就是几; 如上图:A的度为3 树的度&#xff1…

Day5 结构体、文字显示与GDT/IDT初始化

文章目录 1. harib02b用例(使用结构体)2. harib02c用例3. harib02d用例(显示字符图案)3. harib02e用例(增加字符图案)4. harib02g用例4.1 显示字符串4.2 显示变量值 5. harib02h用例(显示鼠标&a…

系统思考全球化落地

感谢加密货币公司Bybit的再次邀请,为全球团队分享系统思考课程!虽然大家来自不同国家,线上学习的形式依然让大家充满热情与互动,思维的碰撞不断激发新的灵感。 尽管时间存在挑战,但我看到大家的讨论异常积极&#xff…

【开原宝藏】30天学会CSS - DAY1 第一课

下面提供一个由浅入深、按步骤拆解的示例教程,让你能从零开始,逐步理解并实现带有旋转及悬停动画的社交图标效果。为了更简单明了,以下示例仅创建四个图标(Facebook、Twitter、Google、LinkedIn),并在每一步…

钉钉项目报销与金蝶系统高效集成技术解析

钉钉报销【项目报销类】集成到金蝶付款单【画纤骨】的技术实现 在企业日常运营中,数据的高效流转和准确对接是提升业务效率的关键。本文将分享一个具体的系统对接集成案例:如何将钉钉平台上的项目报销数据无缝集成到金蝶云星空的付款单系统中。本次方案…

Datawhale coze-ai-assistant:Task 1 了解 AI 工作流 + Coze的介绍

学习网址:Datawhale-学用 AI,从此开始 工作流(Workflow)是指完成一项任务或目标时,按照特定顺序进行的一系列活动或步骤。它强调在计算机应用环境下的自动化,通过将复杂的任务拆分成多个简单的步骤,每一步都…

深度学习 Deep Learning 第3章 概率论与信息论

第三章 概率与信息论 概述 本章介绍了概率论和信息论的基本概念及其在人工智能和机器学习中的应用。概率论为处理不确定性提供了数学框架,使我们能够量化不确定性和推导新的不确定陈述。信息论则进一步帮助我们量化概率分布中的不确定性。在人工智能中,…

GStreamer —— 2.15、Windows下Qt加载GStreamer库后运行 - “播放教程 1:Playbin 使用“(附:完整源码)

运行效果 介绍 我们已经使用了这个元素,它能够构建一个完整的播放管道,而无需做太多工作。 本教程介绍如何进一步自定义,以防其默认值不适合我们的特定需求。将学习: • 如何确定文件包含多少个流,以及如何切换 其中。…

MYsql—1

1.mysql的安装 在windows下安装mysql,直接官网搜索即可:http://www.mysql.com/,自己找想要的版本进行download,官网长这样 安装路径需要是英文路径,设置默认即可,若安装执行内容时报错,则AltCt…