pytorch 自编码器实现图像的降噪

news2024/9/24 8:33:16

自编码器

自动编码器是一种无监督的深度学习算法,它学习输入数据的编码表示,然后重新构造与输出相同的输入。它由编码器和解码器两个网络组成。编码器将高维输入压缩成低维潜在(也称为潜在代码或编码空间) ,以从中提取最相关的信息,而解码器则解压缩编码数据并重新创建原始输入。

自编码器的输入和输出应该尽可能的相似。

通过输入含有噪声的图像,编码器在编码的过程中会存在信息丢失,将输入和输出最相似的特征保留下来,通过解码器得到最后的输出。在这个转换的过程中实现了图像的去噪。

自编码器主要的用途其实是用于降维,将高维的数据编码为一组向量,解码器通过解码得到输出。

数据集导入可视化

import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import numpy as np
import random
import PIL.Image as Image
import torchvision.transforms as transforms


class AddPepperNoise(object):
    """增加椒盐噪声
    Args:
        snr (float): Signal Noise Rate
        p (float): 概率值,依概率执行该操作
    """

    def __init__(self, snr, p=0.9):
        assert isinstance(snr, float) and (isinstance(p, float))
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            h, w = img_.shape
            signal_pct = self.snr
            noise_pct = (1 - self.snr)
            mask = np.random.choice((0, 1, 2), size=(h, w), p=[signal_pct, noise_pct/2., noise_pct/2.])

            img_[mask == 1] = 255   # 盐噪声
            img_[mask == 2] = 0     # 椒噪声
            return Image.fromarray(img_.astype('uint8'))
        else:
            return img

class Gaussian_noise(object):
    """增加高斯噪声
    此函数用将产生的高斯噪声加到图片上
    传入:
        img   :  原图
        mean  :  均值
        sigma :  标准差
    返回:
        gaussian_out : 噪声处理后的图片
    """
    def __init__(self, mean, sigma):

        self.mean = mean
        self.sigma = sigma

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """
        # 将图片灰度标准化
        img_ = np.array(img).copy()
        img_ = img_ / 255.0
        # 产生高斯 noise
        noise = np.random.normal(self.mean, self.sigma, img_.shape)
        # 将噪声和图片叠加
        gaussian_out = img_ + noise
        # 将超过 1 的置 1,低于 0 的置 0
        gaussian_out = np.clip(gaussian_out, 0, 1)
        # 将图片灰度范围的恢复为 0-255
        gaussian_out = np.uint8(gaussian_out*255)
        # 将噪声范围搞为 0-255
        # noise = np.uint8(noise*255)
        return Image.fromarray(gaussian_out)

train_datasets = torchvision.datasets.MNIST('./', train=True, download=True)
test_datasets = torchvision.datasets.MNIST('./', train=False, download=True)

print('训练集的数量', len(train_datasets))
print('测试集的数量', len(test_datasets))


train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)
test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

transform=transforms.Compose([
    transforms.ToPILImage(),
    Gaussian_noise(0,0.1),
    AddPepperNoise(0.9)
    # transforms.ToTensor()
])

print('训练集可视化')
fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i + 1)
    img = train_datasets.train_data[i]
    label = train_datasets.train_labels[i]
    # noise = np.random.normal(0.1, 0.1, img.shape)
    # img=transform(img)
    plt.imshow(img, cmap='gray')
    plt.title(label)
    plt.xticks([])
    plt.yticks([])
plt.show()

噪声图像 

 原始图像

模型的搭建

import torch
from torch import nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        # [b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )

        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )


    def forward(self, x):
        """
        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.size(0)
        # flatten(打平)
        x = x.view(batchsz, 784)
        # encoder
        x = self.encoder(x)
        # decoder
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x

if __name__=='__main__':
    model=AE()
    input=torch.randn(1,28,28)
    input=input.view(1,-1)
    print('输入的维度',input.shape)
    encoder_out=model.encoder(input)
    print('编码器的输出',encoder_out.shape)
    out=model.decoder(encoder_out)
    print('解码器的输出',out.shape)

 

模型的训练

导入训练集训练的时候一定要将使用transforms将所有图像转换为tensor格式,这里的方法不同于tensorflow导入MNIST方法,如果不加transforms则图像的格式为列表类型,下面在训练的时候会报错。

在训练过程中添加噪声。分别添加了高斯噪声和椒盐噪声

import torchvision
from torch.utils.data import DataLoader
import numpy as np
import random,os
import PIL.Image as Image
import torchvision.transforms as transforms
from torch import nn,optim
import torch
from models import AE
from tqdm import tqdm

