手写数字识别Minst(CNN)

news2024/12/29 10:48:37

文章目录

  • 手写数字识别
    • 网络结构
    • 加载数据集
    • 数据集可视化
    • CNN网络结构
    • 训练模型
    • 保存模型和加载模型
    • 测试模型

手写数字识别

网络结构

网上给出的基本网络结构:
在这里插入图片描述
然而在本数据集中,输入图不是1*32*32,是1*28*28。所以正确的网络结构应该是

levelinputstrideoutput
11*28*286*5*516*24*24
MaxPool6*24*24MaxPool26*12*12
26*12*1216*5*5116*8*8
MaxPool16*8*8MaxPool216*4*4
Flatten16*4*4Flatten256
3FC256FC120
4FC120FC84
5FC84FC10

加载数据集

# -*-coding =utf-8 -*-
import torch
import matplotlib.pyplot as plt
import torchvision

# 定义数据转换
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

# 加载数据集
batch_size=32
path = r'05data'
train_dataset = torchvision.datasets.MNIST(root=path, train=True,transform=transform,download =False)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torchvision.datasets.MNIST(root=path, train=True,transform=transform,download =False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# loader.shape=1875*[32*1*28*28,32]

最后loader.shape是1875*[32*1*28*28,32],即 number*[batch(data)*height*width, batch(label)]

数据集可视化


from sklearn.preprocessing import MinMaxScaler
# 归一化转为[0,255]
transfer=MinMaxScaler(feature_range=(0, 255)) 
def visualize_loader(batch,predicted=''): 
    # batch=[32*1*28*28,32]
    imgs=batch[0].squeeze().numpy() # 消squeeze()一维
    fig, axes = plt.subplots(4, 8, figsize=(12, 6))
    labels=batch[1].numpy()
    if str(predicted)=='':
        predicted=labels
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i])
        ax.set_title(predicted[i],color='black' if predicted[i]==labels[i] else 'red')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# loader.shape=1875*[32*1*28*28,32]
for batch in train_loader:       
    break
visualize_loader(batch)

在这里插入图片描述
上图是对数据集的可视化。

CNN网络结构

在PyTorch的torch.nn模块中,卷积函数Conv2d的输入张量的形状应为[batch_size, channels, height, width]对应数据集,无需修改(在一些架构中,可能是[batch_size, height, width, channels])。

# 创建模型
import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.flatten=nn.Flatten()
        self.fc3 = nn.Linear(256, 120)
        self.fc4 = nn.Linear(120, 84)
        self.fc5 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        return x

打印模型结构

model = CNN()
print(model)
CNN(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc3): Linear(in_features=256, out_features=120, bias=True)
  (fc4): Linear(in_features=120, out_features=84, bias=True)
  (fc5): Linear(in_features=84, out_features=10, bias=True)
)

训练模型

import torch.optim as optim

num_epochs=1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 统计准确率
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        running_loss += loss.item()
    
    train_loss = running_loss / len(train_loader)
    train_accuracy = correct / total
    
    # 在测试集上评估模型
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            test_loss += loss.item()
    
    test_loss = test_loss / len(test_loader)
    test_accuracy = correct / total
    
    # 打印训练过程中的损失和准确率
    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
Epoch [1/1] - Train Loss: 0.0154, Train Accuracy: 0.9951, Test Loss: 0.0109, Test Accuracy: 0.9964

保存模型和加载模型


#torch.save(model.state_dict(), '05model.pth')

# 创建一个新的模型实例
model = CNN()
# 加载模型的参数
model.load_state_dict(torch.load('05model.pth'))

测试模型


for batch in test_loader:       
    break
imgs=batch[0]
outputs = model(imgs)
_, predicted = torch.max(outputs.data, 1)
predicted=predicted.numpy()

print(predicted)

visualize_loader(batch,predicted)

在这里插入图片描述

上图中可视化了其中的32次预测,只有第三行第四列的“8”被预测为“5”,其余均是正确。
在测试集的总体预测准确度为99.64%,正确率挺高的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/771002.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实现关注公众号以后自动推送小程序

