【Pytorch项目实战】之图像分类与识别:手写数字识别(MNIST)、普适物体识别(CIFAR-10)

news2024/9/17 9:22:33

文章目录

  • 图像分类与识别
    • (一)实战:基于CNN的手写数字识别(数据集:MNIST)
    • (二)实战:基于CNN的图像分类(数据集:CIFAR-10)

图像分类与识别

分类、识别、检测的区别

  • 分类:对图像中特定的对象进行分类(不同类别)。
    • 如:CIFAR分类。
  • 识别:对图像中特定的对象进行识别(同一类别)。
    • 如:人脸识别、虹膜识别、指纹识别。
  • 检测:识别对象在图像中的位置。
    • 如:人脸检测、行人检测、车辆检测、交通标志检测、视频监控等。

(一)实战:基于CNN的手写数字识别(数据集:MNIST)

在这里插入图片描述

############################################
# 主要步骤:
#       (1)利用Pytorch内置函数mnist下载数据。
#       (2)利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器。
#       (3)可视化原数据
#       (4)利用nn工具箱构建神经网络模型
#       (5)实例化模型,并定义损失函数及优化器。
#       (6)训练模型
#       (7)可视化结果
############################################
# (1)MNIST数据集是机器学习领域中非常经典的一个数据集, 共4个文件,训练集、训练集标签、测试集、测试集标签。由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。
# (2)直接下载下来的数据是无法通过解压或者应用程序打开的,因为这些文件不是任何标准的图像格式而是以字节的形式进行存储的,所以必须编写程序来打开它。
############################################
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn

from torchvision.datasets import mnist          # 导入内置的mnist数据
import torchvision.transforms as transforms     # 导入图像预处理模块
from torch.utils.data import DataLoader

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     # "OMP: Error #15: Initializing libiomp5md.dll"
############################################
# (1)定义超参数
train_batch_size = 64
test_batch_size = 128
learning_rate = 0.01
num_epoches = 20
lr = 0.01
momentum = 0.5
############################################
# (2)下载数据,并进行数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
#       11、transforms.Compose()方法是将多种变换组合在一起。Compose()会将transforms列表里面的transform操作进行遍历。
#       22、torchvision.transforms.Normalize(mean, std):用给定的均值和标准差分别对每个通道的数据进行正则化。
#           单通道=[0.5], [0.5]     ————     三通道=[m1,m2,m3], [n1,n2,n3]

train_dataset = mnist.MNIST('./pytorch_knowledge', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('./pytorch_knowledge', train=False, transform=transform)
# download参数控制是否需要下载。如果目录下已有MNIST,可选择False。

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
############################################
# (3)可视化源数据
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

fig = plt.figure()
for i in range(6):
    plt.subplot(2, 3, i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title('Ground Truth:{}' .format((example_targets[i])))
    plt.xticks(([]))
    plt.yticks(([]))
plt.show()
############################################


# (4)构建网络模型
class Net(nn.Module):
    # 使用Sequential构建网络,将网络的层组合到一起
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1), nn.BatchNorm1d(n_hidden_1))
        self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2), nn.BatchNorm1d(n_hidden_2))
        self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        return x


