机器学习实验4——CNN卷积神经网络分类Minst数据集

news2024/10/2 3:15:58

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡 原理🧡🧡
    • 🧡🧡CNN实现分类Minst🧡🧡
      • 代码
      • 数据预处理:
      • 设置基本参数:

🧡🧡实验内容🧡🧡

基于手写minst数据集,完成关于卷积网络CNN的模型训练、测试与评估。

🧡🧡 原理🧡🧡

卷积层
通过使用一组可学习的滤波器(也称为卷积核)对输入图像进行滑动窗口卷积操作,这样可以提取出不同位置的局部特征,从而捕捉到图像的空间结构信息。
激活函数
在卷积层之后,通常会应用一个非线性激活函数,如 ReLU激活函数的作用是引入非线性,使得 CNN 能够学习更复杂的特征表达。
池化层
池化层用于降低特征图的空间尺寸,同时保留最显著的特征信息(类似于人眼观物,是根据物体的主要轮廓来判断物体是什么,而对一些小细节第一眼并没有那么关注)。常见的池化方式包括最大池化和平均池化,它们可以减少计算量,并增加模型的平移不变性。
全连接层
一般在 CNN 的最后几层,全连接层被用来将先前的卷积和池化层的输出与目标类别进行关联,每个神经元在该层中与前一层的所有神经元相连,通过学习权重参数来进行分类决策。
Softmax 函数
在最后一个全连接层之后,通常会应用 Softmax 函数来将神经网络的输出转换为概率分布,用于多类别分类问题的预测。
例如p=[0.2,0.3,0.5],这表示分类为类别1、2、3的概率分别为0.2,0.3,0.5,因此预测分类结果为类别3.
最后通过反向传播算法,CNN 使用训练数据进行模型参数的优化,它通过最小化损失函数(如交叉熵)来调整网络权重,并使用梯度下降等优化算法进行迭代更新。

构建本实验的CNN网络:

  • 5 x 5的卷积核,输入通道为1,输出通道为16:此时图像矩经过卷积核后尺寸变成24 x 24。
  • 2 x 2 的最大池化层:此时图像大小缩短一半,变成 12 x 12,通道数不变;
  • 再次经过 5 x 5 的卷积核,输入通道为16,输出通道为32:此时图像尺寸经过卷积核后变成8 *8。
  • 再次经过 2 x 2 的最大池化层:此时图像大小缩短一半,变成4 x 4,通道数不变;
  • 最后将图像整型变换成向量,输入到全连接层中:输入一共有4 x 4 x 32 = 512 个元素,输出为10.
    -在这里插入图片描述

🧡🧡CNN实现分类Minst🧡🧡

代码

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn, optim
from time import time

# ======================准备数据集======================
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',
                               train=True,
                               download=True,
                               transform=transform)
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist',
                              train=False,
                              download=True,
                              transform=transform)
test_loader = DataLoader(test_dataset,
                         shuffle=False,
                         batch_size=batch_size)


# ======================CNN net======================
class CNN_net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)  # 卷积1
        self.pooling1 = nn.MaxPool2d(2)  # 最大池化
        self.relu1 = nn.ReLU()  # 激活

        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.pooling2 = nn.MaxPool2d(2)
        self.relu2 = nn.ReLU()

        self.fc = nn.Linear(512, 10)  # 全连接

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.pooling1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.pooling2(x)
        x = self.relu2(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x


model = CNN_net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

# ====================== train ======================
def train(epoch):
    time0 = time()  # 记录下当前时间
    loss_list = []
    for e in range(epoch):
        running_loss = 0.0
        for images, labels in train_loader:
            outputs = model(images)  # 前向传播获取预测值
            loss = criterion(outputs, labels)  # 计算损失
            loss.backward()  # 进行反向传播
            optimizer.step()  # 更新权重
            optimizer.zero_grad()  # 清空梯度
            running_loss += loss.item()  # 累加损失
        # 一轮循环结束后打印本轮的损失函数
        print("Epoch {} - Training loss: {}".format(e, running_loss / len(train_loader)))
        loss_list.append(running_loss / len(train_loader))
    # 打印总的训练时间
    print("\nTraining Time (in minutes) =", (time() - time0) / 60)

    # 绘制损失函数随训练轮数的变化图
    plt.plot(range(1, epoch + 1), loss_list)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.show()

train(5)

# ====================== test ======================
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt


def test():
    model.eval()  # 将模型设置为评估模式
    correct = 0
    total = 0
    all_predicted = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_predicted.extend(predicted.tolist())
            all_labels.extend(labels.tolist())
    print('Model Accuracy =:%.4f' % (correct / total))
    # 绘制混淆矩阵
    cm = confusion_matrix(all_labels, all_predicted)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt=".0f", cmap="Blues")
    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    plt.title("Confusion Matrix")
    plt.show()

