一种简单的自编码器PyTorch代码实现

news2025/1/19 14:17:31

1. 引言

对于许多新接触深度学习爱好者来说,玩AutoEncoder总是很有趣的,因为它具有简单的处理逻辑、简易的网络架构,方便可视化潜在的特征空间。在本文中,我将从头开始介绍一个简单的AutoEncoder模型,以及一些可视化潜在特征空间的一些的方法,以便使本文变得生动有趣。

闲话少说,我们直接开始吧!

2. 数据集介绍

在本文中,我们使用FashionMNIST数据集来完成此任务。
在这里插入图片描述

以下是Kaggle上数据集的链接:戳我。
该数据集已在torchvision库中集成;我们可以通过几行代码直接导入和处理该数据集。

为此,首先需要是编写一个collate_fn函数,将数据集从PIL图像转换为torch张量,并进行相应的pad操作:

# This function convert the PIL images to tensors then pad them
def collate_fn(batch):
    process = transforms.Compose([
                transforms.ToTensor(),
                transforms.Pad([2])]
                )
    # x - images; we process each image in the batch
    x = [process(data[0]) for data in batch]
    x = torch.concat(x).unsqueeze(1)
    # y - labels, note that we should convert the labels to LongTensor
    y = torch.LongTensor([data[1] for data in batch])
    return x, y

3. 实现DataLoader

接着,我们就可以使用以下代码来完成相应的DataLoader的实现:

labels = ["T-shirt/top", "Trouser", "Pullover", "Dress","Coat", 
          "Sandla", "Shirt", "Sneaker", "Bag", "Ankle boot"]

# download/load dataset
train_data = FashionMNIST("./MNIST_DATA", train=True, download=True)
valid_data = FashionMNIST("./MNIST_DATA", train=False, download=True)

# put datasets into dataloaders
train_loader = DataLoader(train_data, batch_size=config["batch_size"], 
                          shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=config["batch_size"], 
                           shuffle=False, collate_fn=collate_fn)

接着我们可以使用以下代码来检验上述代码是否符合我们的预期,测试代码如下:

print("Inspecting train data: ")
for _, data in enumerate(train_loader):
    print("Batch shape: ", data[0].shape)
    fig, ax = plt.subplots(1, 4, figsize=(10, 4))
    for i in range(4):
        # Ture 3D tensor to 2D tensor due to image's single channel
        ax[i].imshow(data[0][i].squeeze(), cmap="gray")
        ax[i].axis("off")
        ax[i].set_title(labels[data[1][i]])
    plt.show()
    # And don't forget to break
    break

运行结果如下:
在这里插入图片描述
观察上图,图像和标签一一对应关系正常,接着我们就可以进入我们的网络设计部分。

4. 实现encoder

我们知道自编码器是由编码器encoder和解码器decoder实现的,其中编码器的作用为将输入的图像编码为特征空间的特征向量,解码器的作用相反,尽可能的将上述特征向量结果恢复为原图。基于此,我们首先来一步步实现编码器。首先,我们来定义模型的基本超参数如下:

# Model parameters:
LAYERS = 3
KERNELS = [3, 3, 3]
CHANNELS = [32, 64, 128]
STRIDES = [2, 2, 2]
LINEAR_DIM = 2048

同时相应的编码器的网络结构设计如下:

class Encoder(nn.Module):
    def __init__(self, output_dim=2, use_batchnorm=False, use_dropout=False):
        super(Encoder, self).__init__()
        # bottleneck dimentionality
        self.output_dim = output_dim
        # variables deciding if using dropout and batchnorm in model
        self.use_dropout = use_dropout
        self.use_batchnorm = use_batchnorm
        # convolutional layer hyper parameters
        self.layers = LAYERS
        self.kernels = KERNELS
        self.channels = CHANNELS
        self.strides = STRIDES
        self.conv = self.get_convs()
        # layers for latent space projection
        self.fc_dim = LINEAR_DIM
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(self.fc_dim, self.output_dim)
    def get_convs(self):
        """
        generating convolutional layers based on model's hyper parameters
        """
        conv_layers = nn.Sequential()
        for i in range(self.layers):
            # The input channel of the first layer is 1
            if i == 0: conv_layers.append(nn.Conv2d(1, 
                                              self.channels[i], 
                                              kernel_size=self.kernels[i],
                                              stride=self.strides[i],
                                              padding=1))
            
            else: conv_layers.append(nn.Conv2d(self.channels[i-1], 
                                         self.channels[i],
                                         kernel_size=self.kernels[i],
                                         stride=self.strides[i],
                                         padding=1))
            
            if self.use_batchnorm:
                conv_layers.append(nn.BatchNorm2d(self.channels[i]))
            
            # Here we use GELU as activation function
            conv_layers.append(nn.GELU()) 
            if self.use_dropout:
                conv_layers.append(nn.Dropout2d(0.15))
        return conv_layers
  
    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        return self.linear(x)

