基于卷积神经网络的交通标志识别(pytorch,opencv,yolov5)

news2024/11/20 0:31:56

文章目录

  • 数据集介绍:
  • resnet18模型代码
  • 加载数据集(Dataset与Dataloader)
  • 模型训练
  • 训练准确率及损失函数:
  • resnet18交通标志分类源码
  • yolov5检测与识别(交通标志)

本文共包含两部分,
第一部分是用resnet18对交通标志分类,仅仅只是交通标志分类
文末附有yolov5和resnet18结合的源码,yolov5复制检测交通标志位置,然后使用resnet18对交通标志进行分类。

数据集介绍:

本文使用的数据集共有6000多张,共包含58个类别。部分数据集如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

resnet18模型代码

使用pytorch自带的resnet18模型,代码如下:

from torchvision import models
import torch.nn as nn

#加载resnet18模型
net=models.resnet18(weights=None)
#因为分类个数为58,所以需要修改模型最后一层全连接层
net.fc=nn.Linear(in_features=512, out_features=58, bias=True)
# print(net)

加载数据集(Dataset与Dataloader)

from torch.utils.data import Dataset,DataLoader
import numpy as np
import cv2
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from PIL import Image
import os
from torchvision import transforms
import torch
import random

a=[]
class Mydata(Dataset):
    def __init__(self,lines,train=True):
        super(Mydata, self).__init__()
        self.lines=lines
        random.shuffle(self.lines)
        self.train=train


    def __len__(self):
        return len(self.lines)
    def __getitem__(self, index):

        txts=self.lines[index].strip().split(';')
        src_path='pic/'+txts[0]
        w=int(txts[1])
        h=int(txts[2])
        x1=int(txts[3])
        y1=int(txts[4])
        x2=int(txts[5])
        y2=int(txts[6])

        new_x1=random.randint(0,x1)
        new_y1=random.randint(0,y1)
        new_x2=random.randint(x2,w-1)
        new_y2=random.randint(y2,h-1)

        lab=int(txts[7])
        # if lab in a:
        #     pass
        # else:a.append(lab)
        #
        # a.sort()
        # print(len(a))
        # print(a)
        img = Image.open(src_path)
        img=np.array(img)[...,:3]
        img=img[new_y1:new_y2,new_x1:new_x2]


        #数据增强
        if self.train:
            img=self.get_random_data(img)
        else:
            img = cv2.resize(img, (128, 128))
        # cv2.imshow('img',img[...,::-1])
        # cv2.waitKey(0)

        #归一化
        img=(img/255.0).astype('float32')
        img=np.transpose(img,(2,0,1))

        img=torch.from_numpy(img)
        return img,lab
    def get_random_data(self,img):
        seq = iaa.Sequential([
            # iaa.Flipud(0.5),  # flip up and down (vertical)
            # iaa.Fliplr(0.5),  # flip left and right (horizontal)
            iaa.Multiply((0.8, 1.2)),  # change brightness, doesn't affect BBs(bounding boxes)
            iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值

            iaa.Crop(percent=(0, 0.2)),
            iaa.Affine(
                translate_px={"x": (0,15), "y": (0,15)},  # 平移
                scale=(0.8, 1.2),  # 尺度变换
                rotate=(-20, 20),
                mode='constant',
                cval=(125)
                ),
            iaa.Resize(128)
        ])

        img= seq(image=img)
        return img
if __name__ == '__main__':
    lines=open('data.txt','r').readlines()
    my=Mydata(lines=lines,train=True)
    myloader=DataLoader(dataset=my,batch_size=3,shuffle=False)

    for i,j in myloader:
        print(i.shape,j.shape)

模型训练

经过60个epoch训练后,模型准确率基本上达到百分百

from mymodel import net
from myDataset import Mydata
import random
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm import tqdm
import matplotlib.pylab as plt

batch_size=32
Epoch=60
lr=0.001

lines=open('data.txt','r').readlines()
random.shuffle(lines)
val_lines=random.sample(lines,int(len(lines)*0.1))
train_lines=list(set(lines)-set(val_lines))


train_data=Mydata(lines=train_lines)
val_data=Mydata(lines=val_lines,train=False)
train_loader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
val_loader=DataLoader(dataset=val_data,batch_size=batch_size,shuffle=False)

num_train   = len(train_lines)
epoch_step  = num_train // batch_size
BCE_loss     = nn.CrossEntropyLoss()
optimizer  = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
#获取学习率函数
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
#计算准确率函数
def metric_func(pred,lab):
    _,index=torch.max(pred,dim=-1)
    acc=torch.where(index==lab,1.,0.).mean()
    return acc
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net=net.to(device)
#设置损失函数
loss_fun     = nn.CrossEntropyLoss()