test()

数据预处理:

加载数据集:
加载torch库中自带的minst数据集
转换数据:
先转为tensor变量(相当于直接除255归一化到值域为(0,1))
然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)。
(做完实验后,上网了解发现minst最合适的的std和mean分别为0.1307, 0.3081,但是其实结果都差不多,准确率变化不大,因为数据集还是相对比较简单的)

设置基本参数:

在这里插入图片描述

构建CNN神经网络:
同上述(1)中,已经构建完毕,这里不再赘述。
模型训练:
在这里插入图片描述
可见,虽然只经过5个epoch,但是花的时间为3.3min。
模型分类:
在这里插入图片描述
准确率达98.26%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1413022.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM实战(34)——内存溢出之消息队列处理不当

一、简介 本章,我们将介绍一个因为处理消息队列中的数据不当而引起的内存溢出问题,先来看下系统的背景。 1.1 系统背景 这是一个线上的数据同步系统,专门从Kafka消费其它系统送进去的数据,处理后存储到自己的数据库中&#xff1…

算法------(10)堆

例题:(1)AcWing 838. 堆排序 我们可以利用一个一维数组来模拟堆。由于堆本质上是一个完全二叉树,他的每个父节点的权值都小于左右子节点,而每个父节点编号为n时,左节点编号为2*n,右节点编号为2*…

智慧应急消防柜的作用

在现代社会,科技的不断进步带来了许多便利与改变。智能化的产品不仅给我们的生活带来了便捷,也让我们对各个领域的发展有了更高的期待。而在这种场景下,智慧应急消防柜作为智慧城市新型基础设施的必备品,正逐渐受到更多关注。 智能…

利用ADS建立MIPI D-PHY链路仿真流程

根据MIPI D-PHY v1.2规范中对于互连电气参数的定义,本次仿真实例中,需要重点关注如下的设计参数: 1. 差分信号的插入损耗Sddij和回拨损耗Sddii; 2. 模式转换损耗Sdcxx、Scdxx; 3. 数据线与时钟线之间的串扰耦合(远、近端)。 设计者还可以结合CTS中的补充…

抓包工具fiddler看完你就懂了

一、简介 fiddler是位于客户端和服务端之间的http代理 1、作用 监控浏览器所有的http/https流量查看、分析请求内容细节伪造客户端请求和服务器请求测试网站的性能解密https的web会话全局、局部断电功能第三方插件 2、使用场景 接口调试、接口测试、线上环境调试、web性能分…

驱动开发-系统移植

一、Linux系统移植概念 需要移植三部分东西,Uboot ,内核 ,根文件系统 (rootfs) ,这三个构成了一个完整的Linux系统。 把这三部分学明白,系统移植就懂点了。 二、Uboot uboot就是引导程序下载的一段代…

爬取樱花动漫名侦探柯南最新剧场版ts格式

import os import requests import zipfile from tqdm import tqdm import tkinter as tkfilename 名侦探柯南\\ if not os.path.exists(filename):os.mkdir(filename) # https://vip.ffzy-online6.com/20231129/22304_740e70d0/2000k/hls/cedd2dc1ecb000001.ts # https://vip…

开源无代码应用程序生成器Saltcorn

什么是 Saltcorn ? Saltcorn 是一个无需编写任何代码即可构建数据库 Web 应用程序的平台。它配备了一个吸睛的仪表板,丰富的生态系统、视图生成器以及支持主题的界面,使用直观的点击、拖放用户界面来构建整个应用程序。 软件的特点&#xff1…

