PyTorch深度学习实战——交通标志识别

news2024/12/23 10:02:06

PyTorch深度学习实战——交通标记识别

    • 0. 前言
    • 1. 交通标志识别
      • 1.1 数据集介绍
      • 1.2 数据增强和批归一化
    • 3. 交通标志检测
    • 相关链接

0. 前言

在道路交通场景中,交通标志识别作为驾驶辅助系统与无人驾驶车辆中不可缺少的技术,为车辆行驶中提供了安全保障。在道路上行驶的车辆,道路周围的环境包括许多重要的交通标志信息,根据交通标志信息在道路上做出正确的驾驶行为,通常能够避免发生交通事故。交通标志识别可以检测并识别当前行驶道路上的交通标志,然后得出有关道路的必要信息。
但交通标志会受到车辆的运动状态、光照以及遮挡等环境因素的影响,因此如何使车辆在道路交通中快速准确地帮助驾驶员识别交通标志已经成为智能交通领域的热点问题之一。鉴于交通标志识别在自动驾驶等应用中具有重要作用,在节中,我们将学习使用卷积神经网络实现交通标志识别。

1. 交通标志识别

1.1 数据集介绍

德国交通标志识别基准 (German Traffic Sign Recognition Benchmark, GTSRB) 是高级驾驶辅助系统和自动驾驶领域的交通标志图像分类基准。其中共包含 43 种不同类别的交通标志。可以在官方网页中下载相关数据集。
每张图片包含一个交通标志,图像包含实际交通标志周围的环境像素,大约为交通标志尺寸的 10% (至少为 5 个像素),图像以 PPM 格式存储,图像尺寸在 15x15250x250 像素之间。

1.2 数据增强和批归一化

在介绍神经网络时,我们已经了解了利用数据增强可以提高模型准确性。在现实世界中,我们会遇到具有不同特征的图像,例如,某些图像可能更亮,某些图像中的感兴趣对象可能在图像边缘附近,而某些图像可能较为模糊。在本节中,我们将介绍如何使用数据增强和批归一化提高模型的准确率。
为了了解数据增强和批归一化对模型性能的影响,我们将使用交通标志数据集训练交通标志识别模型,并评估以下三种情况:

  • 不使用批归一化和数据增强
  • 只使用批归一化,但不使用数据增强
  • 同时使用批归一化和数据增强

除了以上不同外,数据集及其预处理方法完全相同。

3. 交通标志检测

首先考虑不使用批归一化和数据增强的情况,使用 PyTorch 实现交通标志识别。

(1) 下载数据集并导入相关库:

import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from glob import glob
from random import randint
import cv2
from pathlib import Path
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt
import pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'

(2) 指定与输出类别对应的索引:

from torchvision import transforms as T
classIds = pd.read_csv('signnames.csv')
classIds.set_index('ClassId', inplace=True)
classIds = classIds.to_dict()['SignName']
classIds = {f'{k:05d}':v for k,v in classIds.items()}
id2int = {v:ix for ix,(k,v) in enumerate(classIds.items())}

(3) 定义图像转换管道,执行图像转换操作(不使用数据增强):

trn_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize(32),
    T.CenterCrop(32),
    # T.ColorJitter(brightness=(0.8,1.2), 
    # contrast=(0.8,1.2), 
    # saturation=(0.8,1.2), 
    # hue=0.25),
    # T.RandomAffine(5, translate=(0.01,0.1)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
])

val_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize(32),
    T.CenterCrop(32),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]),
])

在以上代码中,对输入图像进行了一系列转换——首先将图像尺寸调整为 128 (最小边为 128),然后从图像中心进行裁剪。此外,使用 .ToTensor() 方法对图像进行缩放(使像素值位于 01 之间),最后对图像进行归一化处理,以便使用预训练模型。
取消以上代码中的注释并重新运行即可执行数据增强。此外,我们并不会对 val_tfms 执行数据增强,因为在模型训练期间没有使用这些图像。但是,val_tfms 图像需要通过与 trn_tfms 相同的转换管道。

(4) 定义数据集类 GTSRB

class GTSRB(Dataset):
    """Road Sign Detection dataset."""

    def __init__(self, files, transform=None):
        self.files = files
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, ix):
        fpath = self.files[ix]
        clss = os.path.basename(Path(fpath).parent)
        img = cv2.imread(fpath)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img, classIds[clss]

    def choose(self):
        return self[randint(len(self))]

    def collate_fn(self, batch):
        imgs, classes = list(zip(*batch))
        if self.transform:
            imgs = [self.transform(img)[None] for img in imgs]
        classes = [torch.tensor([id2int[clss]]) for clss in classes]
        imgs, classes = [torch.cat(i).to(device) for i in [imgs, classes]]
        return imgs, classes

(5) 创建训练、验证数据集和数据加载器:

all_files = glob('GTSRB/Final_Training/Images/*/*.ppm')
np.random.shuffle(all_files)

from sklearn.model_selection import train_test_split
trn_files, val_files = train_test_split(all_files, random_state=1)

