PyTorch深度学习实战——使用卷积神经网络执行图像分类

news2024/9/22 4:19:28

PyTorch深度学习实战——使用卷积神经网络执行图像分类

    • 0. 前言
    • 1. Fashion-MNIST 数据集图像分类
    • 2. 模型测试
    • 相关链接

0. 前言

我们已经在《卷积神经网络详解》一节中介绍了传统神经网络在面对图像平移时的问题以及卷积神经网络 (Convolutional Neural Network, CNN) 的工作原理。CNN 的关键思想是通过卷积操作来提取输入数据中的特征,并使用池化操作进行降采样,以逐渐减少参数数量,从而减少计算量并提高模型的效率。在本节中,将介绍 CNN 在图像平移后如何解决错误预测的问题。

1. Fashion-MNIST 数据集图像分类

Fashion-MNIST 数据集的预处理部分与《使用PyTorch构建神经网络》一节中的代码相同,但当我们整形 (.view) 输入数据时,不是将输入展平为 28 x 28 = 784 维,而是将每个输入图像的形状重塑为 (1,28,28) (需要特别注意的是,在 PyTorch 中,首先要指定通道,然后是高度和宽度),因为卷积神经网络期望每个输入的形状为批大小 x 通道 x 高度 x 宽度

(1) 导入必要的库和数据集:

from torchvision import datasets
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data_folder = './data/FMNIST'
fmnist = datasets.FashionMNIST(data_folder, download=True, train=True)

tr_images = fmnist.data
tr_targets = fmnist.targets

val_fmnist = datasets.FashionMNIST(data_folder, download=True, train=False)
val_images = val_fmnist.data
val_targets = val_fmnist.targets

(2) Fashion-MNIST 数据集类定义如下:

class FMNISTDataset(Dataset):
    def __init__(self, x, y):
        x = x.float()/255
        x = x.view(-1,1,28,28)
        self.x, self.y = x, y 
    def __getitem__(self, ix):
        x, y = self.x[ix], self.y[ix]        
        return x.to(device), y.to(device)
    def __len__(self): 
        return len(self.x)

(3) 定义 CNN 模型架构,并打印模型摘要:

from torch.optim import SGD, Adam
def get_model():
    model = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Conv2d(64, 128, kernel_size=3),
        nn.MaxPool2d(2),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3200, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    ).to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-3)
    return model, loss_fn, optimizer

def train_batch(x, y, model, optimizer, loss_fn):
    prediction = model(x)
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

@torch.no_grad()
def accuracy(x, y, model):
    model.eval()
    prediction = model(x)
    max_values, argmaxes = prediction.max(-1)
    is_correct = argmaxes == y
    return is_correct.cpu().numpy().tolist()

def get_data():     
    train = FMNISTDataset(tr_images, tr_targets)     
    trn_dl = DataLoader(train, batch_size=32, shuffle=True)
    val = FMNISTDataset(val_images, val_targets)     
    val_dl = DataLoader(val, batch_size=len(val_images), shuffle=True)
    return trn_dl, val_dl

@torch.no_grad()
def val_loss(x, y, model, loss_fn):
    prediction = model(x)
    val_loss = loss_fn(prediction, y)
    return val_loss.item()

from torch import optim
trn_dl, val_dl = get_data()
model, loss_fn, optimizer = get_model()

from torchsummary import summary
model, loss_fn, optimizer = get_model()
print(summary(model, tuple([1,28,28])))

模型架构摘要信息如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 26, 26]             640
         MaxPool2d-2           [-1, 64, 13, 13]               0
              ReLU-3           [-1, 64, 13, 13]               0
            Conv2d-4          [-1, 128, 11, 11]          73,856
         MaxPool2d-5            [-1, 128, 5, 5]               0
              ReLU-6            [-1, 128, 5, 5]               0
           Flatten-7                 [-1, 3200]               0
            Linear-8                  [-1, 256]         819,456
              ReLU-9                  [-1, 256]               0
           Linear-10                   [-1, 10]           2,570
