前言:
整个工程分为两个文件:
gan.py: 网络模型搭建
main.py: 数据集生成,模型训练
目录:
- GAN 网络结构
- gan.py
- main.py
一 GAN 网络结构
1.1 训练D
最大化V
1.2 训练G
固定G, 最小化
二 gan.py
功能:
实现 鉴别器D
实现 生成器G
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023
@author: chengxf2
"""
import torch
from torch import nn
#生成器模型
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
# z: [batch,input_features]
h_dim = 400
self.net = nn.Sequential(
nn.Linear(2, h_dim),
nn.ReLU(True),
nn.Linear( h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, h_dim),
nn.ReLU(True),
nn.Linear(h_dim, 2)
)
def forward(self, z):
output = self.net(z)
return output
#鉴别器模型
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
hDim=400
# x: [batch,input_features]
self.net = nn.Sequential(
nn.Linear(2, hDim),
nn.ReLU(True),
nn.Linear(hDim, hDim),
nn.ReLU(True),
nn.Linear(hDim, hDim),
nn.ReLU(True),
nn.Linear(hDim, 1),
nn.Sigmoid()
)
def forward(self, x):
#x:[batch,1]
output = self.net(x)
out = output.view(-1)
return out
三 main.py
功能:
生成数据
训练网络
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023
@author: chengxf2
"""
import visdom
from gan import Discriminator
from gan import Generator
import numpy as np
import random
import torch
from torch import nn, optim
from matplotlib import pyplot as plt
h_dim =400
batchSize = 256
viz = visdom.Visdom()
#viz = visdom.Visdom()
def weights_init(net):
if isinstance(net, nn.Linear):
# net.weight.data.normal_(0.0, 0.02)
nn.init.kaiming_normal_(net.weight)
net.bias.data.fill_(0)
def data_generator():
"""
8- gaussian destribution
Returns
-------
None.
"""
scale = 2
a = np.sqrt(2.0)
centers =[
(1,0),
(-1,0),
(0,1),
(0,-1),
(1/a,1/a),
(1/a,-1/a),
(-1/a, 1/a),
(-1/a,-1/a)
]
centers = [(scale*x, scale*y) for x,y in centers]
while True:
dataset =[]
for i in range(batchSize):
point = np.random.randn(2)*0.02
center = random.choice(centers)
point[0] += center[0]
point[1] += center[1]
dataset.append(point)
dataset = np.array(dataset).astype(np.float32)
dataset /=a
#生成器函数是一个特殊的函数,可以返回一个迭代器
yield dataset
def generate_image(D, G, xr, epoch): #xr表示真实的sample
"""
Generates and saves a plot of the true distribution, the generator, and the
critic.
"""
N_POINTS = 128
RANGE = 3
plt.clf()
points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
points = points.reshape((-1, 2)) # (16384, 2)
x = y = np.linspace(-RANGE, RANGE, N_POINTS)
N = len(x)
# draw contour
with torch.no_grad():
points = torch.Tensor(points) # [16384, 2]
disc_map = D(points).cpu().numpy() # [16384]
plt.contour(x, y, disc_map.reshape((N, N)).transpose())
#plt.clabel(cs, inline=1, fontsize=10)
plt.colorbar()
# draw samples
with torch.no_grad():
z = torch.randn(batchSize, 2) # [b, 2]
samples = G(z).cpu().numpy() # [b, 2]
plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')
plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')
viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
def main():
maxIter = 1000
torch.manual_seed(10)
np.random.seed(10)
data_iter = data_generator()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
G = Generator().to(device)
D = Discriminator().to(device)
G.apply(weights_init)
D.apply(weights_init)
optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))
optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))
K = 5
viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))
for epoch in range(maxIter):
#1: train Discrimator fistly
for k in range(K):
#1.1: train on real data
xr = next(data_iter)
xr = torch.from_numpy(xr).to(device)
predr = D(xr)
vr = torch.log(predr)
#max(predr) == min(-predr)
lossr = vr.mean()
#1.2: train on fake data
z = torch.randn(batchSize,2).to(device) #[b,2] 随机产生的噪声
xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()
predf =D(xf) #min predf
vf = torch.log(1e-4+1.0-predf)
lossf = vf.mean()
loss_D =-(lossr+lossf)
optim_D.zero_grad()
loss_D.backward()
optim_D.step()
#print("\n Discriminator 训练结束 ",loss_D.item())
# 2 train Generator,max V(G,D)
#2.1 train on fake data
z = torch.randn(batchSize, 2)
xf = G(z)
predf =D(xf) #max predf
vf = torch.log(1e-4+1.0-predf)
loss_G= predf.mean()
#optimize
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
if epoch %100 ==0:
viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
generate_image(D, G, xr, epoch)
print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())
if __name__ == "__main__":
main()
三 训练效果
里面的损失函数按照最早的论文里面的,跟其它版本有所区别
效果:
生成器G 训练的loss 最后稳定在一个固定值,无法更新生成器
鉴别器: 因为生成器很弱,很容易鉴别出真实数据 和 fake 数据,导致loss 也迅速降低为0
实际生成效果:
生成器生成出来的数据红色部分,和真实的数据分布绿色 有较大差距。
生成器很弱。
参考:
课时127 GAN实战-GD实现_哔哩哔哩_bilibili
https://www.cnblogs.com/cxq1126/p/13538409.html