在Pytorch中torchsummary是一个非常方便的工具,用于检查和调试模型的网络结构;我们可以检查层、每层中的张量形状以及模型的参数。代码如下:

from torchsummary import summary
# Get the summary of autoencoder architecture
encoder = Encoder(use_batchnorm=True, use_dropout=True).to(DEVICE)
summary(encoder, (1, 32, 32))
pass

得到输出如下:
在这里插入图片描述

5. 实现decoder

在我们的例子中,解码器层decoder是编码器的反向操作;确保每一层的输入和输出形状是很重要的。此外,我们应该调整转置卷积层中的paddingoutput_pading参数,以确保输出图像和输入图像的维度相同。代码实现如下:

class Decoder(nn.Module):
    def __init__(self, input_dim=2, use_batchnorm=False, use_dropout=False):
        super(Decoder, self).__init__()
        # variables deciding if using dropout and batchnorm in model
        self.use_dropout = use_dropout
        self.use_batchnorm = use_batchnorm
        self.fc_dim = LINEAR_DIM
        self.input_dim = input_dim
        # Conv layer hypyer parameters
        self.layers = LAYERS
        self.kernels = KERNELS
        self.channels = CHANNELS[::-1] # flip the channel dimensions
        self.strides = STRIDES
        
        # In decoder, we first do fc project, then conv layers
        self.linear = nn.Linear(self.input_dim, self.fc_dim)
        self.conv =  self.get_convs()
        self.output = nn.Conv2d(self.channels[-1], 1, kernel_size=1, stride=1)

    def get_convs(self):
        conv_layers = nn.Sequential()
        for i in range(self.layers):
            if i == 0: conv_layers.append(
                            nn.ConvTranspose2d(self.channels[i],
                                               self.channels[i],
                                               kernel_size=self.kernels[i],
                                               stride=self.strides[i],
                                               padding=1,
                                               output_padding=1)
                            )
            
            else: conv_layers.append(
                            nn.ConvTranspose2d(self.channels[i-1], 
                                               self.channels[i],
                                               kernel_size=self.kernels[i],
                                               stride=self.strides[i],
                                               padding=1,
                                               output_padding=1
                                              )
                            )
            if self.use_batchnorm and i != self.layers - 1:
                conv_layers.append(nn.BatchNorm2d(self.channels[i]))
            conv_layers.append(nn.GELU())
            if self.use_dropout:
                conv_layers.append(nn.Dropout2d(0.15))
        return conv_layers
   
    def forward(self, x):
        x = self.linear(x)
        # reshape 3D tensor to 4D tensor
        x = x.reshape(x.shape[0], 128, 4, 4)
        x = self.conv(x)
        return self.output(x)

相应的解码器实现如下:

decoder = Decoder(use_batchnorm=True, use_dropout=True).to(DEVICE)
summary(decoder, (1, 2))
pass

运行后,结果如下:
在这里插入图片描述

6. 实现自编码器

接着,我们将上述编码器和解码器串联起来,代码实现如下:

class AutoEncoder(nn.Module):
    
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(output_dim=2, use_batchnorm=True, use_dropout=False)
        self.decoder = Decoder(input_dim=2, use_batchnorm=True, use_dropout=False)
        
    def forward(self, x):
        return self.decoder(self.encoder(x))

model = AutoEncoder().to(DEVICE)
summary(model, (1, 32, 32))
pass

得到结果如下:
在这里插入图片描述

7. 可视化函数

在进入训练部分之前,让我们花一些时间编写一个函数来可视化我们模型的潜在特征空间,即编码后二维特征向量的可视化表示。