if __name__ == '__main__':

    T_acc=[]
    V_acc=[]
    T_loss=[]
    V_loss=[]

    # 设置迭代次数200次
    epoch_step = num_train // batch_size
    for epoch in range(1, Epoch + 1):
        net.train()

        total_loss = 0
        loss_sum = 0.0
        train_acc_sum=0.0

        with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:
            for step, (features, labels) in enumerate(train_loader, 1):
                features = features.to(device)
                labels = labels.to(device)
                batch_size = labels.size()[0]

                optimizer.zero_grad()
                predictions = net(features)
                loss = loss_fun(predictions, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss
                train_acc = metric_func(predictions, labels)
                train_acc_sum+=train_acc
                pbar.set_postfix(**{'loss': total_loss.item() / (step),
                                    "acc":train_acc_sum.item()/(step),
                                    'lr': get_lr(optimizer)})
                pbar.update(1)
        T_acc.append(train_acc_sum.item()/(step))
        T_loss.append(total_loss.item() / (step))
        # 验证
        net.eval()
        val_acc_sum = 0
        val_loss_sum=0
        for val_step, (features, labels) in enumerate(val_loader, 1):
            with torch.no_grad():
                features = features.to(device)
                labels = labels.to(device)
                predictions = net(features)

                val_metric = metric_func(predictions, labels)
                loss=loss_fun(predictions,labels)
            val_acc_sum += val_metric.item()
            val_loss_sum+=loss.item()
        print('val_acc=%.4f' % (val_acc_sum / val_step))
        V_acc.append(round(val_acc_sum / val_step,2))
        V_loss.append(val_loss_sum/val_step)

        # 保存模型
        if (epoch) % 2 == 0:
            torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f_.pth' % (
                epoch, total_loss / (epoch_step + 1)))

        lr_scheduler.step()

    plt.figure()
    plt.plot(T_acc,'r')
    plt.plot(V_acc,'b')
    plt.title('Training and validation Acc')
    plt.xlabel("Epochs")
    plt.ylabel("Acc")
    plt.legend(["Train_acc", "Val_acc"])
    # plt.show()
    plt.savefig("ACC.png")

    plt.figure()
    plt.plot(T_loss, 'r')
    plt.plot(V_loss, 'b')
    plt.title('Training and validation loss')
    plt.xlabel("Epochs")
    plt.ylabel("loss")
    plt.legend(["Train_loss", "Val_loss"])

    plt.savefig("LOSS.png")
    plt.show()


训练准确率及损失函数:

准确率:

在这里插入图片描述
损失函数:
在这里插入图片描述

resnet18交通标志分类源码

(包含训练,预测代码,准确率,损失函结果图像,数据集等):
下载地址:

yolov5检测与识别(交通标志)

前面是使用resnet18网络对交通标志分类,只是单单的分类,无法从一张完整的全局图像中检测交通标志位置。对此,首先使用yolov5从全局图像中检测交通标志的位置,只是检测没有分类,然后再使用前面训练好的resnet18模型对交通标志分类。其效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1691961.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux学习笔记:线程

Linux中的线程 什么是线程线程的使用原生线程库创建线程线程的id线程退出等待线程join分离线程取消一个线程线程的局部存储在c程序中使用线程使用c自己封装一个简易的线程库 线程互斥(多线程)导致共享数据出错的原因互斥锁关键函数pthread_mutex_t :创建一个锁pthread_mutex_in…

[答疑]不开发系统只做领域建模,可以画冗余关联吗

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 Seven 2024-5-22 11:54 您书里面说,可以计算的不是关联。我有个疑问,要是不考虑开发系统,只是做领域建模,这个图是否可以这样画&#x…

构建高效AI代理:单代理与多代理架构的策略与挑战

AI代理架构正成为实现复杂任务自动化的关键技术。这些代理不仅需要处理信息、做出决策,还要能够与外部环境进行交互。随着ChatGPT等生成式AI应用的兴起,研究者们开始探索下一代AI应用,其中AI代理的角色愈发重要。单代理架构以其简洁高效的特点…

数据结构——栈(详细分析)

目录 🍉引言 🍉栈的本质和特点 🍈栈的基本操作 🍈栈的特点 🍍后进先出 🍍操作受限 🍍动态调整 🍈栈的优缺点 🍍优点 🍍缺点 🍉栈的应用…

AI绘画ComfyUI 进阶教程 | 字节最强换脸插件PuLID 详解,还请收藏!

大家好,我是小强 这应当算作是小编分享的换脸工具系列中的又一力作,从最初的roop,到之后的ReActor,再到备受欢迎的InstantID,以及今日重点介绍的字节开源产品——PuLID。 提及PuLID,首要原因并非仅仅在于…

[xx点评完结]——白马点评完整代码+rabbitmq实现异步下单+资料,免费

项目所有功能已测,均可以跑通,Jmeter和RabbitMQ也都测了。 项目源码:dianpinghui: 仿黑马点评项目 资料: https://pan.baidu.com/s/1kTCn9PxgeIey90WgM4KRqA?pwdn66b 对佬有帮助可以给个star哈,感谢🌹🌹&#x1f3…

QQ个性网空间日志网站模板源码

QQ个性网空间日志网站模板源码自带后台登录设置,适用于博客、文章、资讯、其他类网站内容使用。模板自带eyoucms内核,原创设计、手工书写DIVCSS,完美兼容IE7、Firefox、Chrome、360浏览器等;主流浏览器;结构容易优化;多终端均可正常预览。由于…

c++|多态

