【Pytorch】学习记录分享11——GAN对抗生成网络

news2024/9/30 9:24:40

PyTorch GAN对抗生成网络

      • 0. 工程实现
      • 1. GAN对抗生成网络结构
      • 2. GAN 构造损失函数(LOSS)
      • 3. GAN对抗生成网络核心逻辑
        • 3.1 参数加载:
        • 3.2 生成器:
        • 3.3 判别器:

0. 工程实现

原理解析:
论文解析:GAN:Generative Adversarial Nets

1. GAN对抗生成网络结构

在这里插入图片描述

2. GAN 构造损失函数(LOSS)

LOSS公式与含义:
在这里插入图片描述
LOSS代码实现:

import torch
from torch import autograd
input = autograd.Variable(torch.tensor([[ 1.9072,  1.1079,  1.4906],
        [-0.6584, -0.0512,  0.7608],
        [-0.0614,  0.6583,  0.1095]]), requires_grad=True)
print(input)
print('-'*100)

from torch import nn
m = nn.Sigmoid()
print(m(input))
print('-'*100)

target = torch.FloatTensor([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
print(target)
print('-'*100)

import math

r11 = 0 * math.log(0.8707) + (1-0) * math.log((1 - 0.8707))
r12 = 1 * math.log(0.7517) + (1-1) * math.log((1 - 0.7517))
r13 = 1 * math.log(0.8162) + (1-1) * math.log((1 - 0.8162))

r21 = 1 * math.log(0.3411) + (1-1) * math.log((1 - 0.3411))
r22 = 1 * math.log(0.4872) + (1-1) * math.log((1 - 0.4872))
r23 = 1 * math.log(0.6815) + (1-1) * math.log((1 - 0.6815))

r31 = 0 * math.log(0.4847) + (1-0) * math.log((1 - 0.4847))
r32 = 0 * math.log(0.6589) + (1-0) * math.log((1 - 0.6589))
r33 = 0 * math.log(0.5273) + (1-0) * math.log((1 - 0.5273))

r1 = -(r11 + r12 + r13) / 3
#0.8447112733378236
r2 = -(r21 + r22 + r23) / 3
#0.7260397266631787
r3 = -(r31 + r32 + r33) / 3
#0.8292933181294807
bceloss = (r1 + r2 + r3) / 3 
print(bceloss)
print('-'*100)

loss = nn.BCELoss()
print(loss(m(input), target))
print('-'*100)

loss = nn.BCEWithLogitsLoss()
print(loss(input, target))

loss BCEloss代码逐行运行结果
在这里插入图片描述

3. GAN对抗生成网络核心逻辑

整个简单的gan网络代码由以下三个部分组成:

3.1 参数加载:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

3.2 生成器:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img
3.3 判别器:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))


Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1363151.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LLM Agent之RAG的反思:放弃了压缩还是智能么?

已经唠了三章的RAG,是时候回头反思一下,当前的RAG是解决幻觉的终点么?我给不出直接的答案,不过感觉当前把RAG当作传统搜索框架在大模型时代下的改良,这个思路的天花板高度有限~ 反思来源于对RAG下模型回答的直观感受&…

DrGraph原理示教 - OpenCV 4 功能 - 直方图

OpenCV直方图是一种可以对整幅图的灰度分布进行整体了解的图示。它是带有像素值(从0到255,不总是)的图在X轴上,在y轴上的图像对应的像素个数。通过观察图像的直方图,我们可以直观的了解图像的对比度、亮度、亮度分布等…

数据分析——快递电商

一、任务目标 1、任务 总体目的——对账 本项目解决同时使用多个快递发货,部分隔离区域出现不同程度涨价等情形下,如何快速准确核对账单的问题。 1、在订单表中新增一列【运费差异核对】来表示订单运费实际有多少差异,结果为数值。 2、将…

HarmonyOS 开发基础(六)Slider