train_datasets = torchvision.datasets.MNIST('./', train=True, download=True,transform=transforms.ToTensor())
test_datasets = torchvision.datasets.MNIST('./', train=False, download=True,transform=transforms.ToTensor())

print('训练集的数量', len(train_datasets))
print('测试集的数量', len(test_datasets))


train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)
test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

#模型,优化器,损失函数
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=AE().to(device)
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

##导入预训练模型
if os.path.exists('./model.pth') :
    # 如果存在已保存的权重,则加载
    checkpoint = torch.load('model.pth',map_location=lambda storage,loc:storage)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    initepoch = checkpoint['epoch']
    loss = checkpoint['loss']
else:
    initepoch=0

#开始训练
for epoch in range(initepoch, 50):
    with tqdm(total=(len(train_datasets)-len(train_datasets)), ncols=80) as t:
        t.set_description('epoch: {}/{}'.format(epoch, 50))

        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # get the inputs
            true_input, _ = data
            #生成均值为0,方差为0.1的高斯分布
            gaussian_noise=torch.normal(mean=0,std=0.1,size=true_input.shape)
            image_noise=true_input+gaussian_noise
            noise_tensor = torch.rand(size=true_input.shape)
            #添加椒盐噪声
            image_noise[noise_tensor<0.1]=0 #椒噪声
            image_noise[noise_tensor > (1-0.1)] = 1 #盐噪声
            #限制像素的范围在0-1之间
            image_noise=torch.clamp(image_noise,min=0,max=1)
            optimizer.zero_grad()

            outputs = model(image_noise)
            loss = criteon(outputs, true_input)
            loss.backward()

            optimizer.step()

            running_loss += loss.item()
            t.set_postfix(trainloss='{:.6f}'.format(running_loss/len(train_loader)))
            t.update(len(true_input))

    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': running_loss/len(train_loader)
                }, 'model.pth')

模型的测试

 在导入模型的时候经常发生上面的错误。模型在导入参数的时候不需要赋值操作。如果保存的方法是torch.load(model,'model.pth'),也就是直接保存模型的所有(包括模型的结构),在导入模型参数的时候可以使用model=torch.load('./model.pth')

import numpy as np
import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as transforms
from models import AE
from data import AddPepperNoise,Gaussian_noise

test_datasets = torchvision.datasets.MNIST('./', train=False, download=True)

print('测试集的数量', len(test_datasets))

test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

transform=transforms.Compose([
    transforms.ToPILImage(),
    Gaussian_noise(0,0.2),
    AddPepperNoise(0.9),
    transforms.ToTensor()
])
model=AE()
hh=torch.load('./model.pth',map_location=lambda storage,loc:storage)
model.load_state_dict(hh['model_state_dict'])
#错误写法
# model=model.load_state_dict(hh['model_state_dict'])

fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i + 1)
    img = test_datasets.train_data[i]
    label = test_datasets.test_labels[i]
    img_noise=transform(img)
    out=model(img_noise)
    out=out.squeeze()
    out=transforms.ToPILImage()(out)
    #原始图像,噪声图像,去噪图像
    plt.imshow(np.hstack((np.array(img),np.array(transforms.ToPILImage()(img_noise)),np.array(out))), cmap='gray')
    plt.title(label)
    plt.xticks([])
    plt.yticks([])
plt.show()

 生成随机数看看解码器能解码出什么

生成标准正太分布

import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from models import AE

model=AE()
model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])

fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i + 1)
    input=torch.randn(1,20)
    out=model.decoder(input)
    out=out.view(28,28)
    out=transforms.ToPILImage()(out)
    plt.imshow(out, cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.show()

 生成0-1之间的均匀分布

import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from models import AE

model=AE()
model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])

fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i + 1)
    input=torch.rand(1,20)
    out=model.decoder(input)
    out=out.view(28,28)
    out=transforms.ToPILImage()(out)
    plt.imshow(out, cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.show()

可以看到随机生成的数据用解码器解码得到的数据都很乱。接下来,看看编码器编码后的数据服从什么分布。

看看编码器编码的输出服从什么分布

import numpy as np
import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as transforms
from models import AE
from data import AddPepperNoise,Gaussian_noise

test_datasets = torchvision.datasets.MNIST('./', train=False, download=True)

print('测试集的数量', len(test_datasets))

test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)

transform=transforms.Compose([
    transforms.ToPILImage(),
    Gaussian_noise(0,0.2),
    AddPepperNoise(0.9),
    transforms.ToTensor()
])
model=AE()
hh=torch.load('./model.pth',map_location=lambda storage,loc:storage)
model.load_state_dict(hh['model_state_dict'])
#错误写法
# model=model.load_state_dict(hh['model_state_dict'])