def plotting(step:int=0, show=False):
    model.eval() # Switch the model to evaluation mode
    points = []
    label_idcs = []
    path = "./ScatterPlots"
    if not os.path.exists(path): os.mkdir(path)
    for i, data in enumerate(valid_loader):
        img, label = [d.to(DEVICE) for d in data]
        # We only need to encode the validation images
        proj = model.encoder(img)
        points.extend(proj.detach().cpu().numpy())
        label_idcs.extend(label.detach().cpu().numpy())
        del img, label
    
    points = np.array(points)
    # Creating a scatter plot
    fig, ax = plt.subplots(figsize=(10, 10) if not show else (8, 8))
    scatter = ax.scatter(x=points[:, 0], y=points[:, 1], s=2.0, 
                c=label_idcs, cmap='tab10', alpha=0.9, zorder=2)
    
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    
    if show: 
        ax.grid(True, color="lightgray", alpha=1.0, zorder=0)
        plt.show()
    else: 
        # Do not show but only save the plot in training
        plt.savefig(f"{path}/Step_{step:03d}.png", bbox_inches="tight")
        plt.close() # don't forget to close the plot, or it is always in memory
        model.train()

以下是训练过程中生成的图;该过程显示了模型的潜在空间随时间的分布,可以看出尽管有个别离群点,整体不同类别的数据在特征空间呈现出聚类趋势:
在这里插入图片描述

8. 损失函数

在编写训练和验证函数之前,还有一个步骤是定义目标函数和优化方法。由于自动编码器是一个自监督模型,输入也是网络输出重建图像逼近的对象,因此我们可以使用MSE(均方误差)损失来评估输入和重建图像之间的逐像素损失。当然有很多优化器可供选择,这里我选择的是AdamW,因为我在过去几个月里经常使用它。

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-5)

# For mixed precision training
scaler = torch.cuda.amp.GradScaler()
steps = 0 # tracking the training steps

9. 训练函数

接着我们来定义训练一个epoch的函数,代码实现如下:

def train(model, dataloader, criterion, optimizer, save_distrib=False):
    # steps is used to track training progress, purely for latent space plots
    global steps 
    model.train()
    train_loss = 0.0
    
    # Process tqdm bar, helpful for monitoring training process
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, 
                     leave=False, position=0, desc="Train")
    for i, batch in enumerate(dataloader):
        optimizer.zero_grad()
        x = batch[0].to(DEVICE)
        
        # Here we implement the mixed precision training
        with torch.cuda.amp.autocast():
            y_recons = model(x)
            loss = criterion(y_recons, x)
        
        train_loss += loss.item()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        batch_bar.set_postfix(
            loss=f"{train_loss/(i+1):.4f}",
            lr = f"{optimizer.param_groups[0]['lr']:.4f}"
        )
        batch_bar.update()        

        # Saving latent space plots
        if steps % 10 == 0 and save_distrib and steps <= 400: plotting(steps)
        steps += 1        
        
        # remove unnecessary cache in CUDA memory
        torch.cuda.empty_cache()
        del x, y_recons
    
    batch_bar.close()
    train_loss /= len(dataloader)

    return train_loss

10 验证函数

相应的验证函数的实现稍微简单一点,代码如下:

def validate(model, dataloader, criterion):
    model.eval() # Don't forget to turn the model to eval mode
    valid_loss = 0.0
    # Progress tqdm bar
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True,
                     leave=False, position=0, desc="Validation")
    
    for i, batch in enumerate(dataloader):
        x = batch[0].to(DEVICE)
        with torch.no_grad(): # we don't need gradients in validation
            y_recons = model(x)
            loss = criterion(y_recons, x)
        valid_loss += loss.item()
        batch_bar.set_postfix(
            loss=f"{valid_loss/(i+1):.4f}",
            lr = f"{optimizer.param_groups[0]['lr']:.4f}"
        )
        batch_bar.update()
        torch.cuda.empty_cache()
        del x, y_recons
    
    batch_bar.close()
    valid_loss /= len(dataloader)
    return valid_loss

11 训练过程

接着,我们将上述代码串起来,来实现我们模型的训练,由于FashionMNIST是一个很小的数据集,我们实际上不需要大量训练;初始训练和验证损失非常低,并且在三个epoch之后没有太大的改进空间。

for i in range(config["epochs"]):

    curr_lr = float(optimizer.param_groups[0]["lr"])
    train_loss = train(model, train_loader, criterion, 
                       optimizer, save_distrib=True)
    valid_loss = validate(model, valid_loader, criterion)

    print(f"Epoch {i+1}/{config['epochs']}\nTrain loss: {train_loss:.4f}\t Validation loss: {valid_loss:.4f}\tlr: {curr_lr:.4f}")

输出如下:
在这里插入图片描述

12 结果可视化

我们现在可以再次绘制和检查收敛后的特征空间,可视化输出如下:
在这里插入图片描述
观察上图可知,相应的聚类后的效果比训练过程中的要好,但有些个别类混合在同一集群中。这个问题可以通过增加编码器输出的特征向量的维度或使用其他损失函数函数来解决。

