GAN实例基于神经网络

news2025/1/19 20:31:59

目录

1.前言

2.实验


1.前言

        需要了解GAN的原理查看对抗生成网络(GAN),DCGAN原理。

        采用手写数字识别数据集

2.实验

import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1682215.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

创维汽车总经理培训正式开展,打造新能源汽车销售的精英战队

在新能源汽车市场竞争日益激烈的背景下,创维汽车为加强核心竞争力,于2024年5月15日至17日在河南省安阳市举办了为期三天的总经理岗位认证培训。此次培训旨在强化经销商店端负责人们在新能源汽车销售与运营方面的能力,指明未来发展思路&#x…

(5.4–5.10)投融资周报|共38笔公开投融资事件,基础设施领跑,游戏融资活跃

5月4日至5月10日期间,加密市场共发生38笔投融资事件,其中基础设施18笔、游戏5 笔、其他4 笔、DeFi 3笔、Depin 3 笔、CeFi 2笔、NFT2笔、 RWA1笔。 本周千万美金以上融资有5笔: 加密货币交易公司Arbelos完成了一轮2800 万美元的种子轮融资&…

【极简】docker常用操作

镜像images是静态的 容器container是动态的,是基于镜像的,类似于一个进程。 查看docker images: docker images 或者docker image ls 查看docker container情况:docker ps -a,-a意思是--all 运行一个container: doc…

Python程序设计 文件处理(二)

实验十二 文件处理 第1关:读取宋词文件,根据词人建立多个文件 读取wjcl/src/step1/宋词.txt文件, 注意:宋词文件的标题行的词牌和作者之间是全角空格(" ")可复制该空格 在wjcl/src/step3/cr文件夹下根据每…

【WEEK12】 【DAY3】整合MyBatis框架【中文版】

2024.5.15 Wednesday 目录 13.整合MyBatis框架13.1.整合测试13.1.1.新建springboot-05-mybatis项目13.1.2.导入MyBatis需要的依赖13.1.3.配置数据库连接信息13.1.3.1.修改application.properties13.1.3.2.修改Springboot05MybatisApplicationTests.java并测试 13.1.4.新建pojo文…

Spring使用小技巧--排除bean无法被调用问题

我们在项目中可能由于项目的复杂性,创建了个spring的bean,但是调用却出现报错,显示无法找到该bean的异常。 这个时候我们就需要找到出错的原因,很多人往往会忽略的一点就是,你所创建的bean有可能并没有被加载到ioc容器…

【Linux系统编程】第十九弹---进程状态(下)

​​​​​​​ ✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】 目录 1、僵尸进程 2、孤儿进程 3、运行状态 4、阻塞状态 5、挂起状态 6、进程切换 总结 1、僵尸进程 上一弹…

算法练习第22天|39. 组合总和、40.组合总和II

39. 组合总和 39. 组合总和 - 力扣(LeetCode)https://leetcode.cn/problems/combination-sum/description/ 题目描述: 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target ,找出 candidates 中可以使数字和为目标数…

Unity Mirror 从入门到入神(一)

Mirror从入门到成神 文章目录 Mirror从入门到成神简介NetworkClientRegisterPrefabConnect (string address)Disconnect ()activeactiveHost NetworkServerSpawn 简介 Mirror是一个unity网络同步框架,基于MonoBehaviour生命周期的回调的基础上进行数值的同步&#…

一个强大的在线解析网站,无需登录,只用把视频链接粘贴进去就能免费解析下载视频。

TiQu.cc是什么? TiQu.cc是一个强大的在线工具,让用户可以从包括Facebook、VK、Twitter、Tiktok、Instagram等在内的100多个平台下载他们喜爱的视频。不论是音乐、电视节目、电影、短片还是个人上传的内容,TiQu.cc都可以帮助您随时随地以离线…

什么是检索增强生成(Retrieval Augmented Generation)?RAG 架构如何实现?

检索增强生成(Retrieval Augmented Generation)时代 在不断发展的生成人工智能世界中,检索增强生成 (RAG) 标志着一项重大进步,它将检索模型的准确性与生成模型的创造性相结合,达到了准确&创新的更高层级。 这种…

摸鱼大数据——Linux搭建大数据环境(Hadoop高可用环境搭建)六

Hadoop高可用环境搭建 确定提前安装好了hadoop和zookeeper 1.删除原有数据文件 三台机器都要进行删除 可以使用CRT发送交互到所有会话 rm -rf /export/data/hadoop-3.3.0 2.安装软件 三台机器都要进行安装 注意: 如果网络较慢安装失败,那就重复安装即可 # 实现多个服务的通讯 …

数字水印 | 奇异值分解 SVD 的 Python 代码实现

🥑原理:数字水印 | 奇异值分解 SVD 的定义、原理及性质 🥑参考:Python 机器学习笔记:奇异值分解(SVD)算法 正文 对于一个图像矩阵,我们总可以将其分解为以下形式: 通过…

Halcon 根据XYZ生成3D模型

Halcon 根据XYZ生成3D模型 x_points := [a_x_points, b_x_points, c_x_points]y_points := [a_y_points, b_y_points, c_y_points]z_points := [a_z_points, b_z_points, c_z_points]stop()gen_object_model_3d_from_points

某单位Oracle数据库性能优化方案参考

内容分析: 本文是一篇关于XX市XX单位中心数据库优化方案的详细报告。文章首先描述了数据库的现状,包括其运行的软件环境、硬件环境、数据存储情况以及与检测点的连接方式。接着,文章列出了信息系统优化的常用策略,并具体解释了每一…

线性回归模型之套索回归

概述 本案例是基于之前的岭回归的案例的。之前案例的完整代码如下: import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import Ridge, LinearRegression from sklearn.datasets import make_regression from sklearn.model_selectio…

对抗生成网络(GAN),DCGAN原理

目录 1. GAN基础原理 1.1 生成器 1.2 判别器 1.3 整体架构 2. 损失函数 3. DCGAN 3.1 问题 3.2 解决 1. GAN基础原理 GAN(Generative Adversarial Nets)是一种深度神经网络架构。它由生成器和判别器组成,生成器学习真实样本&#x…

微信公众号自定义分销商城小程序源码系统 带完整的安装代码吧以及系统部署搭建教程

系统概述 微信公众号自定义分销商城小程序源码系统是一款功能强大的电商解决方案,它集成了商品管理、订单处理、支付接口、分销管理等多种功能。该系统支持自定义界面设计,商家可根据自身需求调整商城的页面布局和风格,打造独特的品牌形象。…

月薪20K+的策划人简历应该怎么写?

一般咱们大多数策划在写简历前,都是先直接找模板,然后按照模板的框架直接往里面填内容。 最后草草收场,直接拿去海投简历,结果发现没有拿到任何面试邀约。 策划写简历前的第一件事要梳理自己的能力模型和岗位JD。 因为只有先梳…

解决谷歌浏览器无法登陆网站的问题,左下角弹出JavaScript(void:0)

破釜沉舟,全都试一遍: 如果还不行,那就关闭GPU加速:关了瞬间就好了 关闭之后,再打开GPU加速还是行的(咱也不知道为啥呀)