fig = plt.figure()
for i in range(1):
    plt.subplot(1, 1, i + 1)
    img = test_datasets.test_data[i]
    label = test_datasets.test_labels[i]
    img_noise=transform(img)
    img_noise=img_noise.view(1,-1)
    out=model.encoder(img_noise)
    print('encoder的输出',out)
    #正太分布检验
    import scipy.stats as stats
    print(stats.shapiro(out.detach().numpy()))
    plt.imshow(img, cmap='gray')
    plt.title(label)
    plt.xticks([])
    plt.yticks([])
plt.show()

print('均值',torch.mean(out))
print('方差',torch.var(out))

 可以看到一张图片7是服从均值为2.5,方差为8.55的正太分布的。

然后生成一些类似的分布看看效果。

import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from models import AE

model=AE()
model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])

fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i + 1)
    input=torch.normal(mean=2.5928,std=8.5510,size=(1,20))
    out=model.decoder(input)
    out=out.view(28,28)
    out=transforms.ToPILImage()(out)
    plt.imshow(out, cmap='gray')
    plt.xticks([])
    plt.yticks([])
plt.show()

 其实效果挺差的,可能是因为一张图片的分布并不能代表所有吧。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/98240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringCloud之Hystrix

复杂分布式体系结构中的应用程序有数十个依赖关系&#xff0c;每个依赖关系在某些时候将不可避免失败&#xff01; 服务雪崩 多个微服务之间调用的时候&#xff0c;假设微服务A调用微服务B和微服务C&#xff0c;微服务B和微服务C又调用其他的微服务&#xff0c;这就是所谓的“…

HttpRunner3.x 安装与使用

HttpRunner3.x 安装与使用HttpRunner3.x 安装与使用安装使用运行脚手架项目方式一&#xff1a;录制生成用例步骤1&#xff1a;导出har文件步骤2&#xff1a;转化成测试用例文件har2casehmake步骤3&#xff1a;执行测试用例方式二&#xff1a;手工编写测试用例HttpRunner3.x 安装…

Jenkins自动部署springboot的Docker镜像,解决Status [1]问题

Jenkins凡是要指定路径的命令&#xff0c;一定要写绝对路径&#xff0c;不能写相对路径&#xff01;不要以为配置了Remote directory&#xff0c;那么命令就会在Remote directory下执行&#xff01;这种想法是错误的&#xff01;&#xff01;&#xff01; 《jenkins自动化发布到…

机器学习实战教程(五):朴素贝叶斯实战篇

一、前言 上篇文章机器学习实战教程&#xff08;四&#xff09;&#xff1a;朴素贝叶斯基础篇_M_Q_T的博客-CSDN博客讲解了朴素贝叶斯的基础知识。本篇文章将在此基础上进行扩展&#xff0c;你将看到以下内容&#xff1a; 拉普拉斯平滑垃圾邮件过滤(Python3)新浪新闻分类(skle…

毕业设计 - 基于Java的敬老院管理系统设计与实现【源码+论文】

文章目录前言一、项目设计1. 模块设计系统的主要功能性需求2. 实现效果二、部分源码项目源码前言 今天学长向大家分享一个 java web项目: 基于Java的敬老院管理系统设计与实现 一、项目设计 1. 模块设计 站在护工角度来看&#xff0c;他们迫切希望&#xff0c;在运用该系统…

【Flink】Flink Starting Offset 启动消费位置 指定时间消费

文章目录 1.概述2.测试3.源码1.概述 首先参考文章:【Flink】Flink 1.14.0 全新的 Kafka Connector Kafka Source 能够通过指定 OffsetsInitializer来消费从不同偏移量开始的消息。内置的初始值设定项包括: KafkaSource.builder()// Start from committed offset of the co…

【YOLOv7-环境搭建③】PyCharm安装和环境、解释器配置

下载链接&#xff1a; 来源&#xff1a;&#xff08;博主&#xff09;唐三. 链接:https://pan.baidu.com/s/1y6s_EScOqvraFcx7iPSy1g 提取码:m1oa 安装&#xff1a; 以管理员身份打开安装完成后&#xff0c;打开软件到达以下界面&#xff0c;框框全选到达以下界面&#xf…

【指纹识别】指纹识别匹配门禁系统【含GUI Matlab源码 587期】

⛄一、指纹识别简介 1 指纹识别的引入和原理 1.1 指纹的基本知识 指纹&#xff0c;由于其具有终身不变性、唯一性和方便性&#xff0c;已几乎成为生物特征识别的代名词。指纹是指人的手指末端正面皮肤上凸凹不平产生的纹线。纹线有规律的排列形成不同的纹型。纹线的起点、终点…

kotlin基础学习笔记第九章——泛型

