[oneAPI] 图像分类CIFAR-10

news2024/9/23 23:23:09

[oneAPI] 图像分类CIFAR-10

  • 图像分类
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

图像分类

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用CIFAR-10数据集, 数据集是一个常用的计算机视觉数据集,包含了 60000 张 32x32 像素的彩色图像,涵盖了 10 个不同的类别,每个类别有 6000 张图像。这个数据集被广泛用于图像分类、物体识别等任务的训练和评估。

数据集被分成了训练集和测试集,其中训练集包含 50000 张图像,测试集包含 10000 张图像。
CIFAR-10 数据集包含以下 10 个类别:

飞机(airplane)
汽车(automobile)
鸟类(bird)
猫(cat)
鹿(deer)
狗(dog)
青蛙(frog)
马(horse)
船(ship)
卡车(truck)

参数与包

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
num_epochs = 80
batch_size = 100
learning_rate = 0.001

加载数据

# Image preprocessing modules
transform = transforms.Compose([
    transforms.Pad(4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data/',
                                             train=True,
                                             transform=transform,
                                             download=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data/',
                                            train=False,
                                            transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

模型

# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=1, bias=False)


# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out


# ResNet
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[1], 2)
        self.layer3 = self.make_layer(block, 64, layers[2], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

训练过程

model = ResNet(ResidualBlock, [2, 2, 2]).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)


# For updating learning rate
def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


# Train the model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
                  .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

    # Decay learning rate
    if (epoch + 1) % 20 == 0:
        curr_lr /= 3
        update_lr(optimizer, curr_lr)

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

# Save the model checkpoint
torch.save(model.state_dict(), 'resnet.ckpt')

结果

在这里插入图片描述

在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

# 模型
model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/890647.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM——JVM 垃圾回收

文章目录 写在前面本节常见面试题本文导火索 1 揭开 JVM 内存分配与回收的神秘面纱1.1 对象优先在 eden 区分配1.2 大对象直接进入老年代1.3 长期存活的对象将进入老年代1.4 动态对象年龄判定1.5 主要进行 gc 的区域 2 对象已经死亡?2.1 引用计数法2.2 可达性分析算…

【C++】set/multiset容器

1.set基本概念 #include <iostream> using namespace std;//set容器构造和赋值 #include<set>//遍历 void printSet(const set<int>& st) {for (set<int>::const_iterator it st.begin(); it ! st.end(); it){cout << *it << " …

网络编程学习

网络编程 软件结构 C/S结构&#xff1a;QQ、迅雷、百度网盘 程序员&#xff1a;开发客户端和服务端程序用户&#xff1a;需要下载升级更新客户端对网络带宽要求相对较低数据安全性相对较高 B/S结构&#xff1a;IE、谷歌、火狐 程序员&#xff1a;只需要开发服务端程序用户&am…

泛微E-Office任意文件上传漏洞复现

0x01 产品简介 泛微E-Office是一款标准化的协同 OA 办公软件&#xff0c;泛微协同办公产品系列成员之一,实行通用化产品设计&#xff0c;充分贴合企业管理需求&#xff0c;本着简洁易用、高效智能的原则&#xff0c;为企业快速打造移动化、无纸化、数字化的办公平台。 0x02 漏…

【云原生,k8s】基于Helm管理Kubernetes应用

第四阶段 时 间&#xff1a;2023年8月18日 参加人&#xff1a;全班人员 内 容&#xff1a; 基于Helm管理Kubernetes应用 目录 一、Kubernetes部署方式 &#xff08;一&#xff09;minikube &#xff08;二&#xff09;二进制包 &#xff08;三&#xff09;Kubeadm …

米尔瑞萨RZ/G2L开发板-02 ffmpeg的使用和RTMP直播

最近不知道是不是熬夜太多&#xff0c;然后记忆力减退了&#xff1f; 因为板子回来以后我就迫不及待的试了一下板子&#xff0c;然后发现板子有SSH&#xff0c;但是并没有ffmpeg&#xff0c;最近总是在玩&#xff0c;然后今天说是把板子还原一下哇&#xff0c;然后把官方的固件…

Linux系统之wget命令的基本使用

Linux系统之wget命令的基本使用 一、wget命令介绍二、本次实践环境三、wget命令的使用帮助3.1 wget命令的基本语法3.2 wget选项解释 四、安装wget工具4.1 检查yum仓库状态4.2 安装wget工具 五、wget命令的基本使用5.1 直接下载文件5.2 下载时指定文件名5.3 后台下载文件5.4 限速…

视频汇聚/视频云存储/视频监控管理平台EasyCVR如何进行CDN转推?

视频汇聚/视频云存储/集中存储/视频监控管理平台EasyCVR能在复杂的网络环境中&#xff0c;将分散的各类视频资源进行统一汇聚、整合、集中管理&#xff0c;实现视频资源的鉴权管理、按需调阅、全网分发、云存储、智能分析等&#xff0c;视频智能分析平台EasyCVR融合性强、开放度…

postgresql中基础sql查询

postgresql中基础sql查询 创建表插入数据创建索引删除表postgresql命令速查简单查询计算查询结果 利用查询条件过滤数据模糊查询 创建表 -- 部门信息表 CREATE TABLE departments( department_id INTEGER NOT NULL -- 部门编号&#xff0c;主键, department_name CHARACTE…

【【verilog典型电路设计之FIR滤波器的设计】】

verilog典型电路设计之FIR滤波器的设计 我们常用的FIR滤波器称为有限冲激响应 是一种常用的数字滤波器 &#xff0c;采用对已输入样值的加权和来形成它的输出。 对于输入序列X[n] 的FIR滤波器可用下图所示的结构示意图来表示&#xff0c;其中X[n] 是输入数据流。各级的输入连…

Idea中隐藏指定文件或指定类型文件

Setting ->Editor ->Code Style->File Types → Ignored Files and Folders输入要隐藏的文件名&#xff0c;支持*号通配符回车确认添加

微服务中间件--分布式事务

分布式事务 a.理论基础1) CAP定理2) BASE理论 b.Seata1) XA模式1.a) 实现XA模式 2) AT模式3) TCC模式3.a) 代码实现 4) Saga模式5) 四种模式对比6) TC的异地多机房容灾架构 a.理论基础 1) CAP定理 分布式系统有三个指标&#xff1a; Consistency&#xff08;一致性&#xff…