================================================================
Total params: 896,522
Trainable params: 896,522
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.69
Params size (MB): 3.42
Estimated Total Size (MB): 4.11
----------------------------------------------------------------
  • 1 个网络层中:有 64 个卷积核大小为 3 的滤波器,因此有 64 x 3 x 3 个权重和 64 x 1 个偏置,共有 640 个参数
  • 4 个网络层中:有 128 个卷积核大小为 3 的滤波器,因此有 128 x 64 x3 x 3 个权重和 128 x 1 个偏置,共有 73,856 个参数
  • 8 个网络层:有 3,200 个节点的网络层连接到具有 256 个节点的另一网络层,因此有 3,200 x 256 个权重和 256 个偏置,共有 819,456 个参数
  • 10 个网络层:有 256 个节点的网络层连接到具有 10 个节点的另一网络层,因此共有 256 x 10 个权重和 10 个偏置,共有 2,570 个参数

(4) 训练模型完成后,可视化训练和测试数据集的准确率和损失变化:

train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []
for epoch in range(10):
    print(epoch)
    train_epoch_losses, train_epoch_accuracies = [], []
    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        batch_loss = train_batch(x, y, model, optimizer, loss_fn)
        train_epoch_losses.append(batch_loss)        
    train_epoch_loss = np.array(train_epoch_losses).mean()

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        is_correct = accuracy(x, y, model)
        train_epoch_accuracies.extend(is_correct)
    train_epoch_accuracy = np.mean(train_epoch_accuracies)

    for ix, batch in enumerate(iter(val_dl)):
        x, y = batch
        val_is_correct = accuracy(x, y, model)
        validation_loss = val_loss(x, y, model, loss_fn)
    val_epoch_accuracy = np.mean(val_is_correct)

    train_losses.append(train_epoch_loss)
    train_accuracies.append(train_epoch_accuracy)
    val_losses.append(validation_loss)
    val_accuracies.append(val_epoch_accuracy)

epochs = np.arange(10)+1
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
plt.subplot(121)
plt.plot(epochs, train_losses, 'bo', label='Training loss')
plt.plot(epochs, val_losses, 'r', label='Validation loss')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation loss with CNN')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.subplot(122)
plt.plot(epochs, train_accuracies, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation accuracy')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation accuracy with CNN')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()]) 
plt.legend()
plt.grid('off')
plt.show()

模型性能监测

在上图中,可以看到验证数据集在前 5epoch 内的准确率可以达到 92% 左右,即时没有使用额外的正则化技术,也已经比使用了增强技术的全连接网络的准确率更好。

2. 模型测试

接下来,利用训练完成的 CNN 网络预测平移图像的类别:

preds = []
ix = 24150
for px in range(-5,6):
    img = tr_images[ix]/255.
    img = img.view(28, 28)
    plt.subplot(1, 11, px+6)
    img2 = np.roll(img, px, axis=1)
    img3 = torch.Tensor(img2).view(-1,1,28,28).to(device)
    np_output = model(img3).cpu().detach().numpy()
    pred = np.exp(np_output)/np.sum(np.exp(np_output))
    preds.append(pred)
    plt.imshow(img2)
    plt.title(fmnist.classes[pred[0].argmax()])

plt.show()

在以上代码中,对图像 (img3) 进行整形,使其形状转换为 (-1,1,28,28),以便将图像输入到 CNN 模型中。

可视化平移图像的类别概率:

import seaborn as sns
fig, ax = plt.subplots(1,1, figsize=(12,10))
plt.title('Probability of each class for various translations')
sns.heatmap(np.array(preds).reshape(11,10), annot=True, ax=ax, fmt='.2f', xticklabels=fmnist.classes, yticklabels=[str(i)+str(' pixels') for i in range(-5,6)], cmap='gray')
plt.show()

类别概率