trn_ds = GTSRB(trn_files, transform=trn_tfms)
val_ds = GTSRB(val_files, transform=val_tfms)
trn_dl = DataLoader(trn_ds, 32, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, 32, shuffle=False, collate_fn=val_ds.collate_fn)

(6) 定义模型 SignClassifier

def convBlock(ni, no):
    return nn.Sequential(
        nn.Dropout(0.2),
        nn.Conv2d(ni, no, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        #nn.BatchNorm2d(no),
        nn.MaxPool2d(2),
    )
    
class SignClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            convBlock(3, 64),
            convBlock(64, 64),
            convBlock(64, 128),
            convBlock(128, 64),
            nn.Flatten(),
            nn.Linear(256, 256),
            nn.Dropout(0.2),
            nn.ReLU(inplace=True),
            nn.Linear(256, len(id2int))
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def compute_metrics(self, preds, targets):
        ce_loss = self.loss_fn(preds, targets)
        acc = (torch.max(preds, 1)[1] == targets).float().mean()
        return ce_loss, acc

当需要在模型中使用 BatchNormalization (批归一化)时,需要取消注释以上代码中注释行。

(7) 定义使用批数据对模型进行训练和验证的函数:

def train_batch(model, data, optimizer, criterion):
    ims, labels = data
    _preds = model(ims)
    optimizer.zero_grad()
    loss, acc = criterion(_preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item(), acc.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    ims, labels = data
    _preds = model(ims)
    loss, acc = criterion(_preds, labels)
    return loss.item(), acc.item()

(8) 定义模型并对其进行训练:

model = SignClassifier().to(device)
criterion = model.compute_metrics
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 50

train_loss_epochs_no_aug_no_bn = []
train_acc_epochs_no_aug_no_bn = []
val_loss_epochs_no_aug_no_bn = []
val_acc_epochs_no_aug_no_bn = []
for ex in range(n_epochs):
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    N = len(trn_dl)
    for bx, data in enumerate(trn_dl):
        loss, acc = train_batch(model, data, optimizer, criterion)
        train_loss.append(loss)
        train_acc.append(acc)

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss, acc = validate_batch(model, data, criterion)
        val_loss.append(loss)
        val_acc.append(acc)
    train_loss_epochs_no_aug_no_bn.append(np.average(train_loss))
    train_acc_epochs_no_aug_no_bn.append(np.average(train_acc))
    val_loss_epochs_no_aug_no_bn.append(np.average(val_loss))
    val_acc_epochs_no_aug_no_bn.append(np.average(val_acc))
    if ex == 10:
        optimizer = optim.Adam(model.parameters(), lr=1e-4)

epochs = np.arange(50)+1
import matplotlib.pyplot as plt
plt.subplot(121)
plt.plot(epochs, train_loss_epochs_no_aug_no_bn, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs_no_aug_no_bn, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs \n with no batchnormalization and augmentation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.subplot(122)
plt.plot(epochs, train_acc_epochs_no_aug_no_bn, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc_epochs_no_aug_no_bn, 'r', label='Test accuracy')
plt.title('Training and Test accuracy over increasing epochs \n with no batchnormalization and augmentation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid('off')
plt.show()

在三种不同实验设定中,模型在训练过程中的训练和验证准确率如下:

模型性能变化

根据以上结果,我们可以看出:

  • 当没有使用批归一化时,模型的准确率较低
  • 只使用批归一化但未使用数据增强时,模型的准确性会大大提高,但模型在训练数据上出现过拟合现象
  • 同时使用批归一化和数据增强的模型具有很高的准确性和较小的过拟合

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1038975.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【操作系统】24王道考研笔记——第五章 IO管理

第五章 IO管理 一、IO设备 1.1 基本概念与分类 1.2 IO控制器 电子部件 IO控制器组成 值得注意的小细节:①一个I/O控制器可能会对应多个设备; ②数据寄存器、控制寄存器、状态寄存器可能有多个(如:每个控制/状态寄存器对应一个…

RK3566 linux添加rgb13h

一、DTS根节点增加节点 在根节点/{}下增加flash_rgb13h节点,节点内容如下: flash_rgb13h: flash-rgb13h {status "okay";compatible "led,rgb13h";label "gpio-flash";pinctrl-names "default";pinctrl-0 …

如何在JoySSL上申请免费的SSL证书

1,前往 JoySSL 的官方网站注册页面,创建一个账号并登录您的 JoySSL 账户。 扫码注册账号申请免费证书https://www.joyssl.com/certificate/select/free.html?nid52,找到并选择你需要的 SSL 证书相关的功能或选项。 3,提供您的域…

三周过PMP经验分享,用最少的时间拿3A!

今天分享一个大神的PMP备考经验,大佬就是大佬,三周过PMP还拿了3A,正在备考PMP的小伙伴,咱们共勉! 第一周、阅读教材 之前是第六版教材,花了很长时间阅读和梳理框架。现在是第七版教材,内容少了…

快速发布服务到生产环境(手动操作)

背景介绍 虽然现在大部分项目都是用Jenkins搭建环境,自动化部署。但仍然存在一些小客户,只会单独上线一些关键服务,此时就需要手动去服务器里面部署了。此处用一个外业服务做例子,下面开始介绍。 进入服务器 一般需要申请服务器权…

使用BaGet 实现NuGet包私有化部署

本文主要介绍使用IIS部署 1.下载Baget,github下载,本文下载版本v0.4.0-preview2 2.解压,参考使用说明 3.安装环境,.NET Core Runtime,此处说明,.net7安装包是集成了 ASP.NET Core IIS Module的&#xff…

数据结构之时间复杂度空间复杂度的计算

数据结构:计算机如何存储数据的问题。DS关心的是如何高效的进行数据的读写。 算法:在特定的数据集上(不关心怎么进行具体数据的读写),如何利用数据完成特定的功能。算法本质上就是一系列运算的先后集合。 那么&#…

电工-国产二极管型号、三极管型号的命名方式

根据半导体器件型号命名方法(GB249-74)规定,国产半导体由5共部分组成,二极管、三极管的型号命名方式也有5个部分,第一部分是标明晶体管数目(二极管或是三极管)。第二部分是三极管的材质标识&…

【ZLM】花屏现象记录

目录 事后小结 现象 tcpdump看下包的情况 移了两个摄像头到10.60.100.196 事后小结 花屏的现象,主要看链路时延的稳定性。 如果 ping -s 2000 ip , > 2ms已经带宽 2000*8*2/0.002s16Mbps,说明带宽不够,应该接近100Mbps左右。你可…

C#中使用Newtonsoft.Charp实现Json对象序列化与反序列化

场景 C#中使用Newtonsoft.Json实现对Json字符串的解析: C#中使用Newtonsoft.Json实现对Json字符串的解析_霸道流氓气质的博客-CSDN博客 上面讲的对JSON字符串进行解析,实际就是JSON对象的反序列化。 在与第三方进行交互时常需要封装对象,…

linux 防火墙iptables

iptables 是 Linux 中比较底层的网络服务,它控制了 Linux 系统中的网络操作,CentOS 中的 firewalld 和 Ubuntu 中的 ufw 都是在 iptables 之上构建的,只为了简化 iptables 的操作。同时,iptables 不仅仅是防火墙这么简单&#xff…

C/C++代码静态检测工具PC-Lint常见错误总结

目录 1、PC-Lint 概述 2、PC-lint 常见错误列举 3、PC-Lint报告的语法错误 4、总结 VC常用功能开发汇总(专栏文章列表,欢迎订阅,持续更新...)https://blog.csdn.net/chenlycly/article/details/124272585C软件异常排查从入门到…

看文章-做笔记

看文章-做笔记 小蓝本

想要防止视频被盗?用它给视频加水印

随着社交网络的普及,越来越多的人喜欢在网上分享自己制作的视频,但是,共享的视频可能会被其他人传播和滥用。因此,保护自己制作的视频非常重要。 那怎样才能够防止别人盗用自己制作的视频呢?一种简单易行的保护方法是…

printContent 点击打印多页时第一页之前出现白页

项目场景: 提示:这里简述项目相关背景: printContent 点击打印多页时第一页之前出现白页 问题描述 提示:这里描述项目中遇到的问题: printContent 点击打印多页时第一页之前出现白页 原因分析: 提示&am…

【中阳期货】分析市场数据和制定交易策略代码

当谈到期货与市场分析时,编写代码来分析市场数据和制定交易策略是一种常见的做法。在这篇文章中,我将向您展示如何使用Python编写代码来获取市场数据、进行基本分析以及制定简单的交易策略。我们将使用一些常见的Python库,如Pandas、Matplotl…

前有CAP理论,后有BASE理论,分布式系统理论基石

🧑‍💻作者名称:DaenCode 🎤作者简介:CSDN实力新星,后端开发两年经验,曾担任甲方技术代表,业余独自创办智源恩创网络科技工作室。会点点Java相关技术栈、帆软报表、低代码平台快速开…

物联网如何改变智能家居技术?

物联网(IoT)已经在智能家居技术方面产生了深远的影响,其通过将各种设备、传感器和家居设备连接到互联网,实现了智能家居技术的创新和改进。 物联网(IoT)已经在智能家居技术方面产生了深远的影响,其通过将各种设备、传…

【手动实现nn.Linear 】

线性变换参数可视化图 class LinearLayer(nn.Module):def __init__(self, input_dim, output_dim):super(LinearLayer, self).__init__()self.weights nn.Parameter(torch.Tensor(output_dim, input_dim))self.bias nn.Parameter(torch.Tensor(output_dim))# 初始化权重和偏…

codesys自由编码器

1用于位置处理。 2它有个变量: SMC_FreeEncoder.diEncoderPosition 【DINT】 SMC_FreeEncoder_1.diEncoderPosition : GVL.电位器1; SMC_FreeEncoder.diEncoderPosition:hsi_cnt.diCurCountValue; //编码器位置 默认一圈是360.00 给它赋值&#x…