13 预测效果可视化

为了验证我们的解码器确实学到了东西,我们可以在随机绘制一些离散点来观察解码器重建图像的效果,代码如下:

# randomly sample x and y values
xs = [random.uniform(-6.0, 8.0) for i in range(8)]
ys = [random.uniform(-7.5, 10.0) for i in range(8)]

points = list(zip(xs, ys))
coords = torch.tensor(points).unsqueeze(1).to(DEVICE)
nrows, ncols = 2, 4
fig, axes = plt.subplots(nrows, ncols, figsize=(10, 5))
model.eval()
with torch.no_grad():
    generates = [model.decoder(coord) for coord in coords]
# plot points
idx = 0
for row in range(0, nrows):
    for col in range(0, ncols):
        ax = axes[row, col]
        im = generates[idx].squeeze().detach().cpu()
        ax.imshow(im, cmap="gray")
        ax.axis("off")
        coord = coords[idx].detach().cpu().numpy()[0]
        ax.set_title(f"({coord[0]:.3f}, {coord[1]:.3f})")
        idx += 1

plt.show()

代码输出如下:
在这里插入图片描述

14. 总结

本文重点介绍了如何利用Pytorch来实现自编码器,从数据集,到搭建网络结构,以及特征可视化和网络预测输出几个方面,分别进行了详细的阐述,并给出了相应的代码示例。

您学废了吗?

完整代码链接:戳我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1331033.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全渠道在线客服系统支持的沟通渠道:多渠道整合与无缝对接

我们在挑选客服系统的时候&#xff0c;经常会看到有些客服产品会强调自己是“全渠道客服系统”&#xff0c;那什么是全渠道客服系统呢&#xff1f; 1、什么是全渠道客服系统&#xff1f; 简单来讲&#xff0c;它是指能把某个客户在不同渠道的互动历史放到一起集中展现&#x…

rqt_graph使用说明

其中右边的&#xff1a;/rosout是一个topic 也就是一个话题 /rosout是一个topic 也是一个话题 可以看到凡是在rqt_graph里面用长方形标识的全都是话题 通过观察可以发现&#xff1a;凡是用椭圆标识的全都是节点 如果切换为Nodes only视图会发现&#xff1a; 所说的no…

SpringSecurity安全框架 ——认证与授权

目录 一、简介 1.1 什么是Spring Security 1.2 工作原理 1.3 为什么选择Spring Security 1.4 HttpSecurity 介绍&#x1f31f; 二、用户认证 2.1 导入依赖与配置 2.2 用户对象UserDetails 2.3 业务对象UserDetailsService 2.4 SecurityConfig配置 2.4.1 BCryptPasswo…

【数据结构入门精讲 | 第八篇】一文讲清全部排序算法(2)

在上一篇文章中我们介绍了冒泡排序、快速排序等算法&#xff0c;这一篇我们接着对排序算法的学习。 目录 归并排序堆排序选择排序计数排序基数排序排序总结 归并排序 归并排序是建立在归并操作上的一种有效&#xff0c;稳定的排序算法&#xff0c;该算法是采用分治法&#xff…

MySQL报错:1054 - Unknown column ‘xx‘ in ‘field list的解决方法

我在操作MySQL遇到1054报错&#xff0c;报错内容&#xff1a;1054 - Unknown column Cindy in field list&#xff0c;下面演示解决方法&#xff0c;非常简单。 根据箭头指示&#xff0c;Cindy对应的应该是VARCHAR文本数字类型&#xff0c;字符串要用引号&#xff0c;所以解决方…

【C语言】打印内存数据

C语言&#xff0c;用函数封装&#xff1a;16进制打印unsigned char *p指向的内存&#xff0c;长度为int l。16个字节&#xff0c;换一次行。16个字节用一个字符串缓存&#xff0c;一次打印。 以下是一个使用函数封装的C语言代码&#xff0c;用于以16进制格式打印unsigned char …

MySQL 事务的ACID特性

MySQL事务是什么&#xff0c;它就是一组数据库的操作&#xff0c;是访问数据库的程序单元&#xff0c;事务中可能包含一个或者多个 SQL 语句。这些SQL 语句要么都执行、要么都不执行。我们知道&#xff0c;在MySQL 中&#xff0c;有不同的存储引擎&#xff0c;有的存储引擎比如…

省时攻略:快速获得Creo安装包,释放创意天才!

