机器学习预测-CNN手写字识别

news2025/3/15 17:34:03

介绍

这段代码是使用PyTorch实现的卷积神经网络(CNN),用于在MNIST数据集上进行图像分类。让我一步步解释:

  1. 导入库:代码导入了必要的库,包括PyTorch(torch)、神经网络模块(torch.nn)、函数模块(torch.nn.functional)、图像数据集(torchvision)以及数据处理(torch.utils.data)和可视化(matplotlib.pyplot)的工具。

  2. 设置超参数:定义了超参数,如批大小(Batch_size)、epoch数量(Epoch)和学习率(Lr)。

  3. 加载MNIST数据集:使用torchvision.datasets.MNIST加载MNIST数据集。该数据集包含了0到9的手写数字的灰度图像。transform=torchvision.transforms.ToTensor()将PIL图像转换为PyTorch张量。

  4. 可视化样本数据:打印数据集的大小,并显示数据集中的第一张图像及其相应的标签。

  5. 准备测试数据:准备测试数据与训练数据类似。加载MNIST测试数据集,并选择前2000个图像进行测试。

  6. 创建数据加载器:使用torch.utils.data.DataLoader创建训练数据的数据加载器。它有助于在训练过程中对数据进行分批和混洗。

  7. 定义CNN架构:通过子类化nn.Module来定义CNN类。该架构包括两个卷积层(self.con1self.con2),后面跟有ReLU激活函数和最大池化层。卷积层的输出被展平并馈入全连接层(self.out),产生最终输出。

  8. 初始化CNN:创建CNN类的实例。

  9. 定义损失函数和优化器:使用交叉熵损失(nn.CrossEntropyLoss)作为损失函数,使用随机梯度下降(torch.optim.SGD)作为优化器。

  10. 训练CNN:在指定的epoch数量循环内训练模型。在循环内,将训练数据通过模型,计算损失,进行梯度反向传播,并由优化器更新模型参数。

  11. 测试模型:每50次迭代训练时,对测试数据集进行评估。将测试预测与真实标签进行比较,计算准确率。

  12. 打印结果:训练结束后,打印模型预测及前10个测试样本的真实标签。

总的来说,这段代码训练了一个CNN模型,用于在MNIST数据集上对手写数字进行分类,并在单独的测试数据集上评估其性能。

代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt

# define hyper parameters
Batch_size = 100
Epoch = 1
Lr = 0.5
#DOWNLOAD_MNIST = True # 若没有数据,用此生成数据

# define train data and test data
train_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    download=False,
    transform=torchvision.transforms.ToTensor()
)
print(train_data.data.size())
print(train_data.targets.size())
print(train_data.data[0])
# 画一个图片显示出来
plt.imshow(train_data.data[0].numpy(),cmap='gray')
plt.title('%i'%train_data.targets[0])
plt.show()
# print(train_data.data.shape)           # torch.Size([60000, 28, 28])
# print(train_data.targets.size())        # torch.Size([60000])
# print(train_data.data[0].size())        # torch.Size([28, 28])
# plt.imshow(train_data.data[0].numpy(), cmap='gray')
# plt.show()
test_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=False,
    # transform=torchvision.transforms.ToTensor()
)
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000]
test_y = test_data.targets[:2000]
# print(test_x.shape)                         # torch.Size([2000, 1, 28, 28])
# print(test_y.shape)                         # torch.Size([2000])
train_loader = Data.DataLoader(
    dataset=train_data,
    shuffle=True,
    batch_size=Batch_size,
)

# define network structure
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.con1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.con2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.con1(x)            # (batch, 16, 14, 14)
        x = self.con2(x)            # (batch, 32, 7, 7)
        x = x.view(x.size(0), -1)
        out = self.out(x)             # (batch_size, 10)
        return out

cnn = CNN()
# print(cnn)
optimizer = torch.optim.SGD(cnn.parameters(), lr=Lr)
loss_fun = nn.CrossEntropyLoss()