if __name__ == '__main__':
    ############################################
    # (5)检测是否有可用的GPU,有则使用,否则使用GPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # 实例化网络
    model = Net(28*28, 300, 100, 10)
    model.to(device)
    # 定义损失函数和优化器
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    ############################################
    # (6)训练模型
    losses = []
    acces = []
    eval_losses = []
    eval_acces = []
    for epoch in range(num_epoches):
        # 动态修改参数学习率
        if epoch % 5 == 0:
            optimizer.param_groups[0]['lr'] *= 0.1

        # 训练集 #######################################
        train_loss = 0
        train_acc = 0
        # 将模型切换为训练模式
        model.train()
        for img, label in train_loader:
            img = img.to(device)
            label = label.to(device)
            img = img.view(img.size(0), -1)

            out = model(img)                    # 前向传播
            loss = criterion(out, label)        # 损失函数
            optimizer.zero_grad()               # 梯度清零
            loss.backward()                     # 反向传播
            optimizer.step()                    # 参数更新

            # 记录误差
            train_loss += loss.item()
            # 记录分类的准确率
            _, pred = out.max(1)        # 提取分类精度最高的结果
            num_correct = (pred == label).sum().item()      # 汇总准确度
            acc = num_correct / img.shape[0]
            train_acc += acc
        train_loss_temp = train_loss / len(train_loader)        # 记录单次训练损失
        train_acc_temp = train_acc / len(train_loader)          # 记录单次训练准确度
        losses.append(train_loss / len(train_loader))
        acces.append(train_acc / len(train_loader))

        # 测试集 #######################################
        eval_loss = 0
        eval_acc = 0
        # 将模型切换为测试模式
        model.eval()
        for img, label in test_loader:
            img = img.to(device)
            label = label.to(device)
            img = img.view(img.size(0), -1)

            out = model(img)                    # 前向传播
            loss = criterion(out, label)        # 损失函数

            # 记录误差
            eval_loss += loss.item()
            # 记录分类的准确率
            _, pred = out.max(1)        # 提取分类精度最高的结果
            num_correct = (pred == label).sum().item()
            acc = num_correct / img.shape[0]
            eval_acc += acc
        eval_loss_temp = train_loss / len(train_loader)         # 记录单次测试损失
        eval_acc_temp = train_acc / len(train_loader)           # 记录单次测试准确度
        eval_losses.append(eval_loss / len(test_loader))
        eval_acces.append(eval_acc / len(test_loader))
        print('epoch:{}, Train_loss:{:.4f}, Train_Acc:{:.4f}, Test_loss:{:.4f}, Test_Acc:{:4f}'
              .format(epoch, train_loss_temp, train_acc_temp, eval_loss_temp, eval_acc_temp))

    # (7)可视化结果
    plt.title('Train Loss')
    plt.plot(np.arange(len(losses)), losses)
    plt.legend(['train loss'], loc='upper right')
    plt.xlabel('Steps')         # 设置x轴标签
    plt.ylabel('Loss')          # 设置y轴标签
    plt.ylim((0, 1.2))          # 设置y轴的数值显示范围:plt.ylim(y_min, y_max)
    plt.show()
# 备注1:model.eval()的作用是不启用 Batch Normalization 和 Dropout。
# 备注2:model.train()的作用是启用 Batch Normalization 和 Dropout。

(二)实战:基于CNN的图像分类(数据集:CIFAR-10)

在这里插入图片描述
Dataset之CIFAR-10:CIFAR-10数据集的简介、下载、使用方法之详细攻略

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'     # "OMP: Error #15: Initializing libiomp5md.dll"
###################################################################


def imshow(img):
    """显示图像"""
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


class CNNNet(nn.Module):
    """模型定义"""
    def __init__(self):
        super(CNNNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(1296, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 36*6*6)
        x = F.relu(self.fc2(F.relu(self.fc1(x))))
        return x


if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")     # 检测是否有可用的GPU,有则使用,否则使用CPU。
    net = CNNNet()              # 模型实例化
    net = net.to(device)        # 将构建的张量或者模型分配到相应的设备上。
    ###################################################################
    # (1)下载数据、数据预处理、迭代器
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    ###################################################################
    # (2)随机获取部分训练数据
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    imshow(torchvision.utils.make_grid(images))  # 显示图像
    print(' '.join('%5s' % classes[labels[j]] for j in range(16)))  # 打印标签
    print("net have {} paramerters in total".format(sum(x.numel() for x in net.parameters())))
    ###################################################################
    # (3)模型训练
    criterion = nn.CrossEntropyLoss()           # 交叉熵损失函数
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)     # 优化器
    nn.Sequential(*list(net.children())[:4])    # 取模型中的前四层
    for epoch in range(10):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # 获取训练数据
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()               # 权重参数梯度清零
            outputs = net(inputs)               # 前向传播
            loss = criterion(outputs, labels)   # 损失函数
            loss.backward()                     # 后向传播
            optimizer.step()                    # 梯度更新
            # 显示损失值
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches.    共打印:batches * epoch
                print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss/2000))
                running_loss = 0.0
    print('Finished Training')
    ###################################################################
    # (4)模型验证
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)       # 前向传播
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/183091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Lua 函数 - 可变参数