c|多态 1 多态的概念2 多态的定义及其实现2.1 满足多态的条件2.2 虚函数2.3 虚函数的重写2.4 析构函数适合加virtural吗2.4 C11 override 和 final2.5 三个概念的对比 3 多态的原理4 抽象类4.1 概念4.2 纯虚函数 1 多态的概念 多态的概念:通俗来说,就是…

英伟达的GPU(3)

上节内容:英伟达的GPU(2) (qq.com) 书接上文,上文我们讲到CUDA编程体系和硬件的关系,也留了一个小问题CUDA core以外的矩阵计算能力是咋提供的 本节介绍一下Tensor Core 上节我们介绍了CUDA core,或者一般NPU,CPU执行…

Android中华为手机三态位置权限申请理解

博主前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住也分享一下给大家&#xff0c; &#x1f449;点击跳转到教程 前言&#xff1a; 使用的华为MATE 20,Android10的系统。 <!--精准定位权限&#xff0c;如&#xff1a;…

建议收藏 | 2023年生物学类SCI期刊影响因子最新预测,Molecular Plant遥遥领先

公众号&#xff1a;生信漫谈&#xff0c;获取最新科研信息&#xff01; 建议收藏 | 2023年生物学类SCI期刊影响因子最新预测&#xff0c;Molecular Plant遥遥领先https://mp.weixin.qq.com/s/tFINUzZ1l4H9x1HWTq1kFg 2023年生物学类SCI期刊影响因子最新预测&#xff0c;Molecu…

【基于springboot+vue的房屋租赁系统】

介绍 本系统是基于springbootvue的房屋租赁系统&#xff0c;数据库为mysql&#xff0c;可用于日常学习和毕设&#xff0c;系统分为管理员、房东、用户&#xff0c;部分截图如下所示&#xff1a; 部分界面截图 用户 管理员 联系我 微信&#xff1a;Zzllh_

linux父进程fork出子进程后,子进程为何首先需要close文件描述符。

在linux c/c编程时&#xff0c;父进程fork出子进程后&#xff0c;子进程经常第一件事就是close掉所有的文件描述符&#xff1b;为何需要这样做&#xff0c;本文用一个例子进行简单说明。 考虑到一种情况&#xff0c;父进程创建了tcp服务端套接字&#xff0c;并且listen&#x…

6.中断管理

一、简介 中断是 CPU 的一种常见特性&#xff0c;中断一般由硬件产生&#xff0c;当中断发生后&#xff0c;会中断 CPU 当前正 在执行的程序而跳转到中断对应的服务程序种去执行&#xff0c;ARM Cortex-M 内核的 MCU 具有一个 用于中断管理的嵌套向量中断控制器&#xff08;NV…

Qt 在windows下显示中文

Qt在windows平台上显示中文&#xff0c;简直是一门玄学&#xff0c;经过测试&#xff0c;有如下发现&#xff1a; 1&#xff0c; 环境&#xff1a;Qt 5.15.2 vs2019 64位 win11系统 默认用Qt 创建的文件使用utf-8编码格式&#xff0c;此环境下 中文没有问题 ui->textE…

运用HTML、CSS设计Web网页——“西式甜品网”图例及代码

目录 一、效果展示图 二、设计分析 1.整体效果分析 2.头部header模块效果分析 3.导航及banner模块效果分析 4.分类classify模块效果分析 5.产品展示show模块效果分析 6.版权banquan模块效果分析 三、HTML、CSS代码分模块展示 1. 头部header模块代码 2.导航及bann…

一条命令安装Metasploit Framework

做安全渗透的人都或多或少的使用kali-Linux系统中msfconsole命令启动工具&#xff0c;然而也经常会有人遇到这样那样的问题无法启动 今天我们就用一条命令来重新安装这个工具 curl https://raw.githubusercontent.com/rapid7/metasploit-omnibus/master/config/templates/met…

Proteus仿真小技巧(隔空连线)

用了好几天Proteus了.总结一下使用的小技巧. 目录 一.隔空连线 1.打开添加网络标号 2.输入网络标号 二.常用元件 三.运行仿真 四.总结 一.隔空连线 引出一条线,并在末尾点一下. 1.打开添加网络标号 选择添加网络标号, 也可以先点击按钮,再去选择线(注意不要点端口) 2.…

PTT票据传递攻击

一. PTT票据传递攻击原理 1.PTT介绍 PTT(Pass The Ticket)&#xff0c;中文叫票据传递攻击&#xff0c;PTT 攻击只能用于kerberos认证中,NTLM认证中没有&#xff0c; PTT是通过票据进行认证的。 进行票据传递&#xff0c;不需要提权&#xff0c;域用户或者system用户就可以 …

2024上海初中生古诗文大会倒计时4个月:单选题真题解析(持续)

现在距离2024年初中生古诗文大会还有4个多月时间&#xff0c;我们继续来看10道选择题真题和详细解析&#xff0c;以下题目截取自我独家制作的在线真题集&#xff0c;都是来自于历届真题&#xff0c;去重、合并后&#xff0c;每道题都有参考答案和解析。 为帮助孩子自测和练习&…