深度学习练手小例子——cifar10数据集分类问题

news2025/2/2 18:58:53

CIFAR-10 是一个经典的计算机视觉数据集,广泛用于图像分类任务。它包含 10 个类别的 60,000 张彩色图像,每张图像的大小是 32x32 像素。数据集被分为 50,000 张训练图像和 10,000 张测试图像。每个类别包含 6,000 张图像,具体类别包括:

  • 飞机 (airplane)
  • 汽车 (automobile)
  • 鸟 (bird)
  • 猫 (cat)
  • 鹿 (deer)
  • 狗 (dog)
  • 青蛙 (frog)
  • 马 (horse)
  • 船 (ship)
  • 卡车 (truck)

CIFAR-10 是一个多类分类问题,目标是根据图像内容(例如,物体的形状、颜色等特征)预测图像所属的类别。图像分类模型(如卷积神经网络 CNN)常用于这个任务,通过学习图像的空间特征来做出预测。

来看看实现过程:

import torch
import torchvision.datasets
from torch.utils.data import DataLoader
from torch import nn

train_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=True,transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=False,transform=torchvision.transforms.ToTensor(),
                                          download=True)
print(f"train length: {len(train_data)}")
print(f"test length: {len(test_data)}")
Files already downloaded and verified
Files already downloaded and verified
train length: 50000
test length: 10000

找到了CIFAR10数据集并且导入进来,用了三个卷积层的网络模型来训练,进行了10轮训练。

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
    def forward(self,x):
        x = self.model(x)
        return x
mynet = CNN()
mynet = mynet.cuda()

loss_func = nn.CrossEntropyLoss().cuda()
learning_rate = 0.0001
optimizer = torch.optim.Adam(mynet.parameters(),lr=learning_rate)
total_train = 0
total_test = 0
epoch = 10

