基于pytorch的手写数字识别-训练+使用

news2024/12/29 0:58:55
import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

matplotlib.use('tkAgg')

# 设置图形配置
config = {
    "font.family": 'serif',
    "mathtext.fontset": 'stix',
    "font.serif": ['SimSun'],
    'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)

def mymap(labels):
    return np.where(labels < 10, labels, 0)

# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)

# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义模型
my_nn = torch.nn.Sequential(
    torch.nn.Linear(400, 128),
    torch.nn.Sigmoid(),
    torch.nn.Linear(128, 256),
    torch.nn.Sigmoid(),
    torch.nn.Linear(256, 512),
    torch.nn.Sigmoid(),
    torch.nn.Linear(512, 10)
).to(device)

# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式

# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签

# 进行预测
with torch.no_grad():  # 禁用梯度计算
    predictions = my_nn(sample_images)
    predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签

# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):
    plt.subplot(10, 5, i + 1)  # 10行5列的子图
    plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像
    plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)
    plt.axis('off')  # 关闭坐标轴

plt.tight_layout()  # 调整子图间距
plt.show()

Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader

matplotlib.use('tkAgg')

# 设置图形配置
config = {
    "font.family": 'serif',
    "mathtext.fontset": 'stix',
    "font.serif": ['SimSun'],
    'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)

def mymap(labels):
    return np.where(labels < 10, labels, 0)

# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)

# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)

# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义模型
my_nn = torch.nn.Sequential(
    torch.nn.Linear(400, 128),
    torch.nn.Sigmoid(),
    torch.nn.Linear(128, 256),
    torch.nn.Sigmoid(),
    torch.nn.Linear(256, 512),
    torch.nn.Sigmoid(),
    torch.nn.Linear(512, 10)
).to(device)

# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式

# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签

# 进行预测
with torch.no_grad():  # 禁用梯度计算
    predictions = my_nn(sample_images)
    predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签

plt.figure(figsize=(16, 10))
for i in range(20):
    plt.subplot(4, 5, i + 1)  # 4行5列的子图
    plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像
    plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12)  # 标题中显示真实值和预测值
    plt.axis('off')  # 关闭坐标轴

plt.tight_layout()  # 调整子图间距
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2195634.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Maven 高级之分模块设计与继承、聚合

在软件开发中&#xff0c;随着项目规模的扩大&#xff0c;代码量和复杂度不断增加&#xff0c;传统的一体化开发模式逐渐暴露出诸多问题。为了解决这些问题&#xff0c;模块化开发应运而生&#xff0c;而 Maven 正是模块化开发的利器&#xff0c;它提供的继承和聚合机制为构建和…

fiddler抓包20_弱网模拟

课程大纲 ① 打开CustomRules.js文件&#xff1a;Fiddler快捷键“CtrlR”(或鼠标点击&#xff0c;菜单栏 - Rules - Customize Rules)。 ② 设置速率&#xff1a;“CtrlF”&#xff0c;搜索 “m_SimulateModem”&#xff0c;定位至函数。在函数里设置上传、下载速率&#xff0c…

ESP8266模块(WIFI STM32)

目录 一、介绍 二、传感器原理 1.原理图 2.引脚描述 3.ESP8266基础AT指令介绍 4.ESP8266基础工作模式 三、程序设计 main.c文件 esp8266.h文件 esp8266.c文件 四、实验效果 五、资料获取 项目分享 一、介绍 ESP8266是一款嵌入式系统级芯片&#xff0c;它集成了Wi…

将自己写好的项目部署在自己的云服务器上

准备工作 这里呢我要下载的终端软件是Xshell 如图&#xff1a; 自己准备好服务器&#xff0c;我这里的是阿里云的服务器&#xff0c; 如图&#xff1a; 这两个准备好之后呢&#xff0c;然后对我们的项目进行打包。 如图&#xff1a; 这里双击打包就行了。 找到自己打成jar包…

零基础多图详解图神经网络(GNN/GCN)【李沐论文精读】

A Gentle Introduction to Graph Neural Networks 在上图中&#xff0c;每一层都是一层网络&#xff0c;每一层的节点都受到下一层中自身节点和邻居节点的影响。如果网络比较深&#xff0c;是可以处理到一幅图中较大范围的节点。 前言 图神经网络在应用上还只是起步阶段&…

基于SpringBoot健身房管理系统【附源码】

效果如下&#xff1a; 系统首页界面 系统注册详细页面 健身课程详细页面 后台登录界面 管理员主页面 员工界面 健身教练界面 员工主页面 健身教练主页面 研究背景 随着生活水平的提高和健康意识的增强&#xff0c;现代人越来越注重健身。健身房作为一种专业的健身场所&#x…

日期类的实现(C++)

个人主页&#xff1a;Jason_from_China-CSDN博客 所属栏目&#xff1a;C系统性学习_Jason_from_China的博客-CSDN博客 所属栏目&#xff1a;C知识点的补充_Jason_from_China的博客-CSDN博客 前言 日期类是六个成员函数学习的总结和拓展&#xff0c;是实践的体现 创建文件 构造函…

