Pytorch入门实战 P10-使用pytorch实现车牌识别

news2024/11/15 15:39:23

目录

前言

一、MyDataset文件

二、完整代码:

三、结果展示:

四、添加accuracy值


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本周的学习内容是,使用pytorch实现车牌识别。

前言

        之前的案例里面,我们大多是使用的是datasets.ImageFolder函数,直接导入已经分类好的数据集形成Dataset,然后使用DataLoader加载Dataset,但是如果对无法分类的数据集,我们应该如何导入呢。

        这篇文章主要就是介绍通过自定义的一个MyDataset加载车牌数据集并完成车牌识别。

一、MyDataset文件

数据文件是这样的,没有进行分类的。

# 加载数据文件
class MyDataset(data.Dataset):
    def __init__(self, all_labels, data_paths_str, transform):
        self.img_labels = all_labels  # 获取标签信息
        self.img_dir = data_paths_str  # 图像目录路径
        self.transform = transform   # 目标转换函数

    def __len__(self):
        return len(self.img_labels)   # 返回数据集的长度,即标签的数量

    def __getitem__(self, index):
        image = Image.open(self.img_dir[index]).convert('RGB')   # 打开指定索引的图像文件,并将其转换为RGB模式
        label = self.img_labels[index]  # 获取图像对应的标签

        if self.transform:
            image = self.transform(image)   # 如果设置了转换函数,则对图像进行转换(如,裁剪、缩放、归一化等)

        return image, label  # 返回图像和标签

二、完整代码:

import pathlib

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib as mpl
mpl.use('Agg')  # 在服务器上运行的时候,打开注释


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

data_dir = './data'
data_dir = pathlib.Path(data_dir)

data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('/')[1].split('_')[1].split('.')[0] for path in data_paths]
# print(classNames)  # '沪G1CE81', '云G86LR6', '鄂U71R9F', '津G467JR'....

data_paths_str = [str(path) for path in data_paths]

# 数据可视化
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(14,5))
plt.suptitle('data show', fontsize=15)
for i in range(18):
    plt.subplot(3, 6, i+1)

    # 显示图片
    images = plt.imread(data_paths_str[i])
    plt.imshow(images)
plt.show()


# 3、标签数字化
char_enum = ["京","沪","津","渝","冀","晋","蒙","辽","吉","黑","苏","浙","皖","闽","赣","鲁","豫","鄂","湘","粤","桂","琼","川","贵","云","藏","陕","甘","青","宁","新","军","使"]

number = [str(i) for i in range(0, 10)]  # 0-9 的数字
alphabet = [chr(i) for i in range(65, 91)]  # A到Z的字母
char_set = char_enum + number + alphabet
char_set_len = len(char_set)
label_name_len = len(classNames[0])

# 将字符串数字化
def text2vec(text):
    vector = np.zeros([label_name_len, char_set_len])
    for i, c in enumerate(text):
        idx = char_set.index(c)
        vector[i][idx] = 1.0
    return vector


all_labels = [text2vec(i) for i in classNames]


# 加载数据文件
class MyDataset(data.Dataset):
    def __init__(self, all_labels, data_paths_str, transform):
        self.img_labels = all_labels  # 获取标签信息
        self.img_dir = data_paths_str  # 图像目录路径
        self.transform = transform   # 目标转换函数

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, index):
        image = Image.open(self.img_dir[index]).convert('RGB')
        label = self.img_labels[index]  # 获取图像对应的标签

        if self.transform:
            image = self.transform(image)

        return image, label  # 返回图像和标签

total_datadir = './data/'
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

total_data = MyDataset(all_labels, data_paths_str, train_transforms)
# 划分数据
train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size, test_size])
print(train_size, test_size)  # 10940 2735

# 数据加载
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)

for X, y in test_loader:
    print('Shape of X [N,C,H,W]:',X.shape)   # ([16, 3, 224, 224])
    print('Shape of y:', y.shape, y.dtype)    # torch.Size([16, 7, 69]) torch.float64
    break


# 搭建网络模型
class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        """
        nn.Conv2d()函数:
        第一个参数(in_channels)是输入的channel数量
        第二个参数(out_channels)是输出的channel数量
        第三个参数(kernel_size)是卷积核大小
        第四个参数(stride)是步长,默认为1
        第五个参数(padding)是填充大小,默认为0
        """
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn5 = nn.BatchNorm2d(24)
        self.fc1 = nn.Linear(24 * 50 * 50, label_name_len * char_set_len)
        self.reshape = Reshape([label_name_len, char_set_len])

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.pool(x)
        x = x.view(-1, 24 * 50 * 50)
        x = self.fc1(x)

        # 最终reshape
        x = self.reshape(x)

        return x