for i in range(epoch):
    print(f"----No.{i+1} training...-----")
    mynet.train()
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.cuda()
        targets = targets.cuda()
        outputs = mynet(imgs)
        loss = loss_func(outputs,targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train = total_train + 1
        if total_train % 100 == 0:
            print(f"训练次数:{total_train},loss:{loss.item()}")
    #测试
    mynet.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.cuda()
            targets = targets.cuda()
            outputs = mynet(imgs)
            loss = loss_func(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy
    print(f"测试集的loss:{total_test_loss},准确率:{total_accuracy/len(test_data)}")
    torch.save(mynet, f'myCNN_{i+1}p.pth')
    print("模型保存成功")
----No.1 training...-----
训练次数:100,loss:2.0156445503234863
训练次数:200,loss:1.999146580696106
训练次数:300,loss:1.860052466392517
训练次数:400,loss:1.7510318756103516
训练次数:500,loss:1.7712416648864746
训练次数:600,loss:1.6994789838790894
训练次数:700,loss:1.7278780937194824
测试集的loss:257.74497163295746,准确率:0.41990000009536743
模型保存成功
----No.2 training...-----
训练次数:800,loss:1.515326976776123
训练次数:900,loss:1.485555648803711
训练次数:1000,loss:1.6138449907302856
训练次数:1100,loss:1.7650551795959473
训练次数:1200,loss:1.4380264282226562
训练次数:1300,loss:1.3843588829040527
训练次数:1400,loss:1.5849156379699707
训练次数:1500,loss:1.5038520097732544
测试集的loss:236.6359145641327,准确率:0.47110000252723694
模型保存成功
----No.3 training...-----
训练次数:1600,loss:1.4474828243255615
训练次数:1700,loss:1.4474865198135376
训练次数:1800,loss:1.7310973405838013
训练次数:1900,loss:1.5719612836837769
训练次数:2000,loss:1.6212022304534912
训练次数:2100,loss:1.2924069166183472
训练次数:2200,loss:1.256321907043457
训练次数:2300,loss:1.560215711593628
测试集的loss:221.27214550971985,准确率:0.5011000037193298
模型保存成功
----No.4 training...-----
训练次数:2400,loss:1.4557472467422485
训练次数:2500,loss:1.2620049715042114
训练次数:2600,loss:1.4703019857406616
训练次数:2700,loss:1.4131494760513306
训练次数:2800,loss:1.303225040435791
训练次数:2900,loss:1.4961038827896118
训练次数:3000,loss:1.2810102701187134
训练次数:3100,loss:1.337519645690918
测试集的loss:210.63251876831055,准确率:0.5252999663352966
模型保存成功
----No.5 training...-----
训练次数:3200,loss:1.1311390399932861
训练次数:3300,loss:1.2354803085327148
训练次数:3400,loss:1.2415772676467896
训练次数:3500,loss:1.4213279485702515
训练次数:3600,loss:1.4151396751403809
训练次数:3700,loss:1.2579320669174194
训练次数:3800,loss:1.201486349105835
训练次数:3900,loss:1.287066102027893
测试集的loss:202.65885722637177,准确率:0.5475999712944031
模型保存成功
----No.6 training...-----
训练次数:4000,loss:1.2759090662002563
训练次数:4100,loss:1.3534283638000488
训练次数:4200,loss:1.4388338327407837
训练次数:4300,loss:1.1126259565353394
训练次数:4400,loss:1.072700023651123
训练次数:4500,loss:1.2942607402801514
训练次数:4600,loss:1.3078550100326538
测试集的loss:195.93554836511612,准确率:0.5615000128746033
模型保存成功
----No.7 training...-----
训练次数:4700,loss:1.3510404825210571
训练次数:4800,loss:1.3887534141540527
训练次数:4900,loss:1.2628172636032104
训练次数:5000,loss:1.3063734769821167
训练次数:5100,loss:0.9366315007209778
训练次数:5200,loss:1.208983063697815
训练次数:5300,loss:1.0933520793914795
训练次数:5400,loss:1.2654058933258057
测试集的loss:190.015959918499,准确率:0.5735999941825867
模型保存成功
----No.8 training...-----
训练次数:5500,loss:1.1543941497802734
训练次数:5600,loss:1.0732381343841553
训练次数:5700,loss:1.179479718208313
训练次数:5800,loss:1.0669857263565063
训练次数:5900,loss:1.3145105838775635
训练次数:6000,loss:1.4563915729522705
训练次数:6100,loss:1.0026252269744873
训练次数:6200,loss:0.9769096374511719
测试集的loss:184.76930475234985,准确率:0.5831999778747559
模型保存成功
----No.9 training...-----
训练次数:6300,loss:1.2531676292419434
训练次数:6400,loss:1.0582406520843506
训练次数:6500,loss:1.467718482017517
训练次数:6600,loss:0.9885475635528564
训练次数:6700,loss:0.9887412190437317
训练次数:6800,loss:1.1251451969146729
训练次数:6900,loss:1.0831143856048584
训练次数:7000,loss:0.8735517263412476
测试集的loss:180.18007707595825,准确率:0.5949000120162964
模型保存成功
----No.10 training...-----
训练次数:7100,loss:1.1680148839950562
训练次数:7200,loss:0.9758849740028381
训练次数:7300,loss:1.1076891422271729
训练次数:7400,loss:0.8192071914672852
训练次数:7500,loss:1.2766807079315186
训练次数:7600,loss:1.2046217918395996
训练次数:7700,loss:0.8206453323364258
训练次数:7800,loss:1.1484739780426025
测试集的loss:176.2480058670044,准确率:0.6036999821662903
模型保存成功

拿网上下载的几张图片测试一下,注意路径

import torch
import torchvision
from PIL import Image
from torch import nn

# 10分类,分别为airplane'= 0 'automobile'= 1 'bird'= 2'cat'= 3 'deer'=  4 'dog'=  5 'frog'= 6 'horse'= 7 'ship'= 8 'truck'= 9
image_path = "/kaggle/input/testdata/bird.jpg"
image = Image.open(image_path)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                            torchvision.transforms.ToTensor()])
image = transform(image)
image = torch.reshape(image,(1,3,32,32))

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