Lua 函数 - 可变参数 参考至菜鸟教程。 Lua函数可以接收可变数目的参数,和C语言类似,在函数参数列表中使用三点...表示函数有可变的参数。 function add(...) local s 0 for i, v in ipairs{...} do --> {...} 表示一个由所有变长参数构成的数…

模拟实现C库函数(2)

"烦恼无影踪,丢宇宙~"上一篇的模拟实现了好几个库函数,strlen\strcpy\memcpy\memmove,那么这一篇又会增加几个常用C库函数的模拟实现 memset\itoa\atoi。一、memsetmemset - fill memory with a constant byte#include <string.h>void *memset(void *s, int c,…

机器自动翻译古文拼音 - 十大宋词 - 江城子·乙卯正月二十日夜记梦 苏轼

江城子乙卯正月二十日夜记梦 宋苏轼 十年生死两茫茫&#xff0c;不思量&#xff0c;自难忘。 千里孤坟&#xff0c;无处话凄凉。 纵使相逢应不识&#xff0c;尘满面&#xff0c;鬓如霜。 夜来幽梦忽还乡&#xff0c;小轩窗&#xff0c;正梳妆。 相顾无言&#xff0c;惟有泪千…

uniapp使用及踩坑项目记录

环境准备 下载 HBuilderX 使用命令行创建项目&#xff1a; 一些常识准备 响应式单位rpx 当设计稿宽度为750px的时&#xff0c;1rpx1px。 uniapp中vue文件style不用添加scoped 打包成h5端的时候自动添加上去&#xff0c;打包成 微信小程序端 不需要添加 scoped。 图片的…

SpringDataJpa set()方法自动保存失效

问题描述&#xff1a;springdatajpa支持直接操作对象设置属性进行更新数据库记录的方式&#xff0c;正常情况下&#xff0c;get()得到的对象直接进行set后&#xff0c;即使不进行save操作&#xff0c;也将自动更新数据记录&#xff0c;将改动持久化到数据库中&#xff0c;但这里…

20230126使AIO-3568J开发板在原厂Android11下跑起来

20230126使AIO-3568J开发板在原厂Android11下跑起来 2023/1/26 18:22 1、前提 2、修改dts设备树 3、适配板子的dts 4、&#xff08;修改uboot&#xff09;编译系统烧入固件验证 前提 因源码是直接使用原厂的SDK&#xff0c;没有使用firefly配套的SDK源码&#xff0c;所以手上这…

Linux安装mongodb企业版集群(分片集群)

目录 一、mongodb分片集群三种角色 二、安装 1、准备工作 2、安装 configsvr配置 router配置 shard配置 三、测试 四、整合Springboot 一、mongodb分片集群三种角色 router角色&#xff1a; mongodb的路由&#xff0c;提供入口&#xff0c;使得分片集群对外透明&…

【目标检测论文解读复现NO.27】基于改进YOLOv5的螺纹钢表面缺陷检测

前言此前出了目标改进算法专栏&#xff0c;但是对于应用于什么场景&#xff0c;需要什么改进方法对应与自己的应用场景有效果&#xff0c;并且多少改进点能发什么水平的文章&#xff0c;为解决大家的困惑&#xff0c;此系列文章旨在给大家解读最新目标检测算法论文&#xff0c;…

【工程化之路】Node require 正解

require 实现原理 流程概述 步骤1&#xff1a;尝试执行代码require("./1"). 开始调用方法require.步骤2&#xff1a;此时会得到filename&#xff0c;根据filename 会判断缓存中是否已经加载模块&#xff0c;如果加载完毕直接返回&#xff0c;反之继续执行步骤3&…

python图像处理(laplacian算子)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing @163.com】 和之前的prewitt算子、sobel算子不同,laplacian算子更适合检测一些孤立点、短线段的边缘。因此,它对噪声比较敏感,输入的图像一定要做好噪声的处理工作。同时,laplacian算子设计…

Leetcode 03. 无重复字符的最长子串 [C语言]