计算机视觉之三维重建(二)(摄像机标定)

标定示意图 标定目标 P ′ M P w K [ R T ] P w P^{}MP_wK[R \space T]P_w P′MPw​K[R T]Pw​ 其中 K K K为内参数&#xff0c; [ R T ] [R \space T] [R T]为外参数。该式子需要使用至少六对内外点对进行求解内外参数&#xff08;11个未知参数&#xff09;。 其中 R 3 3 …

【快应用】快应用广告学习之激励视频广告

【关键词】 快应用、激励视频广告、广告接入 【介绍】 一、关于激励视频广告 定义&#xff1a;用户通过观看完整的视频广告&#xff0c;获得应用内相关的奖励。适用场景&#xff1a;游戏/快游戏的通关、继续机会、道具获取、积分等场景中&#xff0c;阅读、影音等应用的权益体系…

二、SQL,如何实现表的创建和查询

1、新建表格&#xff08;在当前数据库中新建一个表格&#xff09;&#xff1a; &#xff08;1&#xff09;基础语法&#xff1a; create table [表名]( [字段:列标签] [该列数据类型] comment [字段注释], [字段:列标签] [该列数据类型] comment [字段注释], ……&#xff0c…

Spring Clould 服务间通信 - Feign

视频地址&#xff1a;微服务&#xff08;SpringCloudRabbitMQDockerRedis搜索分布式&#xff09; Feign-基于Feign远程调用&#xff08;P30&#xff09; 先来看我们以前利用RestTemplate发起远程调用的代码&#xff1a; 存在下面的问题&#xff1a; 代码可读性差&#xff0c…

字符设备驱动分布注册

驱动文件&#xff1a; 脑图&#xff1a; 现象&#xff1a;

初识Sentinel

目录 1.解决雪崩的方式有4种&#xff1a; 1.1.2超时处理&#xff1a; 1.1.3仓壁模式 1.1.4.断路器 1.1.5.限流 1.1.6.总结 1.2.服务保护技术对比 1.3.Sentinel介绍和安装 1.3.1.初识Sentinel 1.3.2.安装Sentinel 1.4.微服务整合Sentinel 2.流量控制 2.1.簇点链路 …

设计必备:6个免费PNG图片素材网站推荐

本文推荐了6个PNG图片素材网站,种类繁多、质量上乘,完全免费下载。这些网站收录了海量经过抠图处理的PNG图片,设计师无需操心抠图,直接搜索下载需要的图片进行创作,可以极大地提高工作效率,节约时间成本。获取这些高质量的PNG图片素材,你就可以无后顾之忧地进行创作设计了。不用…

Hive的窗口函数与行列转换函数及JSON解析函数

1. 系统内置函数 查看系统内置函数&#xff1a;show functions ; 显示内置函数的用法&#xff1a; desc function lag; – lag为函数名 显示详细的内置函数用法: desc function extended lag; 1.1 行转列 行转列是指多行数据转换为一个列的字段。 Hive行转列用到的函数 con…