[oneAPI] 手写数字识别-BiLSTM

news2024/11/20 1:46:09

[oneAPI] 手写数字识别-BiLSTM

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

加载数据

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

模型

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, num_classes)  # 2 for bidirection

    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)  # 2 for bidirection
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

训练过程

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                  .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

结果

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# 模型
model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/885976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[obs] 编译记录

2023.7 obs 最近编译方式经常改。本次使用的是最新的方式编译 2023/7月份版本,记录一下遇到的坑 obs 最新版默认使用 vs2022 才能编译,如果想用 vs2019 编译,改下面这个地方就好了 CMakePresets.json 文件的obs aja 编译有问题 解决方案&a…

【方法】如何给分卷压缩文件添加密码?

在压缩文件的时候,如果文件比较大,或者网盘单个文件限制了大小,很多人会选择将文件压缩成分卷文件。 如果文件还有保密需求,那如何在压缩文件时设置分卷,又同时设置密码保护呢?下面小编来举例看看如何操作…

测试平台开发:(19)自动化测试脚本工具化 2

上一篇:测试平台开发:(18)自动化测试脚本工具化_要开朗的spookypop的博客-CSDN博客 本篇先实现创建自动化脚本的功能,将selenium代码语言转化为文字语言,如下图所示: 例1:比如下面的代码,转化为语言“用谷歌浏览器打开XX页面”: service = ChromeService(executabl…

【Python】基础语法:变量类型和动态类型

文章目录 1. 常量2. 变量3. 动态类型特性 努力经营当下 直至未来明朗 1. 常量 浮点数在内存中表示使用的是IEEE754标准,这套规则下,在内存中表示该浮点数额时候可能会存在微小的误差 在进行运算的时候,最好可以将数字(字面值常…

如何理解“I/O指令是CPU系统指令的一部分”?

I/O指令作用过程(以 I/O端口独立编址方式为例): CPU识别出当前指令是I/O指令,向I/O总线发送相应控制信号和地址信息。 之前迷惑的点:默认以为I/O指令是作用于I/O接口的,进而产生疑问,I/O接口中…

期权行权和不行权的区别

对于期权小白刚入门来说,期权是一种金融衍生品,简单理解期权就是大盘指数为标的物,可以做多和做空,在期权到期日最后一天,你面临持仓合约是选择行权呢还是不行权,下文科普期权行权和不行权的区别&#xff0…

03.Spring Security 如何保护用户密码

1. 前言 上一文我们对Spring Security中的重要用户信息主体UserDetails进行了探讨。中间例子我们使用了明文密码,规则是通过对密码明文添加{noop}前缀。那么本节将对 Spring Security 中的密码编码进行一些探讨。 2. 不推荐使用md5 首先md5 不是加密算法&#xf…

前端跨域的原因以及解决方案(vue),一文让你真正理解跨域

跨域这个问题,可以说是前端的必需了解的,但是多少人是知其然不知所以然呢? 下面我们来梳理一下vue解决跨域的思路。 什么情况会跨域? ​ 跨域的本质就是浏览器基于同源策略的一种安全手段。所谓同源就是必须有以下三个相同点:协议相同、域名…

网络套接字

网络套接字 文章目录 网络套接字认识端口号初识TCP协议初识UDP协议网络字节序 socket编程接口socket创建socket文件描述符bind绑定端口号sockaddr结构体netstat -nuap:查看服务器网络信息 代码编译运行展示 实现简单UDP服务器开发 认识端口号 端口号(port)是传输层协…

Python 3 使用HBase 总结

HBase 简介和安装 请参考文章:HBase 一文读懂 Python3 HBase API HBase 前期准备 1 安装happybase库操作hbase 安装该库 pip install happybase2 确保 Hadoop 和 Zookeeper 可用并开启 确保Hadoop 正常运行 确保Zookeeper 正常运行3 开启HBase thrift服务 使用命…

谈谈召回率(R值),准确率(P值)及F值

通俗解释机器学习中的召回率、精确率、准确率,一文让你一辈子忘不掉这两个词 赶时间的同学们看这里:提升精确率是为了不错报、提升召回率是为了不漏报 先说个题外话,暴击一下乱写博客的人,网络上很多地方分不清准确率和精确率&am…

前端实战系列:【2023酷炫前端特效】HTML蜂巢特效(附完整可执行代码 + 全网唯一!超详细注释分析 (熬夜换来的...),让你看得懂,敲的出代码!

久别重逢非昨日,万语千言不忍谈。 🎯作者主页: 追光者♂🔥 🌸个人简介: 💖[1] 计算机专业硕士研究生💖 🌿[2] 2023年城市之星领跑者TOP1(哈尔滨)🌿 🌟[3] 2022年度博客之星人工智能领域TOP4🌟 🏅[4] 阿里云社区特邀专家博主🏅 🏆…

【等保测评】等保初级测评师试题合集(3w字汇总)

【等保测评】信息安全等级保护初级测评师试题合集 一、法律法规单选多选判断 二、实施指南单选多选 三、定级指南四、基本要求五、测评准则六、信息安全等级测评模拟模拟试题1一、单选二、多选三、判断四、简答 模拟试题2一、单选二、多选三、判断四、简答 模拟试题3一、单选二…

MPLS基础知识

MPLS:多协议标签交换 多协议:可以基于多种不同的3层协议来生成2.5层的标签信息; 包交换—包为网络层的PDU,故包交换是基于IP地址进行数据转发;就是路由器的路由行为; 原始的包交换:数据包进入…

【自动化测试】接口自动化01

文章目录 一、熟悉若requests库以及底层方法的调用逻辑二、接口自动化以及正则和Jsonpath提取器的应用6. 高频面试题:9. 示例:接口关联13. 文件上传示例14. cookie关联的接口 努力经营当下 直至未来明朗 一、熟悉若requests库以及底层方法的调用逻辑 接…

on-java-8 知识总结(低频部分)

Perl简介 Perl 是 Practical Extraction and Report Language 的缩写,可翻译为 “实用报表提取语言”。最开始,Perl是一门文本处理语言,不过现在已经是通用的语言了。 作者吐槽其write-only,想必是因为其灵活性,同一目标下能写出…

Android设备通过蓝牙HID技术模拟键盘实现

目录 一,背景介绍 二,技术方案 2.1 获取BluetoothHidDevice实例 2.2 注册/解除注册HID实例 2.3 Hid report description描述符生成工具 2.4 键盘映射表 2.5 通过HID发送键盘事件 三,实例 一,背景介绍 日常生活中&#xff0…

第15集丨Vue 江湖 —— 组件

目录 一、为什么需要组件1.1 传统方式编写应用1.2 使用组件方式编写应用1.3 Vue的组件管理 二、Vue中的组件1.1 基本概念1.1.1 组件分类1.1.2 Vue中使用组件的三大步骤:1.1.3 如何定义一个组件1.1.4 如何注册组件1.1.5 如何使用组件 1.2 注意点1.2.1 关于组件名1.2.2 关于组件标…

14.Linkedin在中国市场的主要竞争对手

自Linkedin敲响了中国的大门之后,在国内市场也拥有了大量的用户。经过不断地发展了改革创新,更是成为了国内影响力比较大的职业社交平台之一。为了能够在国内市场中取得成功,在进入国内之前,Linkedin就采取了全新的模式,不仅仅是销售机构,也具备了产品技术、市场、公关等完整的…

达芬奇无法播放视频,黑屏,不能预览画面

说一下其中一个原因,是因为用了像是xdisplay,super display这类软件,他们会安装一个显卡驱动而这个驱动就会导致这个问题。 方法很简单,在设备管理器里面,显示卡一览,卸载掉除了你的电脑显卡以外的虚拟显卡…