[oneAPI] 手写数字识别-LSTM

news2024/7/6 5:26:15

[oneAPI] 手写数字识别-LSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

模型

# Recurrent neural network (many-to-one)
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

训练过程

model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# 模型
model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/887851.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】自定义实现strcpy函数

大家好,我是苏貝,本篇博客带大家了解如何自定义实现strcpy函数,如果你觉得我写的还不错的话,可以给我一个赞👍吗,感谢❤️ 一. 了解strcpy函数。 函数原型:char* strcpy( char* destination , …

LLM - 大模型评估指标之 BLEU

目录 一.引言 二.BLEU 简介 1.Simple BLEU 2.Modified BLEU 3.Modified n-gram precision 4.Sentence brevity penalty 三.BLEU 计算 1.计算句子与单个 reference 2.计算句子与多个 reference 四.总结 一.引言 机器翻译的人工评价广泛而昂贵,且人工评估可…

【uni-app报错】获取用户收货地址uni.chooseAddress()报错问题

chooseAddress:fail the api need to be declared in …e requiredPrivateInf 原因: 小程序配置 / 全局配置 (qq.com) 解决: 登录小程序后台申请接口 按照流程申请即可 在项目根目录中找到 manifest.json 文件,在左侧导航栏选择源码视图&a…

代码pytorch-adda-master跑通记录

前言 最近在学习迁移学习,ADDA算法,由于嫌自己写麻烦,准备先跑通别人的代码。 代码名称:pytorch-adda-master 博客:https://www.cnblogs.com/BlairGrowing/p/17020378.html github地址:https://github.com…

Vue3 —— watchEffect 高级侦听器

该文章是在学习 小满vue3 课程的随堂记录示例均采用 <script setup>&#xff0c;且包含 typescript 的基础用法 前言 Vue3 中新增了一种特殊的监听器 watchEffect&#xff0c;它的类型是&#xff1a; function watchEffect(effect: (onCleanup: OnCleanup) > void,o…

第三届“赣政杯”网络安全大赛 | 赛宁筑牢安全应急防线

​​为持续强化江西省党政机关网络安全风险防范意识&#xff0c;提高信息化岗位从业人员基础技能&#xff0c;提升应对网络安全风险处置能力。由江西省委网信办、江西省发展改革委主办&#xff0c;江西省大数据中心、国家计算机网络与信息安全管理中心江西分中心承办&#xff0…

【负荷频率和电压控制】电力系统的组合负荷频率和电压控制模型研究(Simulink)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

CXL 寄存器介绍 (1) - 寄存器分类

&#x1f525;点击查看精选 CXL 系列文章&#x1f525; &#x1f525;点击进入【芯片设计验证】社区&#xff0c;查看更多精彩内容&#x1f525; &#x1f4e2; 声明&#xff1a; &#x1f96d; 作者主页&#xff1a;【MangoPapa的CSDN主页】。⚠️ 本文首发于CSDN&#xff0c…

【贪心】CF1822 E

Problem - 1822E - Codeforces 题意&#xff1a; 思路&#xff1a; 简单复盘一下思路 首先&#xff0c;n为奇数或有一种字符出现次数> n / 2就无解的结论是可以根据样例看出来的 然后就显然的发现&#xff0c;每2个不同的回文对有1的贡献 那么这样匹配之后会有剩余的回文…

JVM - 垃圾回收机制

JVM的垃圾回收机制(简称GC) JVM的垃圾回收机制非常强大&#xff0c;是JVM的一个很重要的功能&#xff0c;而且这也是跟对象实例息息相关的&#xff0c;如果对象实例不用了要怎么清除呢&#xff1f; 如何判断对象已经没用了 当JVM认为一个对像已经没用了&#xff0c;就会把这个…

【C++STL基础入门】string类的基础使用

文章目录 前言一、STL使用概述二、string类概述1.string类的构造函数string输出示例代码 2.string类属性属性是什么属性API示例代码 3.输出输出全部输出单个字符 总结 前言 本系列文章使用VS2022&#xff0c;C20版本 STL&#xff08;Standard Template Library&#xff09;是…

linux平台实现虚拟磁盘驱动(通用的块设备驱动和基于SCSI的磁盘驱动)

by fanxiushu 2023-08-16 转载或引用请桌面原始作者。 实现linux平台的虚拟磁盘驱动&#xff0c;是为了要实现在linux远程无盘启动的。 linux平台下的无盘启动&#xff0c;现成的办法有许多&#xff0c;比如iSCSI&#xff0c;NFS&#xff0c;NBD等都可以&#xff0c; 不过我都没…

JVM中释放内存的三种方法

判断是否需要垃圾回收可以采用分析。 1标记--清除算法 分为两个阶段&#xff0c;标记和清除&#xff0c;先利用可达性分型标记还存活的对象&#xff0c;之后将没有被标记的对象删除&#xff0c;这样容易生成空间碎片&#xff0c;而且效率不稳定 标记阶段&#xff1a; 标记阶段…

C#和Java的大端位和小端位的问题

C#代码里就是小端序,Java代码里就是大端序&#xff0c; 大端位:big endian,是指数据的高字节保存在内存的低地址中&#xff0c;而数据的低字节保存在内存的高地址中&#xff0c;也叫高尾端 小端位:little endian,是指数据的高字节保存在内存的高地址中,而数据的低字节保存在内存…

vue : 无法加载文件 F:\nodejs\vue.ps1,因为在此系统上禁止运行脚本。

报错信息如下 在使用Windows PowerShell输入指令查看版本时或者用脚手架创建vue项目时都有可能报错&#xff0c;报错信息如下&#xff1a;vue : 无法加载文件 F:\nodejs\vue.ps1&#xff0c;因为在此系统上禁止运行脚本 解决方案&#xff1a; 原因&#xff1a;因为Windows Po…

问道管理:金叉死叉十句口诀?

随着越来越多人参加股票买卖&#xff0c;关于股票商场的了解也变得越来越重要。其中一项重要的概念就是金叉死叉&#xff0c;这是一种均线穿插的现象&#xff0c;而均线穿插是技能剖析的重点之一。在本文中&#xff0c;咱们将会从多个角度剖析金叉死叉&#xff0c;并给出十句口…

Qt应用开发(基础篇)——MDI窗口 QMdiArea QMdiSubWindow

一、前言 QMdiArea类继承于QAbstractScrollArea&#xff0c;QAbstractScrollArea继承于QFrame&#xff0c;是Qt用来显示MDI窗口的部件。 滚屏区域基类 QAbstractScrollAreahttps://blog.csdn.net/u014491932/article/details/132245486 框架类 QFramehttps://blog.csdn.net/u01…

sqlite3用法

完成数据库的插入。 程序如下&#xff1a; 运行结果如下&#xff1a;

Python学习笔记_基础篇(十)_socket编程

本章内容 1、socket 2、IO多路复用 3、socketserver Socket socket起源于Unix&#xff0c;而Unix/Linux基本哲学之一就是“一切皆文件”&#xff0c;对于文件用【打开】【读写】【关闭】模式来操作。socket就是该模式的一个实现&#xff0c;socket即是一种特殊的文件&…

返回数组中最大(最小)值的位置(索引值),查找范围可以是所有元素,或者指定行列numpy.argmax()numpy.argmin()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 返回数组中最大(最小)值的位置(索引值)&#xff0c; 查找范围可以是所有元素&#xff0c;或者指定行列 numpy.argmax() numpy.argmin() [太阳]选择题 关于以下代码说法错误的一项是? import…