基于pytorch的手写数字识别

news2024/10/6 23:44:20
import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

matplotlib.use('tkAgg')

# 设置图形配置
config = {
    "font.family": 'serif',
    "mathtext.fontset": 'stix',
    "font.serif": ['SimSun'],
    'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)

def mymap(labels):
    return np.where(labels < 10, labels, 0)

# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)

# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义模型
my_nn = torch.nn.Sequential(
    torch.nn.Linear(400, 128),
    torch.nn.Sigmoid(),
    torch.nn.Linear(128, 256),
    torch.nn.Sigmoid(),
    torch.nn.Linear(256, 512),
    torch.nn.Sigmoid(),
    torch.nn.Linear(512, 10)
).to(device)

# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式

# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签

# 进行预测
with torch.no_grad():  # 禁用梯度计算
    predictions = my_nn(sample_images)
    predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签

# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):
    plt.subplot(10, 5, i + 1)  # 10行5列的子图
    plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像
    plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)
    plt.axis('off')  # 关闭坐标轴

plt.tight_layout()  # 调整子图间距
plt.show()

Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

matplotlib.use('tkAgg')

# 设置图形配置
config = {
    "font.family": 'serif',
    "mathtext.fontset": 'stix',
    "font.serif": ['SimSun'],
    'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)

def mymap(labels):
    return np.where(labels < 10, labels, 0)

# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)

# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义模型
my_nn = torch.nn.Sequential(
    torch.nn.Linear(400, 128),
    torch.nn.Sigmoid(),
    torch.nn.Linear(128, 256),
    torch.nn.Sigmoid(),
    torch.nn.Linear(256, 512),
    torch.nn.Sigmoid(),
    torch.nn.Linear(512, 10)
).to(device)

# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式

# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签

# 进行预测
with torch.no_grad():  # 禁用梯度计算
    predictions = my_nn(sample_images)
    predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签

plt.figure(figsize=(16, 10))
for i in range(20):
    plt.subplot(4, 5, i + 1)  # 4行5列的子图
    plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像
    plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12)  # 标题中显示真实值和预测值
    plt.axis('off')  # 关闭坐标轴

plt.tight_layout()  # 调整子图间距
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2193195.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在华为云服务器查看IP地址,及修改服务器登录密码!!!

1.在华为云服务器查看IP地址 (1).第一步&#xff1a; 先找到控制台 (2).第二步&#xff1a; 点击华为云Flexus云服务 (3)第三步&#xff1a; 找到公网IP&#xff0c;就找到华为云服务器IP地址啦。 注意&#xff1a;在操作以上步骤的前提是要已注册华为云账号及购买云服务器…

PPPoE协议个人理解+报文示例+典型配置-RFC2516

个人认为&#xff0c;理解报文就理解了协议。通过报文中的字段可以理解协议在交互过程中相关传递的信息&#xff0c;更加便于理解协议。 因此本文将在PPPoE协议报文的基础上进行介绍。 PPPoE协议发展 关于PPPoE基本原理&#xff0c;可参考1999年发布的《RFC2516-A Method fo…

class 031 位运算的骚操作

这篇文章是看了“左程云”老师在b站上的讲解之后写的, 自己感觉已经能理解了, 所以就将整个过程写下来了。 这个是“左程云”老师个人空间的b站的链接, 数据结构与算法讲的很好很好, 希望大家可以多多支持左程云老师, 真心推荐. 左程云的个人空间-左程云个人主页-哔哩哔哩视频…

8649 图的广度遍历

### 思路 1. **图的邻接表存储结构**&#xff1a;使用邻接表存储图的顶点和边信息。 2. **基本操作函数**&#xff1a;包括创建图、查找顶点、获取顶点值、获取第一个邻接顶点、获取下一个邻接顶点等。 3. **广度优先遍历&#xff08;BFS&#xff09;**&#xff1a;从某个顶点出…

LPDDR6 来之未远

很多朋友可能还没用上DDR5,但不好意思的是,DDR6 可能马上就要出现了。 三星和海力士较早开始DDR6 的设计,预计2025年商业化。 DDR6 速度 来源: 半导体观察 DDR6的速度将是主流的DDR4的四倍,将是现有DDR5的两倍,DDR6传输速度可达12800 Mbps。 LPDDR6 来源:快科技 L…

OpenAI董事会主席Bret Taylor的Agent公司Sierra:专注于赋能下一代企业用户体验

本文由readlecture.cn转录总结。ReadLecture专注于音、视频转录与总结&#xff0c;2小时视频&#xff0c;5分钟阅读&#xff0c;加速内容学习与传播。 视频来源 youtube: https://www.youtube.com/watch?vriWB5nPNZEM&t47s 大纲 介绍 欢迎与介绍 介绍Bret Taylor&#x…

功耗电流图的对比技巧

电流波形对比 使用系统画图工具的反色和透明设置项目&#xff0c;就可以将2张图合在一块看 方法【系统画图工具】 例如在相同的测试用例&#xff0c;可以对比电流和耗电量的差异

3.使用条件语句编写存储过程(3/10)