class Reshape(nn.Module):
    def __init__(self,shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(x.size(0), *self.shape)


model = Network_bn().to(device)
print(model)


# 优化器与损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0001)
loss_model = nn.CrossEntropyLoss()

def test(model, test_loader, loss_model):
    size = len(test_loader.dataset)
    num_batches = len(test_loader)

    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            pred = model(X)

            test_loss += loss_model(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches

    print(f'Avg loss: {test_loss:>8f}\n')
    return correct, test_loss


def train(model,train_loader, loss_model, optimizer):
    model = model.to(device)
    model.train()

    for i, (images, labels) in enumerate(train_loader, 0):   # 0 是标起始位置的值
        images = Variable(images.to(device))
        labels = Variable(labels.to(device))

        optimizer.zero_grad()
        outputs = model(images)

        loss = loss_model(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print('[%5d] loss: %.3f' % (i, loss))


# 模型的训练
test_acc_list = []
test_loss_list = []
epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n-----------------------")
    train(model,train_loader, loss_model,optimizer)
    test_acc,test_loss = test(model, test_loader, loss_model)
    test_acc_list.append(test_acc)
    test_loss_list.append(test_loss)

print('Done!!!')


# 结果分析
x = [i for i in range(1,31)]
plt.plot(x, test_loss_list, label="Loss", alpha = 0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.legend()
plt.show()
plt.savefig("/data/jupyter/deep_demo/p10_car_number/resultImg.jpg")  # 保存图片在服务器的位置
plt.show()

三、结果展示:

总结:从刚开始损失为0.077 到,训练30轮后,损失到了0.026。

四、添加accuracy值

需求:对在上面的代码中,对loss进行了统计更新,请补充acc统计更新部分,即获取每一次测试的ACC值。

添加accuracy的运行过程:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1678476.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【永洪BI】管理系统

管理系统模块包括系统设置、认证授权、日志管理、监控预警、资源部署、VooltDB管理、数据库管理、企业应用配置、系统检查、应用管理模块。 系统设置界面: 可以进行清除系统缓存、配置系统主题、配置系统邮箱、配置门户主页、配置权限管理系统、配置密码策略、配置…

端午佳节,品尝食家巷传统面点与黄米粽子礼盒

端午佳节,品尝食家巷传统面点与黄米粽子礼盒 在这个端午节来临之际,食家巷倾情推出一款别具特色的端午礼盒,将甘肃的传统面点与地方特色黄米粽子完美融合,为您带来一场美味与传统的邂逅。 这款礼盒以甘肃传统面点一窝丝、油饼和烤…

Python 渗透测试:子域名查询.

什么是 子域名查询. 子域名查询是指通过域名系统(DNS)查找某个域名下的子域名信息。子域名是域名层级结构中的一部分,位于主域名的下一级。子域名查询是网络安全评估和渗透测试中的一个重要步骤,可以帮助安全研究人员更好地了解目标系统的架构和潜在的安全隐患。但在进行子域名…

svn批量解锁

问题 svn对文件进行checkout之后,先进行lock,之后再去更改,最后进行Commit操作; 上述为我们通过svn管理代码的正常方式,但总会有其他现象发生; 如果我们非正常操作,批量锁所有的svn文件&#x…

Pencils Protocol 宣布再获合作伙伴 Galxe 的投资

近日,Scroll生态项目Penpad将品牌进一步升级为Pencils Protocol,全新升级后其不仅对LaunchPad平台进行了功能上的升级,同时其也进一步引入了Staking、Vault以及Shop等玩法,这也让Pencils Protocol的叙事方向不再仅限于LaunchPad&a…

刘邦的创业团队是沛县人,朱元璋的则是凤阳;要创业,一个县人才就够了

当人们回顾刘邦和朱元璋的创业经历时,总是会感慨他们起于微末,都创下了偌大王朝,成就无上荣誉。 尤其是我们查阅史书时,发现这二人的崛起班底都是各自的家乡人,例如刘邦的班底就是沛县人,朱元璋的班底是凤…

openlayers 热力图 天地图

openlayers 实现热力图 样式可调 在https://blog.csdn.net/qq_36287830/article/details/131844745?spm1001.2014.3001.5501基础上改进来的 最终样式 关键代码 如果你有数据可以不使用for循环 var blurInput document.getElementById("blur");var rediusInput do…

sa-token权限认证框架,最简洁,最实用讲解

查看源码,可知,sa sa-token框架 测试代码源码配置自动装配SaTokenConfigSaTokenConfigFactory SaManager工具类SaFoxUtilStpUtilSaResult StpLogic持久层定时任务 会话登录生成token创建account-session事件驱动模型写入tokenSaSessionSaCookieSaTokenDa…

如何使用AspectJ做切面,打印jar包中方法的执行日记

最近在工作中遇到一个redis缓存中的hash key莫名其妙被删除的问题,我们用了J2Cache,二级缓存用的是redis。hash key莫名其妙被删除又没有日志,就想到做一个切面在调用redis删除hash key的方法的时候,打印日志,并且把调…

揭秘SmartEDA魅力:为何众多学校青睐这款电路仿真软件?

在当今数字化、信息化的教育时代,电子电路仿真软件已成为电子学教学不可或缺的重要工具。其中,SmartEDA电路仿真软件以其强大的功能、用户友好的界面以及丰富的教育资源,赢得了众多学校的青睐。那么,究竟是什么原因让SmartEDA成为…

解决Win11下SVN状态图标显示不出来

我们正常SVN在Windows资源管理器都是有显示状态图标的, 如果不显示状态图标,可能你的注册表的配置被顶下去了,我们查看一下注册表 运行CMD > regedit 打开注册表编辑器 然后打开这个路径:计算机\HKEY_LOCAL_MACHINE\SOFTWARE…

【LeetCode刷题】27. 移除元素

1. 题目链接2. 题目描述3. 解题方法4. 代码 1. 题目链接 27. 移除元素 2. 题目描述 3. 解题方法 暴力法直接解决,用双层for循环,外层for循环找val,内层for循环做删除操作。双指针法,fast和slow。fast找不是val的值,…

在Ubuntu上的QT创建工程并打包项目

一、环境准备 参考UbuntuQT安装 二、创建项目,点击choose 设置项目名字路径等,点击下一步 默认,点击下一步 设置函数名字,保持默认,下一步 保持默认,点击下一步 继续,下一步 点击完成 三…

22 优化日志文件统计程序-按月份统计每个用户每天的访问次数

读取任务一中序列文件&#xff0c;统计每个用户每天的访问次数&#xff0c;最终将2021/1和2021/2的数据分别输出在两个文件中。 一、创建项目步骤&#xff1a; 1.创建项目 2.修改pom.xml文件 <packaging>jar</packaging> <dependencies><dependency>…

听劝!普通人千万别随意入门网络安全

一、什么是网络安全 网络安全是一种综合性的概念&#xff0c;涵盖了保护计算机系统、网络基础设施和数据免受未经授权的访问、攻击、损害或盗窃的一系列措施和技术。经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”…

Linux-页(page)和页表

本文在页表方面参考了这篇博客&#xff0c;特别鸣谢&#xff01; 【Linux】页表的深入分析 1. 页帧和页框 页帧&#xff08;page frame&#xff09;是内存的最小可分配单元&#xff0c;也开始称作页框&#xff0c;Linux下页帧的大小为4KB。 内核需要将他们用于所有的内存需求&a…

【Git教程】(十九)合并小型项目 — 概述及使用要求,执行过程及其实现,替代解决方案 ~

Git教程 合并小型项目 1️⃣ 概述2️⃣ 使用要求3️⃣ 执行过程及其实现 在项目的初始阶段&#xff0c;往往需要针对重要的设计决策和技术实现原型实验。当原型评估结束后&#xff0c;需要将那些成功的原型合并起来称为整个项目的初始版本。 在这样的情景中&#xff0c;各个原…

什么是ARP攻击,怎么做好主机安全,受到ARP攻击有哪些解决方案

在数字化日益深入的今天&#xff0c;网络安全问题愈发凸显其重要性。其中&#xff0c;ARP攻击作为一种常见的网络攻击方式之一&#xff0c;往往给企业和个人用户带来不小的困扰。ARP协议是TCP/IP协议族中的一个重要协议&#xff0c;负责把网络层(IP层)的IP地址解析为数据链路层…

Vmvare—windows中打不开摄像头

1、检查本地摄像头是否能正常打开 设备管理器—查看—显示隐藏设备—选中照相机—启动 USB2.0 HD UVC—打开相机查看 2、检查虚拟机的设置 虚拟机—虚拟机—可移动设备—USB2.0 HD UVC—勾选在状态栏中显示 虚拟机—打开windows主机—右小角选中圆圈图标—勾选连接主机 此时…