实化类型参数允许你在运行时的内联函数中引用作为类型实参的具体类型&#xff08;对普通的类和函数来说&#xff0c;这样行不通&#xff0c;因为类型实参在运行时会被擦除&#xff09;。 声明点变型可以说明一个带类型参数的泛型类型&#xff0c;是否是另一个泛型类型的子类型或…

什么是MySQL插入意向锁?

Insert Intention Lock&#xff0c;中文我们也称之为插入意向锁。 这个可以算是对我们之前所讲的 Gap Lock 的一个补充&#xff0c;关于 Gap Lock&#xff0c;如果还有小伙伴不懂&#xff0c;可以参考&#xff1a;记录锁、间隙锁与 Next-Key Locks。 1. 为什么需要插入意向锁…

吃透Chisel语言.40.Chisel实战之单周期RISC-V处理器实现(下)——具体实现和最终测试

Chisel实战之单周期RISC-V处理器实现&#xff08;下&#xff09;——具体实现和最终测试 上一篇文章中我们对本项目的需求进行了分析&#xff0c;并得到了初步的设计&#xff0c;这一篇文章我们就可以基于该设计来实现我们的单周期RISC-V处理器了。实现之后也必须用实际代码来…

[ 数据结构 -- 手撕排序算法第三篇 ] 希尔排序

文章目录前言一、常见的排序算法二、希尔排序2.1 希尔排序(缩小增量排序)2.1.1 预排序阶段2.1.2 插入排序阶段2.2 单趟希尔排序2.2.1 思路分析三、希尔排序实现代码四、希尔排序的时间复杂度五、希尔排序和直接插入排序效率测试5.1 测试5.2 结论5.2.1 随机数比较5.2.2 有序数组…

【二维码识别】灰度+二值化+校正二维码生成与识别【含GUI Matlab源码 635期】

⛄一、二维码生成与识别简介 如今,移动互联网技术日新月异,随着5G时代的来临,广泛应用于数据处理过程中的二维码信息安全日益成为人们越来越关注的问题。以QR码为代表的二维码,以其在信息存储、传输和识别技术领域优异的表现,成为信息共享、移动支付等领域的宠儿。不可避免地,…

利用深度学习生成数据的时间序列预测(Matlab代码实现)

目录 &#x1f4a5;1 概述 &#x1f4da;2 运行结果 &#x1f389;3 参考文献 &#x1f468;‍&#x1f4bb;4 Matlab代码 &#x1f4a5;1 概述 数据分析研究目前仍是行业热点,相关学者从数据分析关键技术中的异常检测、入侵检测、时间序列预测等角度展开研究。然而,现有研…

Go环境搭建与IDE开发工具配置

安装Go语言编译器 Go语言编译器》编译器将源代码编译为可执行程序》源代码程序员使用高级语言所书写的代码文件》高级语言c/c/go…》机器语言0和1构成&#xff0c;机器能直接识别》汇编语言比机器语言稍微可读一点点的指令集 编译器下载地址 根据系统下载对应的go编译器版本…

微信小程序保存相册授权全过程:第一次授权、已授权、拒绝后再授权

微信小程序部分功能需要使用授权&#xff08;也就是需要用户显式同意&#xff0c;系统会阻止开发者任何静默获取授权行为&#xff09;&#xff0c;以存储相册为例&#xff0c;用户需要获得"scope.writePhotosAlbum"权限 微信系统接口wx.getSetting可以获取已经获得的…

MySQL连接数据库

①MySQLpymysql ②django开发操作数据库&#xff0c;orm框架 安装第三方模块&#xff1a;orm pip install mysqlclient ORM Django链接数据库 在settings.py中修改 查看创建的数据库的端口号和用户名&#xff1a; Django操作表&#xff1a; 创建表 models.py from django…

[附源码]Python计算机毕业设计Django新冠疫苗接种预约系统

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

PDF怎么插入页?将页面添加到 PDF 文档的 3 种简单方法

得益于现代技术&#xff0c;我们现在可以轻松地合并、创建、编辑 PDF 并执行更多操作。使用专业的PDF程序在PDF文档中插入一页问题不大。这篇文章将介绍如何使用 奇客PDF编辑 和其他四个桌面和在线程序向 PDF 添加页面。 如何使用桌面程序将页面添加到 PDF 毫无疑问&#…

Simulink基础【2】- PID控制器

Simulink基础【2】- PID控制器1. Simulink作用回顾1.1 模块化1.2 常用模块1.2.1 输入信号源模块库&#xff08;Sources&#xff09;1.2.2 接收模块库&#xff08;Sinks&#xff09;1.2.3 系统模块1.2.4 数学运算模块1.3 界面布局与使用1.4 自定义模块2. PID算法仿真2.1 PID算法…