HCIP--以太网交换安全(二)

端口安全 一、端口安全概述 1.1、端口安全概述&#xff1a;端口安全是一种网络设备防护措施&#xff0c;通过将接口学习的MAC地址设为安全地址防止非法用户通信。 1.2、端口安全原理&#xff1a; 类型 定义 特点 安全动态MAC地址 使能端口而未是能Stichy MAC功能是转换的…

在VMware WorkStation上安装飞牛OS(NAS系统)

对于NAS系统&#xff0c;小白相信很多小伙伴都不陌生&#xff0c;在许多场景下也能看得到&#xff0c;它其实可以算是文件存储服务器&#xff0c;当然&#xff0c;你如果给它加上其他服务的话&#xff0c;它也能变成网页服务器、Office协同办公服务器等等。 有许多小伙伴都拿这…

信息安全工程师(38)防火墙类型与实现技术

一、防火墙类型 按软、硬件形式分类 软件防火墙&#xff1a;通过软件实现防火墙功能&#xff0c;通常安装在个人计算机或服务器上&#xff0c;用于保护单个设备或小型网络。硬件防火墙&#xff1a;采用专门的硬件设备来实现防火墙功能&#xff0c;通常部署在企业网络边界或数据…

基于SpringBoot图书馆预约与占座小程序【附源码】

效果如下&#xff1a; 首页界面 用户登录界面 查看座位界面 管理员登录界面 管理员主界面 座位分布信息界面 预约信息界面 研究背景 随着互联网技术的不断进步和智能手机的广泛普及&#xff0c;图书馆作为知识获取和学习的重要场所&#xff0c;其管理方式也在逐步向信息化和智…

系统架构设计师论文《论企业应用系统的数据持久层架构设计》精选试读

论文真题 数据持久层&#xff08;Data Persistence Layer&#xff09;通常位于企业应用系统的业务逻辑层和数据源层之间&#xff0c;为整个项目提供一个高层、统一、安全、并发的数据持久机制&#xff0c;完成对各种数据进行持久化的编程工作&#xff0c;并为系统业务逻辑层提…

【电路基础 · 4】电路的图;KCL、KVL巩固;支路电流法

一、电路的图 1.线性电路的一般的分析方法 2.计算方法 掌握计算方法。 3.支路 branch 和 节点 node 对于支路&#xff0c;经常取电压、电流为同向。 4.KCL 巩固 巩固一下之前学习的 KCL。 但是需要注意&#xff1a; 对于一个电路&#xff0c;如果有 n 个节点&#xff0c;那…

浅学React和JSX

用antd做个人博客卡到前端了&#xff0c;迫不得已来学react&#xff0c;也是干上全栈了-- --学自尚硅谷张天禹react React就是js框架&#xff0c;可以理解为对js做了封装&#xff0c;那么封装后的肯定用起来更方便。 相关JS库 react.js&#xff1a;React核心库。react-dom.js&a…

计算机基本组成和工作原理(Basic Components and Working Principles of Computers)

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…

node配置swagger

安装swagger npm install swagger-jsdoc swagger-ui-express 创建 swagger.js 配置文件 ​ const path require(path); const express require(express); const swaggerUI require(swagger-ui-express); const swaggerJsDoc require(swagger-jsdoc); // 修改 swaggerDoc…

DAY26||669.修建二叉树 |108.将有序数组转换为二叉搜索树|538.把二叉搜索树转换为累加树

669.修剪二叉树 题目&#xff1a;669. 修剪二叉搜索树 - 力扣&#xff08;LeetCode&#xff09; 给你二叉搜索树的根节点 root &#xff0c;同时给定最小边界low 和最大边界 high。通过修剪二叉搜索树&#xff0c;使得所有节点的值在[low, high]中。修剪树 不应该 改变保留在树…

ArcGIS属性表怎么连接Excel表格?

ArcGIS中&#xff0c;属性表是存储空间要素非几何特征属性的重要工具。有时&#xff0c;我们需要将这些属性与外部数据&#xff0c;如Excel表格中的数据进行连接。以下是如何在ArcGIS中实现这一过程的步骤。 要把Excel表里的数据导入到ArcGIS里的地图数据里面&#xff0c;对数…

[单master节点k8s部署]34.ingress 反向代理(一)

ingress是k8s中的标准API资源&#xff0c;作用是定义外部流量如何进入集群&#xff0c;并根据核心路由规则将流量转发到集群内的服务。 ingress和Istio工作栈中的virtual service都是基于service之上&#xff0c;更细致准确的一种流量规则。每一个pod对应的service是四层代理&…

City Builder Urban 城市都市街道建筑场景模型

目前拥有178项优质资产。 城市建设者:Urban一个高质量的资产包,专为快速的纽约式城市建设而设计,与所有渲染管道兼容。 资产 56个带LOD的街道和屋顶道具 13个可堆叠的建筑部件与LOD混合搭配 10个不同尺寸的建筑装饰/分离器,总共40个装饰 请参阅秋季列表的技术细节 1个带有C…