引言 在现代数据库管理系统中&#xff0c;存储过程扮演着至关重要的角色。它们是一组为了执行特定任务而编写的SQL语句&#xff0c;这些语句被保存在数据库中&#xff0c;可以被重复调用。存储过程不仅可以提高数据库操作的效率&#xff0c;还可以增强数据的安全性和一致性。此…

Python3 爬虫 中间人爬虫

中间人&#xff08;Man-in-the-Middle&#xff0c;MITM&#xff09;攻击是指攻击者与通信的两端分别创建独立的联系&#xff0c;并交换其所收到的数据&#xff0c;使通信的两端认为其正在通过一个私密的连接与对方直接对话&#xff0c;但事实上整个会话都被攻击者完全控制。在中…

LCD屏入门(基于ESP-IDF、SPI屏)

主要参考资料&#xff1a; ESP32-S3 开发 SPI 屏【DIY 智能手表】: https://www.bilibili.com/video/BV1Yc411y7bb/?spm_id_from333.337.search-card.all.click&vd_sourcedd284033cd0c4d1f3f59a2cd40ae4ef9 使用 SPI 屏和 I2C 触屏运行 SquareLine Studio 提供的手表 UI 示…

突触可塑性与STDP:神经网络中的自我调整机制

突触可塑性与STDP&#xff1a;神经网络中的自我调整机制 在神经网络的学习过程中&#xff0c;突触可塑性&#xff08;Synaptic Plasticity&#xff09;是指神经元之间的连接强度&#xff08;突触权重&#xff09;随着时间的推移而动态变化的能力。这种调整机制使神经网络能够通…

链动 2+1 模式 S2B2C 商城小程序:交易转化的创新引擎

摘要 在数字化商业时代&#xff0c;电商行业竞争激烈&#xff0c;交易转化成为核心问题。链动 21 模式 S2B2C 商城小程序源码作为创新电商模式&#xff0c;通过独特的推荐与分享机制、丰富奖励机制、AI 智能名片及 S2B2C 商城的个性化定制与供应链协同等&#xff0c;在交易转化…

redis+mysql数据一致性+缓存穿透解决方案

在分布式事务中我们知道有cap定理&#xff0c;即 我们保证高可用的情况下&#xff0c;必然要牺牲一些一致性&#xff0c;在保证强一致性的情况下&#xff0c;必然会牺牲一些可用性。而我们redismysql数据一致性的使用策略就是在我们保证可用性的情况下尽量保证数据的一致性。想…

MySql的基本语法操作

查看数据库和表 查看所有的数据库 show databases; 建立一个新的数据库 create database database_name; 也可以是 create database if not exists database_name; 表示这个数据库不存在才建立 而不会打断其他sql语句的执行&#xff0c;而如果没有加的话&#xff0c;创建…

神经网络及大模型科普揭秘

一、生物神经元及神经元构成的神经网络 下图是生物神经元的示意图: 生物神经元由细胞体、树突、轴突、轴突末梢四部分构成。 下图是生物神经网络的一个简单示意图: 生物神经元通过电信号在彼此间传递信号,神经元的各个树突接收输入信号,经过细胞体汇总,如果最终总和高…

【动态规划-最长公共子序列(LCS)】力扣97. 交错字符串

给定三个字符串 s1、s2、s3&#xff0c;请你帮忙验证 s3 是否是由 s1 和 s2 交错 组成的。 两个字符串 s 和 t 交错 的定义与过程如下&#xff0c;其中每个字符串都会被分割成若干 非空 子字符串&#xff1a; s s1 s2 … sn t t1 t2 … tm |n - m| < 1 交错 是 s1…

【微服务】服务注册与发现 - Eureka(day3)

CAP理论 P是分区容错性。简单来说&#xff0c;分区容错性表示分布式服务中一个节点挂掉了&#xff0c;并不影响其他节点对外提供服务。也就是一台服务器出错了&#xff0c;仍然可以对外进行响应&#xff0c;不会因为某一台服务器出错而导致所有的请求都无法响应。综上所述&…

网络安全概述:从认知到实践

一、定义 网络安全&#xff0c;即致力于保护网络系统所涵盖的硬件、软件以及各类数据&#xff0c;切实保障其免遭破坏、泄露或者篡改等不良情形的发生。 二、重要性 个人层面&#xff1a;着重于守护个人隐私以及财产安全&#xff0c;为个人在网络世界中的各项活动提供坚实的保…

分享几个做题网站------学习网------工具网;

以下是就是做题网站&#xff1b;趣IT官网-互联网求职刷题神器趣IT——互联网在线刷题学习平台&#xff0c;汇集互联网大厂面试真题&#xff0c;拥有java、C、Python、前端、产品经理、软件测试、新媒体运营等多个热门IT岗位面试笔试题库&#xff0c;提供能力测评、面试刷题、笔…

Meta 首个多模态大模型一键启动!首个多针刺绣数据集上线,含超 30k 张图片

小扎在 Meta Connect 2024 主题演讲中宣布推出首个多模态大模型 Llama 3.2 vision&#xff01;该模型有 11B 和 90B 两个版本&#xff0c;成为首批支持多模态任务的 Llama 系列模型&#xff0c;根据官方数据&#xff0c;这两个开原模型的性能已超越闭源模型。 小编已经迫不及待…