不要再在网上浪费时间寻找Creo的安装包了&#xff0c;一键下载安装&#xff0c; 你要的一切都可以在这里找到&#xff01;我们深知在海量的信息中寻找合适的软件包并非易事&#xff0c;而且往往还伴随着繁琐的安装过程。然而&#xff0c;现在有了我们&#xff0c;一切变得轻松简…

【飞凌 OK113i-C 全志T113-i开发板】一些有用的常用的命令测试

一些有用的常用的命令测试 一、系统信息查询 可以查询板子的内核信息、CPU处理器信息、环境变量等 二、CPU频率 从上面的系统信息查询到&#xff0c;这是一颗具有两个ARMv7结构A7内核的处理器&#xff0c;主频最高1.2GHz 可以通过命令查看当前支持的频率以及目前所使用主频 …

爬虫工作量由小到大的思维转变---<第二十三章 Scrapy开始很快,越来越慢(医病篇)>

诊断篇https://blog.csdn.net/m0_56758840/article/details/135170994?ops_request_misc%257B%2522request%255Fid%2522%253A%2522170333243316800180644102%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id1703332433168001806441…

更改WiseAlign软件界面图标方法

更改WiseAlign软件界面图标方法 未替换时 首先将图片转换为BMP格式&#xff0c;在搜索栏处输入画图&#xff0c;点击打开画图工具 按住图标拖动到画布内&#xff0c;或是直接CtrlV将图标复制到画布内 点击文件&#xff0c;再点击另存为 保存类型选择“24位位图&#xff08;*.bm…

SpringBoot3-基础特性

文章目录 自定义 banner自定义 SpringApplicationFluentBuilder APIProfiles指定环境环境激活环境包含Profile 分组Profile 配置文件 外部化配置配置优先级 外部配置导入配置属性占位符 单元测试-JUnit5测试组件测试注解断言嵌套测试参数化测试 自定义 banner banner 就是启动…

MySQL数据库 触发器

目录 触发器概述 语法 案例 触发器概述 触发器是与表有关的数据库对象&#xff0c;指在insert/update/delete之前(BEFORE)或之后(AFTER)&#xff0c;触发并执行触发器中定义的soL语句集合。触发器的这种特性可以协助应用在数据库端确保数据的完整性&#xff0c;日志记录&am…

idea多光标无法取消

通常按住alt 鼠标左键。是多光标操作 但是不知道怎么按照了导致一直多光标 使用 altshiftinsert 取消多光标

【优质书籍推荐】LoRA微调的技巧和方法

大家好&#xff0c;我是爱编程的喵喵。双985硕士毕业&#xff0c;现担任全栈工程师一职&#xff0c;热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。…

【Unity基础】9.地形系统Terrain

【Unity基础】9.地形系统Terrain 大家好&#xff0c;我是Lampard~~ 欢迎来到Unity基础系列博客&#xff0c;所学知识来自B站阿发老师~感谢 &#xff08;一&#xff09;地形编辑器Terrain &#xff08;1&#xff09;创建地形 游戏场景中大多数的山川河流地表地貌都是基…

vue微乾坤子应用开发及ele组件开发时问题记录

一. 微乾坤 1. 新增page页面路由,pmi权限中心配置正常&#xff0c;跳转链接正确&#xff0c;但路由未找到403. 解决&#xff1a; 新增的配置是page类型&#xff0c;transformQianKunRoute方法转换微前端路由数据 时&#xff0c;过滤未兼容page型的路由&#xff0c; 解决 [menu,…

Git的总体认知与具体实现

GIt概念 是一种分布式控制管理器 tips:敏捷开发 -> 先上线&#xff0c;后续开发再继续开发 集中式和分布式 集中式的版本控制系统每次在写代码时都需要从服务器中拉取一份下来&#xff0c;并且如果服务器丢失了&#xff0c;那么所有的就都丢失了&#xff0c;你本机客户端仅…

神经网络:深度学习基础

1.反向传播算法&#xff08;BP&#xff09;的概念及简单推导 反向传播&#xff08;Backpropagation&#xff0c;BP&#xff09;算法是一种与最优化方法&#xff08;如梯度下降法&#xff09;结合使用的&#xff0c;用来训练人工神经网络的常见算法。BP算法对网络中所有权重计算…

MySQL代码笔记

欢迎来到Cefler的博客&#x1f601; &#x1f54c;博客主页&#xff1a;那个传说中的man的主页 &#x1f3e0;个人专栏&#xff1a;题目解析 &#x1f30e;推荐文章&#xff1a;题目大解析&#xff08;3&#xff09; 目录 &#x1f449;&#x1f3fb;表的增删查改创建表格&…