准备好小程序的APPID和跳转路径 然后一行代码搞定&#xff1a; <a data-miniprogram-appid"小程序APPID" data-miniprogram-path"跳转路径">点我跳转到小程序</a>

Shuffle简单理解

map的结果本身是无序的&#xff0c;但是map输出的结果有序 mapper和reduce是不同的机器&#xff0c;进行了网络传输&#xff0c;所以存在数据拷贝 第二次排序&#xff0c;是将每个reduce对应的task进行排序&#xff0c;然后再进入reduce maptask运行结束&#xff0c;每个mask块…

被字节拷打了~基础还是太重要了...

今天分享一篇一位同学去字节面试的实习面经&#xff0c;技术栈是java&#xff0c;投了go后端岗位&#xff0c;主要拷打了 redismysql网络系统java算法&#xff0c;面试问题主要集中在 mysql、redis、网络这三部门&#xff0c;因为面试官是搞 go 的&#xff0c;java 只是随便问了…

55 # 实现可写流

先在 LinkedList.js 给链表添加一个移除方法 class Node {constructor(element, next) {this.element element;this.next next;} }class LinkedList {constructor() {this.head null; // 链表的头this.size 0; // 链表长度}// 可以直接在尾部添加内容&#xff0c;或者根据…

数据库小白看这里,这个Oracle数据库知识图谱你值得拥有

2022年前后&#xff0c;墨天轮社区曾陆续推出PostgreSQL知识图谱、MySQL知识图谱&#xff0c;并得到了大家的广泛好评。此后&#xff0c;便有众多朋友对Oracle知识图谱发起不断“催更“。经过近期的内容搜集整合、专家复审与打磨&#xff0c;墨天轮社区正式推出Oracle知识图谱&…

MySQL五种约束类型(普通 /自增主键,外键等) + 进阶查询(聚合查询,内 /外连接查询,自连接查询,子查询,合并查询)

文章目录 前言一、五种约束NOT NULL 约束UNIQUE 约束DEFAULT 约束PRIMARY KEY 主键约束(重点)普通主键自增主键 FOREIGN KEY 外键约束(重点) 二、进阶查询聚合查询聚合函数GROUP BY子句HAVING 联合查询笛卡尔积内连接外连接自连接子查询单行子查询&#xff1a;返回一行记录的子…

AI时代图像安全“黑科技”如何助力人工智能与科技发展?

〇、前言 7月7日下午&#xff0c;2023世界人工智能大会&#xff08;WAIC&#xff09;“聚焦大模型时代AIGC新浪潮—可信AI”论坛在上海世博中心红厅举行。人工智能等技术前沿领域的著名专家与学者、投资人和领军创业者汇聚一堂&#xff0c;共同探索中国科技创新的驱动力量。 在…

搭载下一代人工智能技术,微软推出Power Automate流程挖掘产品

在近日的Microsoft Inspire大会中&#xff0c;微软揭晓了他们即将推出的Power Automate流程挖掘产品&#xff0c;并计划在8月1日正式对外开放。 试用地址&#xff1a;https://powerautomate.microsoft.com/zh-cn/#home-signup 这款产品搭载了下一代人工智能技术&#xff0c;有…

好用的思维导图软件有哪些?这几款简单好用

好用的思维导图软件有哪些&#xff1f;思维导图是一种非常有用的思维工具&#xff0c;可以帮助我们组织和理清复杂的信息。在如今的数字时代&#xff0c;有很多软件可以帮助我们创建和编辑思维导图。下面介绍几款简单好用的思维导图软件。 第一款&#xff1a;迅捷画图 这是一款…

多个信贷范围时客户主数据界面的定制(套头和信用缴纳范围=信贷范围)

客户主数据-销售范围-开票的界面有信贷范围&#xff0c;叫贷方控制范围。 但是默认是看不到的。需要进行配置。 但是SAP的配置里面的名字很奇怪&#xff0c;在客户账户组里面的销售数据中(OVT0)定制 双击后处理的这个界面&#xff0c;和界面的“”开票凭证“”对不上&#x…

云原生微服务应用的平台工程实践

作者&#xff1a;纳海 01 微服务应用云原生化 微服务是一个广泛使用的应用架构&#xff0c;而如何使得微服务应用云原生化却是近些年一直在演进的课题。国内外云厂商对云原生概念的诠释大同小异&#xff0c;基本都会遵循 CNCF 基金会的定义&#xff1a; 云原生技术有利于各组…

Linux内核源代码的目录结构包括部分:

内核核心代码&#xff1a;这部分代码包括内核的各个子系统和模块&#xff0c;如进程管理、内存管理、文件系统、网络协议栈等。这些代码构成了Linux内核的核心功能。 非核心代码&#xff1a;除了核心代码之外&#xff0c;还包括一些非核心的代码和文件&#xff0c;如库文件、固…

【网站搭建】1安装Hexo

1.前期准备工作 安装node.js和git Node.js (nodejs.org) Git - Downloads (git-scm.com) 安装好后验证是否完成安装 2.打开Git安装配置Hexo 由于国内的镜像源速度较慢&#xff0c;所以我们利用 npm 来安装 cnpm &#xff0c;在命令行中输入npm install -g cnpm --registry…

一文详解 requests 库中 json 参数和 data 参数的用法

在requests库当中&#xff0c;requests请求方法&#xff0c;当发送post/put/delete等带有请求体的请求时&#xff0c;有json和data2个参数可选。 众所周知&#xff0c;http请求的请求体格式主要有以下4种&#xff1a; application/jsonapplicaiton/x-www-from-urlencodedmult…

291. 单词规律 II(plus题)

给你一种规律 pattern 和一个字符串 s&#xff0c;请你判断 s 是否和 pattern 的规律相匹配。 如果存在单个字符到 非空 字符串的 双射映射 &#xff0c;那么字符串 s 匹配 pattern &#xff0c;即&#xff1a;如果 pattern 中的每个字符都被它映射到的字符串替换&#xff0c;那…

python发送邮件zmail库

第三方库“zmail”和“yagmail”可实现邮件发送。在实际使用对比zmail比yagmail更简洁。使用zmail&#xff0c;无需登录OA邮箱&#xff0c;便可完成邮件的发送及附件的自动加载。 import zmaildef send_zmail(sender, sender_password, addressee, host, port465, inspect_smtp…

<C语言> 自定义类型

1.结构体 结构体是一种用户自定义的数据类型&#xff0c;允许将不同类型的数据项组合在一起&#xff0c;形成一个更大的数据结构。结构体可以包含多个成员变量&#xff0c;每个成员变量可以是不同的数据类型&#xff0c;如整数、字符、浮点数等&#xff0c;甚至可以包含其他结构…

师承AI世界新星|7天获新加坡南洋理工大学访学邀请函

能够拜师在“人工智能10大新星”名下&#xff0c;必定可以学习到前沿技术&#xff0c;受益良多&#xff0c;本案例中的C老师无疑就是这个幸运儿。我们只用了7天时间就取得了这位AI新星导师的邀请函&#xff0c;最终C老师顺利获批CSC&#xff0c;如愿出国。 C老师背景&#xff1…

线程与信号

1.子线程会继承主线程信号处理配置&#xff0c;故信号配置可以全部放在主线程内。 2.同一信号多次触发或者嵌套触发不会嵌套执行。 3.不同信号可以嵌套触发执行。 4.kill()触发的信号由进程&#xff08;主线程&#xff09;执行&#xff0c;pthread_kill()触发的信号由参数指…

数据结构-单链表

#include<stdio.h> #include<stdlib.h>typedef struct Node {int data;struct Node* next; }Node;//创建一个头结点&#xff0c;数据域保存链表节点数 Node* init_single_list() {Node* node (Node*)malloc(sizeof(Node));node->next NULL;node->data 0; …