以下是一个使用 PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码。该代码包含训练和可视化部分,假设输入为图片和 4 个工艺参数,根据这些输入生成相应的图片。
1. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2. 定义数据集类
class ImagePairDataset(Dataset):
def __init__(self, image_pairs, params):
self.image_pairs = image_pairs
self.params = params
def __len__(self):
return len(self.image_pairs)
def __getitem__(self, idx):
input_image, target_image = self.image_pairs[idx]
param = self.params[idx]
return input_image, target_image, param
3. 定义生成器和判别器
# 生成器
class Generator(nn.Module):
def __init__(self, z_dim=4, img_channels=3):
super(Generator, self).__init__()
self.gen = nn.Sequential(
# 输入: [batch_size, z_dim + 4, 1, 1]
self._block(z_dim + 4, 1024, 4, 1, 0), # [batch_size, 1024, 4, 4]
self._block(1024, 512, 4, 2, 1), # [batch_size, 512, 8, 8]
self._block(512, 256, 4, 2, 1), # [batch_size, 256, 16, 16]
self._block(256, 128, 4, 2, 1), # [batch_size, 128, 32, 32]
nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False
),
nn.BatchNorm2d(out_channels),
nn.ReLU(True)
)
def forward(self, z, params):
params = params.view(params.size(0), 4, 1, 1)
x = torch.cat([z, params], dim=1)
return self.gen(x)
# 判别器
class Discriminator(nn.Module):
def __init__(self, img_channels=3):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
# 输入: [batch_size, img_channels + 4, 64, 64]
nn.Conv2d(img_channels + 4, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
self._block(64, 128, 4, 2, 1), # [batch_size, 128, 16, 16]
self._block(128, 256, 4, 2, 1), # [batch_size, 256, 8, 8]
self._block(256, 512, 4, 2, 1), # [batch_size, 512, 4, 4]
nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),
nn.Sigmoid()
)
def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False
),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
def forward(self, img, params):
params = params.view(params.size(0), 4, 1, 1).repeat(1, 1, img.size(2), img.size(3))
x = torch.cat([img, params], dim=1)
return self.disc(x)
4. 训练代码
def train_conditional_dcgan(image_pairs, params, batch_size=32, epochs=10, lr=0.0002, z_dim=4):
dataset = ImagePairDataset(image_pairs, params)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
gen = Generator(z_dim).to(device)
disc = Discriminator().to(device)
criterion = nn.BCELoss()
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
for i, (input_images, target_images, param) in enumerate(dataloader):
input_images = input_images.to(device)
target_images = target_images.to(device)
param = param.to(device)
# 训练判别器
opt_disc.zero_grad()
real_labels = torch.ones((target_images.size(0), 1, 1, 1)).to(device)
fake_labels = torch.zeros((target_images.size(0), 1, 1, 1)).to(device)
# 计算判别器对真实图像的损失
real_output = disc(target_images, param)
d_real_loss = criterion(real_output, real_labels)
# 生成假图像
z = torch.randn(target_images.size(0), z_dim, 1, 1).to(device)
fake_images = gen(z, param)
# 计算判别器对假图像的损失
fake_output = disc(fake_images.detach(), param)
d_fake_loss = criterion(fake_output, fake_labels)
# 总判别器损失
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
opt_disc.step()
# 训练生成器
opt_gen.zero_grad()
output = disc(fake_images, param)
g_loss = criterion(output, real_labels)
g_loss.backward()
opt_gen.step()
print(f'Epoch [{epoch+1}/{epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')
return gen
5. 可视化代码
def visualize_generated_images(gen, input_images, params, z_dim=4):
input_images = input_images.to(device)
params = params.to(device)
z = torch.randn(input_images.size(0), z_dim, 1, 1).to(device)
fake_images = gen(z, params).cpu().detach()
fig, axes = plt.subplots(1, input_images.size(0), figsize=(15, 3))
for i in range(input_images.size(0)):
img = fake_images[i].permute(1, 2, 0).numpy()
img = (img + 1) / 2 # 从 [-1, 1] 转换到 [0, 1]
axes[i].imshow(img)
axes[i].axis('off')
plt.show()
6. 示例使用
# 假设 image_pairs 是一个包含图像对的列表,params 是一个包含 4 个工艺参数的列表
image_pairs = [] # 这里需要替换为实际的图像对数据
params = [] # 这里需要替换为实际的工艺参数数据
# 训练模型
gen = train_conditional_dcgan(image_pairs, params)
# 可视化生成的图像
test_input_images, test_target_images, test_params = image_pairs[:5], image_pairs[:5], params[:5]
test_input_images = torch.stack([torch.tensor(img) for img in test_input_images]).float()
test_params = torch.tensor(test_params).float()
visualize_generated_images(gen, test_input_images, test_params)
代码说明
- 数据集类:
ImagePairDataset
用于加载图像对和工艺参数。 - 生成器和判别器:
Generator
和Discriminator
分别定义了生成器和判别器的网络结构。 - 训练代码:
train_conditional_dcgan
函数用于训练 Conditional DCGAN 模型。 - 可视化代码:
visualize_generated_images
函数用于可视化生成的图像。 - 示例使用:最后部分展示了如何使用上述函数进行训练和可视化。
请注意,你需要将 image_pairs
和 params
替换为实际的数据集。此外,代码中的超参数(如 batch_size
、epochs
、lr
等)可以根据实际情况进行调整。