[oneAPI] 手写数字识别-GAN

news2024/10/5 12:57:17

[oneAPI] 手写数字识别-GAN

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

加载数据

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],  # 1 for greyscale channels
                         std=[0.5])])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)

模型

# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

# Generator 
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

训练过程

# Device setting
D = D.to(device)
G = G.to(device)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()


# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)

        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                      Train the discriminator                       #
        # ================================================================== #

        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        # ================================================================== #
        #                        Train the generator                         #
        # ================================================================== #

        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)

        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)

        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))

    # Save real images
    if (epoch + 1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))

# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

结果

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# Device setting
D = D.to(device)
G = G.to(device)

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/888922.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高德地图1.4.15楼层处理

前因: 接入了高德1.4.15JS API的项目中使用了mapStyle: ‘amap://styles/grey’,在这个模式下楼层几近透明,方案一是升级到2.0然后加wallColor (表示建筑物墙面的颜色)和roofColor (表示建筑物屋顶的颜色)…

机器学习与模型识别1:SVM(支持向量机)

一、简介 SVM是一种二类分类模型,在特征空间中寻找间隔最大的分离超平面,使得数据得到高效的二分类。 二、SVM损失函数 SVM 的三种损失函数衡量模型的性能。 1. 0-1 损失: 当正例样本落在 y0 下方则损失为 0,否则损失为…

uniapp小程序实现上传图片功能,并显示上传进度

效果图: 实现方法: 一、通过uni.chooseMedia(OBJECT)方法,拍摄或从手机相册中选择图片或视频。 官方文档链接: https://uniapp.dcloud.net.cn/api/media/video.html#choosemedia uni.chooseMedia({count: 9,mediaType: [image,video],so…

如何最大利用WhatsApp高效拓客引流?

对话式商务盛行,WhatsApp是许多国家最受欢迎的聊天应用程序,包括巴西、德国、印尼、泰国、新加坡等。用户渗透率超过80%,作为一个无敌社交APP,自然也是跨境业务的首选。 以下是使用 WhatsApp 进行电子商务时需要记住的一些策略。…

做好需求管理的四个最佳实践

改进您的需求管理过程可以对你的开发过程产生重大影响,所带来的益处包括:提高效率、缩短上市时间,以及节省宝贵的预算和资源。需求是最能向工程师说明要构建什么,以及向测试人员说明要测试什么的信息。 需求具有三个主要功能&…

数据结构——栈(C语言)

需求:无 栈的概念: 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈顶,另一端为栈底。栈中的数据元素遵守后进先出(LIFO)原则。压栈&…

FPGA应用学习笔记-----布线布局优化

优化约束: 设置到最坏情况下会过多 布局和布线之间的关系: 最重要的是与处理器努力的,挂钩允许设计者调整处理器努力的程度 逻辑复制: 不能放置多个负载,只使用在关键路径钟 减少布线延时,但会增加面积&a…

大规模SFT微调指令数据的生成

前言 想要微调一个大模型,前提是得有一份高质量的SFT数据,可以这么说其多么高质量都不过分,关于其重要性已经有很多工作得以验证,感兴趣的小伙伴可以穿梭笔者之前的一篇文章: 《大模型时代下数据的重要性》&#xff…

【AI】百度AI助力开发,测试一下百度搜索的AI能力如何

百度搜索页面有个AI对话,点击进去看看: 是不是文心一言?它说不是。 测试一下辅助写代码功能: 1、写个爬虫: 代码: import requests from bs4 import BeautifulSoup# 目标网站的URL url "http:/…

Ubuntu虚拟机网络无法连接的几种解决方法

虚拟机网络无法连接的几种解决方法 问题状况描述可能的解决方案 问题状况描述 Ubuntu虚拟机没有网络,无法ping通互联网,左上角网络连接图标消失等情况可能的解决方案 1.重启虚拟机网络编辑器 2.重启虚拟机网络适配器 3.重启虚拟机网络服务器1.重启网络…

优思学院|五大工具:APQP、FMEA、MSA、SPC、PPAP

在现代制造业中,质量是企业成功的关键之一。为了确保产品和过程的质量,需要采用一系列有效的工具和方法。APQP、FMEA、MSA、SPC和PPAP被认定为质量管理体系的五大核心工具,这些工具不仅在汽车行业中得到广泛应用,还被其他制造领域…

生信豆芽菜-分组比较的表格

网址:http://www.sxdyc.com/visualsCliTableCompare 1、数据准备 两列的数据,最后比较这两组的样本分布 2、选择两个分组的颜色,有几个就选几个颜色,表头颜色,图片的宽度和高度,提交等待运行成功 3、结…

PHP入门基础教程 - 专栏导读

🏆作者简介,黑夜开发者,全栈领域新星创作者✌,CSDN博客专家,阿里云社区专家博主,2023年6月CSDN上海赛道top4。 🏆数年电商行业从业经验,历任核心研发工程师,项目技术负责…

148. 排序链表

题目描述 给你链表的头结点 head ,请将其按 升序 排列并返回 排序后的链表 。 示例 1: 输入:head [4,2,1,3] 输出:[1,2,3,4]示例 2: 输入:head [-1,5,3,4,0] 输出:[-1,0,3,4,5]示例 3&#…

【仿写tomcat】四、解析http请求信息,响应给前端,HttpServletRequest、HttpServletResponse的简单实现

思考 在解析请求之前我们要思考一个问题,我们解析的是其中的哪些内容? 对于最基本的实现,当然是请求类型,请求的url以及请求参数,我们可以根据请求的类型作出对应的处理,通过url在我们的mapstore中找到se…

计算机控制技术|17/8|11:32

目录 1. 学习计算控制系统需要的相关知识有哪些? 2. 计算机控制系统是什么? 3. 计算机控制系统的主要研究内容是什么? 4. 计算机控制系统的主要特点是什么? 5. 计算机控制系统的性能指标主要有哪些? 6. 计算机控…

学生宿舍管理系统(前端java+后端Vue)源码

完整资料下载链接 界面介绍 登录 宿舍管理 菜单管理 角色管理 ###班级管理

Nginx常见的三个漏洞

目录 $uri导致的CRLF注入漏洞 两种常见场景 表示uri的三个变量 案例 目录穿越漏洞 案例 Http Header被覆盖的问题 案例 $uri导致的CRLF注入漏洞 两种常见场景 用户访问http://example.com/aabbcc,自动跳转到https://example.com/aabbcc 用户访问http://exa…

Java中的枚举类型

一,什么是枚举 在Java中,枚举(Enumeration)是一种特殊的数据类型,它允许我们定义一个固定数量的常量集合。枚举类型在Java中是通过关键字enum来定义的。每个枚举常量都是枚举类型的实例,它们在枚举类型中以…

【RP2040】香瓜树莓派RP2040之自定义的短按、双击、长按按键

本文最后修改时间:2022年09月15日 11:02 一、本节简介 本节介绍如何编写一个可以自己选择引脚的短按、双击、长按三种方式的按键驱动。 二、实验平台 1、硬件平台 1)树莓派pico开发板 ①树莓派pico开发板*2 ②micro usb数据线*2 2)电脑…