for epoch in range(Epoch):
    for i, (x, y) in enumerate(train_loader):
        output = cnn(x)
        loss = loss_fun(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            test_output = torch.max(cnn(test_x), dim=1)[1]
            loss = loss_fun(cnn(test_x), test_y).item()
            accuracy = torch.sum(torch.eq(test_output, test_y)).item() / test_y.numpy().size
            print('Epoch:', Epoch, '|loss:%.4f' % loss, '|accuracy:%.4f' % accuracy)

print('real value', test_data.targets[: 10].numpy())
print('train value', torch.max(cnn(test_x)[: 10], dim=1)[1].numpy())




结果

real value [7 2 1 0 4 1 4 9 5 9]
train value [7 2 1 0 4 1 4 9 5 9]

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1695146.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STM32H743的FDCAN使用方法(1):STM32CubeMX初始化代码生成

0 工具准备 1.STM32CubeMX1 前言 本文介绍基于STM32CubeMX,使用stm32h743xi的对FDCAN2进行配置的方法。 2 初始化代码生成 2.1 选择FDCAN引脚 本例选择PB5、PB6作为FDCAN2的RX、TX引脚。 2.2 选择FDCAN时钟源 本例选择PLL2Q作为FDCAN时钟源,频率…

Redis(1)-Jedis连接配置

问题 阿里云安装并启用Redis后,尝试在本地用Jedis调用,发现报错 public class Jedis01 {Testpublic void connect(){Jedis jedis new Jedis("101.37.31.211", 6379); // 公网ipjedis.auth("123"); // 密码String ping jedis.pin…

可转债日内自动T+0交易,行情推送+策略触发+交易接口

说明 目前这个项目已编译打包,下载即可测试,直接生成多平台可执行文件,详见运行方法。行情部分与策略弱相关,拆分解耦单独作为一个项目。行情项目请移步GitHub - freevolunteer/hangqing: A股行情订阅工具,支持股票/可转债level2/level2数据&…

Golang实现递归复制文件夹

代码 package zdpgo_fileimport ("errors""os""path/filepath""strings" )// CopyDir 复制文件夹 // param srcPath 源文件夹 // param desPath 目标文件夹 // return error 错误信息 func CopyDir(srcPath, desPath string) error {…

机器学习第十次课

前言 因为考了一次试,所以没讲太多新东西,唯一的问题是有的知识刚讲完就考了,导致我爆了...... 正文 主要讲的就是一个贝叶斯分类模型,这是属于生成式的分类器了 Bayesian decision theory 我的理解是贝叶斯公式则是利用条件概率和全概率公式计算后验概率,就这么简单 但是…

智慧农业可视化大屏,当个农民是不是小伙伴的梦想。

智慧农业可视化大屏是指通过数据可视化技术,将农业生产过程中的各种数据、指标和信息以图表、图像等形式展示在大屏上,以便农业从业者能够直观地了解农田、作物、气象、设备等方面的情况,从而进行农业生产的监控、管理和决策。以下是智慧农业…

怎么快速批量导出文本二维码?文件批量生码的方法和步骤

随着互联网的快速发展,二维码的应用也越来越广泛,现在很多二维码会用来展示物品信息,将编辑好的文字内容生成二维码之后,让其他人通过扫码的方式来获取相关内容。那么当有多条信息时,有什么方法能够一次批量生成二维码…

PHP质量工具系列之php_CodeSniffer

PHP_CodeSniffer 是一组两个 PHP 脚本:主脚本 phpcs 对 PHP、JavaScript 和 CSS 文件进行标记,以检测是否违反定义的编码标准;第二个脚本 phpcbf 自动纠正违反编码标准的行为。PHP_CodeSniffer 是一个重要的开发工具,可以确保你的…

04Django项目基本运行逻辑及模板资源套用

对应视频链接点击直达 Django项目用户管理及模板资源 对应视频链接点击直达1.基本运行逻辑Django的基本运行路线:视图views.py中的 纯操作、数据返回、页面渲染 2.模版套用1.寻找一个好的模版2.模板部署--修改适配联动 OVER,不会有人不会吧不会的加Q1394…

AI大模型到底能帮我干什么?

这周百度文心一言大模型正式发布了,不少网民拿着各种段子搞笑和玩梗。我在的其中某个微信群里,一位老兄针对当下的大模型,发出来如下的问题: 大家这么玩梗真没意思 我一直没弄明白这些大模型怎么帮助我工作 这个问题有一定的代表性…

关于sklearn决策树手动指定节点进行剪枝调整的实现

一、决策树剪枝 决策树的剪枝方式有两种,预剪枝和后剪枝,后剪枝在python的sklearn方法中提供了CCP代价复杂度剪枝法(Cost Complexity Pruning)具体实现代码如下: # -*- coding: utf-8 -*- from sklearn.datasets imp…

Java入门基础学习笔记44——String

为什么要学习String的处理呢? 开发中,对字符串的处理是非常常见的。 String是什么?可以做什么? java.lang.String 代表字符串。可以用来创建对象封装字符串数据,并对其进行处理。 1、创建对象 2、封装字符串数据 3…

超过GPT4.0?Claude3官网及国内镜像站,国内使用克劳德3的方法

近期又有一个大模型横空出世,这就是由Anthropic公司推出的Claude 3(克劳德3),在多项基准测试中得分超越了GPT-4,那么他到底是什么情况呐?其实大家在国内也是可以使用上的! 克劳德Claude3 关于…

Python 闭包的高级用法详解

所谓闭包,就是指内函数使用了外函数的局部变量,并且外函数把内函数返回出来的过程,这个内函数称之为闭包函数。可以理解为是函数式编程中的封装。 内部函数可以使用外部函数定义的属性:外部函数调用后,返回内部函数的地…

Java入门基础学习笔记36——面向对象基础

面向对象编程快速入门: 计算机是用来处理数据的。 单个变量 数组变量 对象数据 Student类: package cn.ensource.object;public class Student {String name;double chinese_score;double math_score;public void printTotalScore() {System.out.pr…

AUTOMATIC1111/stable-diffusion-webui/stable-diffusion-webui-v1.9.3

配置环境介绍 目前平台集成了 Stable Diffusion WebUI 的官方镜像,该镜像中整合如下资源: GpuMall智算云 | 省钱、好用、弹性。租GPU就上GpuMall,面向AI开发者的GPU云平台 Stable Diffusion WebUI版本:v1.9.3 Python版本:3.10.…

HCIE是什么证书?为什么要考?

每当我发一些关于HCIE的话题时,总有小伙伴过来问“啥是HCIE啊?”今天就一起来了解下,到底什么是HCIE?为什么这么多人都要考HCIE? HCIE是华为认证ICT专家的缩写,它是华为认证体系中最高级别的ICT技术认证。HCIE全称为H…

windows 设置系统字体 (win11 win10)

由于微软的字体是有版权的,所以我打算替换掉 1.下载替换工具 github的项目,看起来很多人对微软默认字体带版权深恶痛绝。 项目地址:nomeiryoUi地址 这里选取最新的版本即可 2.打开软件 这里显示标题栏不能改,确认,其…

使用Systemd 设置Python程序开机启动

在 Linux 系统中设置Python 脚本开机启动,通常可以通过以下几种方式实现: 1. 使用 systemd(推荐方式) systemd 是大多数现代 Linux 发行版使用的初始化系统和服务管理器。你可以为Python 脚本创建一个 systemd 服务文件&#xf…

鸿蒙ArkUI-X平台差异化:【运行态差异化(@ohos.deviceInfo)】

平台差异化 简介 跨平台使用场景是一套ArkTS代码运行在多个终端设备上,如Android、iOS、OpenHarmony(含基于OpenHarmony发行的商业版,如HarmonyOS Next)。当不同平台业务逻辑不同,或使用了不支持跨平台的API&#xf…