目录题目思路1代码1结果1思路2代码2结果2该文章只是用于记录考研复试刷题题目 Leetcode 03: 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长子串 的长度。 示例 1: 输入: s “abcabcbb” 输出: 3 解释: 因为无重复字符的最长子串是 “abc”&#xff0c;所…

尚医通-OAuth2-微信登录接口开发(三十一)

目录&#xff1a; &#xff08;1&#xff09;微信登录-OAuth2介绍 &#xff08;2&#xff09;前台用户系统-微信登录-准备工作 &#xff08;3&#xff09;微信登录-生成微信二维码-接口开发 &#xff08;4&#xff09;微信登录-生成验证码-前端整合 &#xff08;5&#xf…

Telerik DevCraft Ultimate R1 2023

Telerik DevCraft Ultimate R1 2023 Kendo UI R1 2023-添加新的Chip和ChipList组件。 KendoReact R1 2023&#xff08;v5.11.0&#xff09;-新的PDFViewer组件允许用户直接在应用程序中查看PDF文档。 Telerik JustLock R1 2023-Visual Studio快速操作菜单现在可以在创建通用…

蓝桥杯重点(C/C++)(随时更新,更新时间:2023.1.29)

点关注不迷路&#xff0c;欢迎推荐给更多人 目录 1 技巧 1.1 取消同步&#xff08;节约时间&#xff0c;甚至能多骗点分&#xff0c;最好每个程序都写上&#xff09; 1.2 万能库&#xff08;可能会耽误编译时间&#xff0c;但是省脑子&#xff09; 1.3 蓝桥杯return 0…

【数据库-通用知识系列-01】数据库规范化设计之范式,让数据库表看起来更专业

我们在设计数据库时考虑的因素包括读取性能&#xff0c;数据一致性&#xff0c;数据冗余度&#xff0c;可扩展性等&#xff0c;好好学习数据库规范化的知识&#xff0c;设计的数据库表看起来才专业。 范式一览 “键”理解&#xff1a; 超键&#xff1a;在关系中能唯一标识元组…

送什么礼物给小学生比较有纪念意义?适合送小学生的小礼物

送给小学生的礼物哪种比较有意义呢&#xff1f;送给学生的礼物&#xff0c;基本上是对学习有所帮助的&#xff0c;但是像送钢笔、练习册这些&#xff0c;有一部分学生是抗拒的&#xff0c;作为大人就是希望对视力、对成长有用的东西&#xff0c;我认为保护视力是现在许多家庭的…

isNotEmpty() 和 isNotBlank() 的区别,字符串判空, StringUtils工具包 StringUtil工具类,isEmpty() 和 isBlank() 的区别

目录1.StringUtils 和 StringUtilStringUtils 的依赖&#xff1a;StringUtils 的用法&#xff1a;StringUtil 工具类2. isNotEmpty() 和 isNotBlank()1.StringUtils 和 StringUtil 注&#xff1a;StringUtils 和 StringUtil 的区别&#xff08;StringUtil为自定义工具类&#…

以表达式作为template参数

目录 一.template参数的分类&#xff1a; 二.非类型参数与默认参数值一起使用 三.应用 一.template参数的分类&#xff1a; ①.某种类型&#xff1a; template<typename T>; ②.表达式(非类型)&#xff1a; template<int length,int position>; 其中length…

Liunx中shell命令行和权限的理解

文章目录前言1.shell外壳的理解2.关于权限理解1.Linux下的用户2.角色划分3.文件和目录的权限3.粘滞位3.总结前言 Linux中的操作都是通过在命令行上敲指令来实现的&#xff0c;本文将简单的介绍Linux中的外壳程序shell以及浅谈一下对Linux中的权限理解。 1.shell外壳的理解 Lin…

微信小程序开发(一)

1. 微信小程序的开发流程 2. 注册小程序 小程序注册页&#xff1a;https://mp.weixin.qq.com/wxopen/waregister?actionstep1 如已注册&#xff0c;直接登录 小程序后台 https://mp.weixin.qq.com/ 即可。 在小程序后台的 【开发管理】→ 【开发设置】下可以查看AppID&…