HarmonyOS 开发基础(六)Slider Entry Component struct Index {build() {Row() {Column() {// Slider:ArkUI 的基础组件 滑动条组件// options 参数:Slider 基础设置Slider({// 最小值min: 20,// 最大值max: 200,// 当前值value: …

文心一言 VS 讯飞星火 VS chatgpt (173)-- 算法导论13.3 2题

二、用go语言,将关键字41、38、31、12、19、8连续地插入一棵初始为空的红黑树之后,试画出该结果树。 文心一言: 首先,我们需要了解红黑树的性质。红黑树是一种自平衡二叉查找树,其中每个节点要么是红色,要…

使用sentinel作为熔断器

什么是sentinel Sentinel,中文翻译为哨兵,是为微服务提供流量控制、熔断降级的功能,它和Hystrix提供的功能一样,可以有效的解决微服务调用产生的“雪崩”效应,为微服务系统提供了稳定性的解决方案。随着Hytrxi进入了维…

7+非肿瘤+WGCNA+机器学习+诊断模型,构思巧妙且操作简单

今天给同学们分享一篇生信文章“Platelets-related signature based diagnostic model in rheumatoid arthritis using WGCNA and machine learning”,这篇文章发表在Front Immunol期刊上,影响因子为7.3。 结果解读: DEGs和血小板相关基因的…

Flink自定义Source模拟数据流

maven依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.…

CSS3 边框border、outline、box-shadow

1 border 语法&#xff1a;border: width style color 2 outline 语法&#xff1a;outline: width style color 2.1 outline-offet MDN解释&#xff1a;用于设置outline与一个元素边缘或边框之间的间隙 即&#xff1a;设置outline相对border外边缘的偏移&#xff0c;可以为…

Python 全栈体系【四阶】(十一)

第四章 机器学习 机器学习&#xff1a; 传统的机器学习&#xff1a;以算法为核心深度学习&#xff1a;以数据和计算为核心 感知机 perceptron&#xff08;人工神经元&#xff09; 可以做简单的分类任务掀起了第一波 AI 浪潮 感知机不能解决线性不可分问题&#xff0c;浪潮…

RouterOS L2TP安装与配置

申明&#xff1a;本文仅针对国内L2TP/PPTP&#xff0c;适用于国内的游戏加速或学术研究&#xff0c;禁止一切利用该技术的翻墙行为。 1. L2TP介绍 L2TP&#xff08;Layer 2 Tunneling Protocol&#xff09;是一种在计算机网络中广泛使用的隧道协议&#xff0c;它被设计用于通过…

java火车查询管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java Web火车查询管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql…

2023年12月 C/C++(一级)真题解析#中国电子学会#全国青少年软件编程等级考试

C/C++编程(1~8级)全部真题・点这里 第1题:数的输入和输出 输入一个整数和双精度浮点数,先将浮点数保留2位小数输出,然后输出整数。 时间限制:1000 内存限制:65536 输入 一行两个数,分别为整数N(不超过整型范围),双精度浮点数F,以一个空格分开。 输出 一行两个数,分…

电子学会C/C++编程等级考试2023年12月(二级)真题解析

C/C++编程(1~8级)全部真题・点这里 第1题:统计指定范围里的数 给定一个数的序列S,以及一个区间[L, R], 求序列中介于该区间的数的个数,即序列中大于等于L且小于等于R的数的个数。 时间限制:1000 内存限制:65536 输入 第一行1个整数n,分别表示序列的长度。(0 < n ≤…

使用Enterprise Architect绘制架构图

如何使用Enterprise Architect绘制架构图 之前没有使用过Enterprise Architect软件绘制&#xff0c;目前由于工作需求&#xff0c;需要使用Enterprise Architect绘制一些架构图&#xff0c;现在只使用Enterprise Architect绘制过简单的Flow Chart&#xff0c;想请教一下大神们…

如何批量自定义视频画面尺寸

在视频制作和编辑过程中&#xff0c;对于视频画面尺寸的调整是一项常见的需求。有时候&#xff0c;为了适应不同的播放平台或满足特定的展示需求&#xff0c;我们需要对视频尺寸进行批量调整。那么&#xff0c;如何实现批量自定义视频画面尺寸呢&#xff1f;本文将为您揭示这一…

纳尼??Rabbitmq居然被一个逗号给坑了??

转载说明&#xff1a;如果您喜欢这篇文章并打算转载它&#xff0c;请私信作者取得授权。感谢您喜爱本文&#xff0c;请文明转载&#xff0c;谢谢。 前言 这个问题发生在部署一套新的环境。搭建一个单节点的Rabbitmq&#xff0c;按照小伙伴写的部署文档搭建的。其中搭建步骤和我…

JetCache源码解析——配置加载

JetCache自动化配置加载 JetCache的配置加载主要是在jetcache-autoconfigure模块中完成的&#xff0c;无论是使用内存缓存LinkedHashMap和caffeine&#xff0c;亦或是通过lettuce、redisson和spring-data-redis来操作Redis服务缓存数据&#xff0c;其自动加载配置的操作基本上…

JSON Crack数据可视化工具结合内网穿透实现公网访问

文章目录 1. 在Linux上使用Docker安装JSONCrack2. 安装Cpolar内网穿透工具3. 配置JSON Crack界面公网地址4. 远程访问 JSONCrack 界面5. 固定 JSONCrack公网地址 JSON Crack 是一款免费的开源数据可视化应用程序&#xff0c;能够将 JSON、YAML、XML、CSV 等数据格式可视化为交互…

redis介绍与数据类型(一)

redis介绍与数据类型 1、redis1.1、数据库分类1.2、NoSQL分类1.3、redis简介1.4、redis应用1.5、如何学习redis 2、redis的安装2.1、Windows安装2.2.1、客户端redis管理工具 2.2、Linux安装&#x1f525;2.2.1、redis核心文件2.2.2、启动方式2.2.3、redis桌面客户端1、redis命令…