【漏洞复现】艺创科技智能营销路由器后台命令执行漏洞

Nx01 产品简介 成都艺创科技有限公司,成立于2011年,位于四川省成都市,是一家以从事研究和试验发展为主的企业。企业注册资本1000万人民币,实缴资本50万人民币。 Nx02 漏洞描述 成都艺创科技有限公司智能营销路由器存在默认口令(…

QT+VS实现Kmeans聚类算法

1、Kmeans的定义 聚类是一个将数据集中在某些方面相似的数据成员进行分类组织的过程,聚类就是一种发现这种内在结构的技术,聚类技术经常被称为无监督学习。k均值聚类是最著名的划分聚类算法,由于简洁和效率使得他成为所有聚类算法中最广泛使…

数位dp,HDU 4151 The Special Number

一、题目 1、题目描述 In this problem, we assume the positive integer with the following properties are called ‘the special number’: 1) The special number is a non-negative integer without any leading zero. 2) The numbers in every digit of the special nu…

RCE 漏洞审计

Command Injection 命令注入(Command Injection)是一种安全漏洞,命令注入攻击的目的是,在易受攻击的应用程序中注入和执行攻击者指定的命令。在这种情况下,执行不需要的系统命令的应用程序就像一个伪系统外壳&#xff…

无线监测终端引领文物保护和管理新篇章

一、文物预防性保护系统大升级 随着科技的不断发展,越来越多的高科技产品进入人们的生活和工作中。在文物保护和管理行业,无线监测终端大放异彩。免布线、即插即用的特点在提供方便的同时,也为文物的长久保存和有效管理带来更好的保护环境。…

兼容树莓派扩展模块,专注工业产品开发的瑞米派强势来袭

近日,米尔电子和瑞萨电子共同定义和开发了瑞萨第一款MPU生态开发板——瑞米派(Remi Pi)正式上市了!在各种Pi板卡琳琅满目的当下,Remi Pi是一款与众不同的开发板,他兼顾了严肃产品开发和爱好者创意实现两种需…

ffmpeg和opencv一些容易影响图片清晰度的操作

ffmpeg 转视频或者图片,不指定码率清晰度会下降 ffmpeg -i xxx.png xxx.mp4 码率也叫比特率(Bit rate)(也叫数据率)是一个确定整体视频/音频质量的参数,秒为单位处理的字节数,码率和视频质量成正比,在视频…

Linux CentOs7 安装Mysql(5.7和8.0版本)密码修改 超详细教程

CSDN 成就一亿技术人! 今天出一期Centos下安装Mysql(详细教程)包括数据库密码跳过修改 CSDN 成就一亿技术人! 目录 1.获取安装包 2.安装程序 安装下载的rpm包 查看安装包 修改5.7版本(重要) 安装M…

Java七大排序详解

排序 排序的概念 所谓排序 ,就是让一串记录,按照其中某些或者某个关键字的大小,递增或递减的排列起来的操作。 稳定性:就比如在待排序的序列中,存在多个具有相同关键字的记录 ,如果经过排序这些相同的关键…

通过FileZilla配置FTP

FileZilla服务端的安装 在虚拟机里安装FileZilla服务器 FileZilla的官网 下载一个客户端和一个服务端的FileZilla 如果已经有了一个客户端,可以不下用载。 FileZilla的配置 说明一下:通过FileZilla配置FTP有两种模式,我们先用被动模式 下载…

Javaweb之SpringBootWeb案例之阿里云OSS服务集成的详细解析

2.3.3 集成 阿里云oss对象存储服务的准备工作以及入门程序我们都已经完成了,接下来我们就需要在案例当中集成oss对象存储服务,来存储和管理案例中上传的图片。 在新增员工的时候,上传员工的图像,而之所以需要上传员工的图像&…

CDR绘图软件|安装教程来了(小白福利:有红包封面领取哦!)

前言 今天给小伙伴们讲讲:如何安装CDR软件。 如果未来的你想从事平面设计/广告行业,那应该就会接触到CDR这款软件。 CorelDRAW Graphics Suite是加拿大Corel公司的平面设计软件;该软件是Corel公司出品的矢量图形制作工具软件,这…