在上图中可以看出,即使图像平移了 4 个像素,也可以得到正确的预测结果,而全连接网络中,当图像被平移 4 个像素时,输出了完全错误的预测结果。而当图像平移 5 个像素时,CNN 预测 “Trouser” 的概率会大幅降低。因此虽然 CNN 有助于解决图像平移的问题,但它们并不能完全解决该问题,在之后的学习中,我们将学习如何利用数据增强和 CNN 来解决此问题。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/887980.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSS 字体修饰属性

前言 字体修饰属性 属性说明font-family指定文本显示字体font-size设置字体的大小font-weight设置字体的粗细程度font-style设置字体的倾斜样式text-decoration给文本添加装饰线text-indent设置文本的缩进text-align设置文本的对齐方式line-height设置行高color设置文本的颜色…

IDEA常用插件推荐(个人)

分享下个人在大厂工作四五年的一个常用配置插件 一、Alibaba Java Coding Guidelines 代码规范插件(必备) 阿里巴巴代码规范检查 人手必备。减少你的垃圾代码 各种不良提示代码全靠它了。 代码划线的嘎嘎 crtlenter优化得了 二、Atom Material File Icons 图标主题插件(提示…

Java学习手册——第二篇面向对象程序设计

Java学习手册——第二篇面向对象 1. 结构化程序设计2. 面向对象 第一章我们已经介绍了Java语言的基础知识,也知道他能干什么了, 那我们就从他的设计思想开始入手吧。 接触一个语言之前首先要知道他的大方向,设计思想是什么样的, 这…

【高阶数据结构】红黑树详解

文章目录 前言1. 红黑树的概念及性质1.1 红黑树的概念1.2 红黑树的性质1.3 已经学了AVL树,为啥还要学红黑树 2. 红黑树结构的定义3. 插入(仅仅是插入过程)4. 插入结点之后根据情况进行相应调整4.1 cur为红,p为红,g为黑…

Redis——哨兵模式(docker部署redis哨兵)+缓存穿透和雪崩

哨兵模式 自动选取主机的模式。 概述 主从切换技术的方法是:当主服务器宕机后,需要手动把一台从服务器切换为主服务器,这就需要人工干预,费事费力,还会造成段时间内服务不可用。这不是一种推荐的方式,更多时候&…

LabVIEW调用DLL传递结构体参数

LabVIEW 中调用动态库接口时,如果是值传递的结构体,可以根据字段拆解为多个参数;如果参数为结构体指针,可用簇(Cluster)来匹配,其内存连续相当于单字节对齐。 1.值传递 接口定义: …

交叉导轨的内部结构

相对于直线导轨,交叉导轨的知名度是没那么高的,但随着技术水平的提高,精度更高,安装高度更低的交叉导轨也慢慢走近大众的视野,得到更多厂商的青睐,使用范围也更加广泛。 交叉导轨是由两根具有V型滚道的导轨…

数据结构之动态内存管理机制

目录 数据结构之动态内存管理机制 占用块和空闲块 系统的内存管理 可利用空间表 分配存储空间的方式 空间分配与回收过程产生的问题 边界标识法管理动态内存 分配算法 回收算法 伙伴系统管理动态内存 可利用空间表中结点构成 分配算法 回收算法 总结 无用单元收…

leetcode-413. 等差数列划分(java)

等差数列划分 leetcode-413. 等差数列划分题目描述双指针 上期经典算法 leetcode-413. 等差数列划分 难度 - 中等 原题链接 - 等差数列划分 题目描述 如果一个数列 至少有三个元素 ,并且任意两个相邻元素之差相同,则称该数列为等差数列。 例如&#xff0…

【Linux操作系统】Linux系统编程实现递归遍历目录,详细讲解opendir、readdir、closedir、snprintf、strcmp等函数的使用

在Linux系统编程中,经常需要对目录进行遍历操作,以获取目录中的所有文件和子目录。递归遍历目录是一种常见的方法,可以通过使用C语言来实现。本篇博客将详细介绍如何使用C语言实现递归遍历目录的过程,并提供相应的代码示例&#x…

高阶数据结构-图

高阶数据结构-图 图的表示 图由顶点和边构成,可分为有向图和无向图 邻接表法 图的表示方法有邻接表法和邻接矩阵法,以上图中的有向图为例,邻接表法可以表示为 A->[(B,5),(C,10)] B->[(D,100)] C->[(B,3)] D->[(E,7)] E->[…

AgentBench::AI Agent 是大模型的未来

最有想象力、最有前景的方向 “Agent 是 LLM(大语言模型)的最有前景的方向。一旦技术成熟,短则几个月,长则更久,它可能就会创造出超级个体。这解释了我们为何对开源模型和 Agent 兴奋,即便投产性不高,但是我们能想象自己有了 Agent 之后就可以没日没夜地以百倍效率做现在…

Collada .dae文件格式简明教程【3D】

当你从互联网下载 3D 模型时,可能会在格式列表中看到 .dae 格式。 它是什么? 推荐:用 NSDT编辑器 快速搭建可编程3D场景。 1、Collada DAE概述 COLLADA是COLLAborative Design Activity(中文:协作设计活动&#xff09…

剑指offer43.1~n整数中1出现的次数

看到这么大的数据规模就直到用暴力法肯定会超时&#xff0c;但是还是花一分钟写了一个试一下&#xff0c;果然超时 class Solution {public int countDigitOne(int n) {int count 0;for(int i1;i<n;i){countdigitOneInOneNum(i);}return count;}public int digitOneInOneNu…

从零实战SLAM-第九课(后端优化)

在七月算法报的班&#xff0c;老师讲的蛮好。好记性不如烂笔头&#xff0c;关键内容还是记录一下吧&#xff0c;课程入口&#xff0c;感兴趣的同学可以学习一下。 --------------------------------------------------------------------------------------------------------…

字符个数统计(同类型只统计一次)

思路&#xff1a;因为题目圈定出现的字符都是 ascii 值小于等于127的字符&#xff0c;因此只需要定义一个标记数组大小为128 &#xff0c;然后将字符作为数组下标在数组中进行标记&#xff0c;若数组中没有标记过表示第一次出现&#xff0c;进行计数&#xff0c;否则表示重复字…

Layui列表复选框根据条件禁用

// 禁用客服回访id有值的复选框res.data.forEach(function (item, i) {if (item.feedbackEmpId) {let index res.data[i][LAY_TABLE_INDEX];$(".layui-table tr[data-index"index"] input[typecheckbox]").prop(disabled,true);$(".layui-table tr[d…

探索Chevereto图床:使用Docker Compose快速搭建个人图床

家人们!图片在今天的社交媒体、博客和论坛中扮演着至关重要的角色。然而&#xff0c;随着图片数量的增加&#xff0c;寻找一个可靠的图片托管解决方案变得越来越重要。Chevereto图床是一个备受赞誉的解决方案&#xff0c;而使用Docker Compose搭建它更是一种高效、可维护的方法…

【内容安全】微服务学习笔记八:使用腾讯云T-Sec天御对文本及图片内容进行安全检测

个人简介&#xff1a; > &#x1f4e6;个人主页&#xff1a;赵四司机 > &#x1f3c6;学习方向&#xff1a;JAVA后端开发 > &#x1f4e3;种一棵树最好的时间是十年前&#xff0c;其次是现在&#xff01; > ⏰往期文章&#xff1a;SpringBoot项目整合微信支付 &g…

安装paddlepadddle-gpu的正确方式

正确安装paddlepadddle-gpu的方式 1.查看系统CUDA版本2.参照飞桨官网快速pip安装 安装paddlepaddle时&#xff0c;pip install paddlepaddle是直接安装的CPU版本&#xff0c;要安装GPU版本的话&#xff0c;就要注意适配的CUDA版本&#xff0c;安装GPU版本可参照官网教程&#x…