model = torch.load("/kaggle/working/myCNN_10p.pth",map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
    output = model(image)
print(output.argmax(1))
tensor([2])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2290905.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Git】初识Git Git基本操作详解

文章目录 学习目标Ⅰ. 初始 Git💥注意事项 Ⅱ. Git 安装Linux-centos安装Git Ⅲ. Git基本操作一、创建git本地仓库 -- git init二、配置 Git -- git config三、认识工作区、暂存区、版本库① 工作区② 暂存区③ 版本库④ 三者的关系 四、添加、提交更改、查看提交日…

【JavaEE进阶】应用分层

目录 🎋序言 🍃什么是应用分层 🎍为什么需要应用分层 🍀如何分层(三层架构) 🎄MVC和三层架构的区别和联系 🌳什么是高内聚低耦合 🎋序言 通过上⾯的练习,我们学习了SpringMVC简单功能的开…

【数据结构篇】时间复杂度

一.数据结构前言 1.1 数据结构的概念 数据结构(Data Structure)是计算机存储、组织数据的⽅式,指相互之间存在⼀种或多种特定关系的数 据元素的集合。没有⼀种单⼀的数据结构对所有⽤途都有⽤,所以我们要学各式各样的数据结构, 如&#xff1a…

【数据结构】_链表经典算法OJ(力扣/牛客第二弹)

目录 1. 题目1:返回倒数第k个节点 1.1 题目链接及描述 1.2 解题思路 1.3 程序 2. 题目2:链表的回文结构 2.1 题目链接及描述 2.2 解题思路 2.3 程序 1. 题目1:返回倒数第k个节点 1.1 题目链接及描述 题目链接: 面试题 …

深度学习之“缺失数据处理”

缺失值检测 缺失数据就是我们没有的数据。如果数据集是由向量表示的特征组成,那么缺失值可能表现为某些样本的一个或多个特征因为某些原因而没有测量的值。通常情况下,缺失值由特殊的编码方式。如果正常值都是正数,那么缺失值可能被标记为-1…

MYSQL--一条SQL执行的流程,分析MYSQL的架构

文章目录 第一步建立连接第二部解析 SQL第三步执行 sql预处理优化阶段执行阶段索引下推 执行一条select 语句中间会发生什么? 这个是对 mysql 架构的深入理解。 select * from product where id 1;对于mysql的架构分层: mysql 架构分成了 Server 层和存储引擎层&a…

C++解决输入空格字符串的三种方法

一.gets和fgets char * gets ( char * str ); char * fgets ( char * str, int num, FILE * stream ); 1. gets 是从第⼀个字符开始读取,⼀直读取到 \n 停⽌,但是不会读取 \n ,也就是读取到的内容 中没有包含 \n ,但是会在读取到的内…

多模态论文笔记——NaViT

大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本文详细解读多模态论文NaViT(Native Resolution ViT),将来自不同图像的多个patches打包成一个单一序列——称为Patch n’ Pack—…

云中漫步:精工细作铸就免费公益刷步平台

云中漫步,历经三年深度研发与优化,平台以高稳定性、零成本及公益属性为核心特色,依托前沿技术手段与多重安全防护机制,确保用户步数数据的精准修改与隐私安全。我们致力于提供无缝流畅的用户体验,让每一次步数更新都轻…

neo4j入门

文章目录 neo4j版本说明部署安装Mac部署docker部署 neo4j web工具使用数据结构图数据库VS关系数据库 neo4j neo4j官网Neo4j是用ava实现的开源NoSQL图数据库。Neo4作为图数据库中的代表产品,已经在众多的行业项目中进行了应用,如:网络管理&am…

【ts + java】古玩系统开发总结

src别名的配置 开发中文件和文件的关系会比较复杂,我们需要给src文件夹一个别名吧 vite.config.js import { defineConfig } from vite import vue from vitejs/plugin-vue import path from path// https://vitejs.dev/config/ export default defineConfig({pl…

【Docker】快速部署 Nacos 注册中心

【Docker】快速部署 Nacos 注册中心 引言 Nacos 注册中心是一个用于服务发现和配置管理的开源项目。提供了动态服务发现、服务健康检查、动态配置管理和服务管理等功能,帮助开发者更轻松地构建微服务架构。 仓库地址 https://github.com/alibaba/nacos 步骤 拉取…

SpringCloud篇 微服务架构

1. 工程架构介绍 1.1 两种工程架构模型的特征 1.1.1 单体架构 上面这张图展示了单体架构(Monolithic Architecture)的基本组成和工作原理。单体架构是一种传统的软件架构模式,其中所有的功能都被打包在一个单一的、紧密耦合的应用程序中。 …

tf.Keras (tf-1.15)使用记录4-model.fit方法及其callbacks参数

model.fit() 方法是 TensorFlow Keras 中用于训练模型的核心方法。 其中里面的callbacks参数是实现模型保存、监控、以及和tensorboard联动的重要API 1 model.fit() 方法的参数及使用 必需参数 x: 训练数据的输入。可以是 NumPy 数组、TensorFlow tf.data.Dataset、Python 生…

Easy系列PLC尺寸测量功能块ST代码(激光微距仪应用)

激光微距仪可以测量短距离内的产品尺寸,产品规格书的测量 精度可以到0.001mm。具体需要看不同的型号。 1、激光微距仪 2、尺寸测量应用 下面我们以测量高度为例子,设计一个高度测量功能块,同时给出测量数据和合格不合格指标。 3、高度测量功能块 4、复位完成信号 5、功能…

996引擎 -地图-添加安全区

996引擎 -地图-添加安全区 文件位置配置 cfg_startpoint.xls特效效果1345参考资料文件位置 文件位置服务端D:\996M2-lua\MirServer-lua\Mir200客户端D:\996M2-lua\996M2_debug\dev配置 cfg_startpoint.xls 服务端\Mir200\Envir\DATA\cfg_startpoint.xls 填歪了也有可能只画一…

[Collection与数据结构] B树与B+树

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

redex快速体验

第一步: 2.回调函数在每次state发生变化时候自动执行

【VM】VirtualBox安装CentOS8虚拟机

阅读本文前,请先根据 VirtualBox软件安装教程 安装VirtualBox虚拟机软件。 1. 下载centos8系统iso镜像 可以去两个地方下载,推荐跟随本文的操作用阿里云的镜像 centos官网:https://www.centos.org/download/阿里云镜像:http://…

电子电气架构 --- 汽车电子拓扑架构的演进过程

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活…