pytorch+CRNN实现

news2024/9/24 11:27:51

最近接触了一个仪表盘识别的项目,简单调研以后发现可以用CRNN来做。但是手边缺少仪表盘数据集,就先用ICDAR2013试了一下。
在这里插入图片描述
结果遇到了一系列坑。为了不使读者和自己在以后的日子继续遭罪。我把正确的代码发到下面了。
1)超参数请不要调整!!!!CRNN前期训练极其离谱,需要良好的调参,loss才会慢慢下降。
在这里插入图片描述
我给出了一个训练曲线,可以看到确实贼几把怪,七拐八拐的。

2)千万不要用百度开源的那个ctc!!!

网络代码:

#crnn.py
import torch.nn as nn
import torch.nn.functional as F

class BidirectionalLSTM(nn.Module):
    # Inputs hidden units Out
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output


class CRNN(nn.Module):
    #                   32    1   37     256
    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16

        self.cnn = cnn
        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):
        # conv features
        #print('---forward propagation---')
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2) # b *512 * width
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        output = F.log_softmax(self.rnn(conv), dim=2)
        return output

训练:

#train.py
import os
import torch
import cv2
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import crnn
import time
import re
import matplotlib.pyplot as plt
dic={" ":0,"a":1,"b":2,"c":3,"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26,
     "A":27,"B":28,"C":29,"D":30,"E":31,"F":32,"G":33,"H":34,"I":35,"J":36,"K":37,"L":38,"M":39,"N":40,"O":41,"P":42,"Q":43,"R":44,"S":45,"T":46,"U":47,"V":48,"W":49,"X":50,"Y":51,"Z":52}

STR=" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
n_class=53
label_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task1_GT"
image_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task12_Images"
use_gpu = True
learning_rate = 0.0001
max_epoch = 100
batch_size = 20
# 调整图像大小和归一化操作
class resizeAndNormalize():
    def __init__(self, size, interpolation=cv2.INTER_LINEAR):
        # 注意对于opencv,size的格式是(w,h)
        self.size = size
        self.interpolation = interpolation
        # ToTensor属于类  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
        self.toTensor = transforms.ToTensor()

    def __call__(self, image):
        # (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴
        image = cv2.resize(image, self.size, interpolation=self.interpolation)
        # 转为tensor的数据结构
        image = self.toTensor(image)
        # 对图像进行归一化操作
        #image = image.sub_(0.5).div_(0.5)
        return image

def load_data(label_folder,image_folder,label_suffix_name=".txt",image_suffix_name=".jpg"):
    image_file,label_file,num_file=[],[],[]
    for parent_folder, _, file_names in os.walk(label_folder):
        # 遍历当前子文件夹中的所有文件
        for file_name in file_names:
            # 只处理图片文件
            # if file_name.endswith(('jpg', 'jpeg', 'png', 'gif')):#提取jpg、jpeg等格式的文件到指定目录
            if file_name.endswith((label_suffix_name)):  # 提取json格式的文件到指定目录
                # 构造源文件路径和目标文件路径
                a,b=file_name.split("gt_")
                c,d=b.split(label_suffix_name)
                image_name=image_folder + "\\" + c + image_suffix_name
                if os.path.exists(image_name):
                    label_name = label_folder + "\\" + file_name
                    txt=open(label_name,'rb')
                    txtl=txt.readlines()
                    for line in range(len(txtl)):
                        image_file.append(image_name)
                        label_file.append(label_name)
                        num_file.append(line)
    return image_file,label_file,num_file

def zl2lable(zl):
    label_list=[]
    for str in zl:
        label_list.append(dic[str])
    return label_list

class NewDataSet(Dataset):
    def __init__(self, label_source,image_source,train=True):
        super(NewDataSet, self).__init__()
        self.image_file,self.label_file,self.num_file= load_data(label_source,image_source)

    def __len__(self):
        return len(self.image_file)

    def __getitem__(self, index):
        txt = open(self.label_file[index], 'rb')
        img=cv2.imread(self.image_file[index],cv2.IMREAD_GRAYSCALE)
        wordL = txt.readlines()
        word=str(wordL[self.num_file[index]])
        pl = re.findall(r'\d+',word)
        zl = re.findall(r"[a-zA-Z]+", word)[1]  #1

        #img tensor
        x1, y1, x2, y2 = pl[:4]
        img= img[int(y1):int(y2),int(x1):int(x2), ]
        (height, width)=img.shape
        # 由于crnn网络输入图像的高为32,故需要resize原始图像的height
        size_height = 32
        # ratio = 32 / float(height)
        size_width =100
        transform = resizeAndNormalize((size_width, size_height))
        # 图像预处理
        imageTensor = transform(img)

        #label tensor
        l = zl2lable(zl)
        labelTensor = torch.IntTensor(l)
        return imageTensor,labelTensor




class CRNNDataSet(Dataset):
    def __init__(self, imageRoot, labelRoot):
        self.image_root = imageRoot
        self.image_dict = self.readfile(labelRoot)
        self.image_name = [fileName for fileName, _ in self.image_dict.items()]

    def __getitem__(self, index):
        image_path = os.path.join(self.image_root, self.image_name[index])
        keys = self.image_dict.get(self.image_name[index])
        label = [int(x) for x in keys]

        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        # if image is None:
        #     return None,None
        (height, width) = image.shape

        # 由于crnn网络输入图像的高为32,故需要resize原始图像的height
        size_height = 32
        ratio = 32 / float(height)
        size_width = int(ratio * width)
        transform = resizeAndNormalize((size_width, size_height))
        # 图像预处理
        image = transform(image)
        # 标签格式转换为IntTensor
        label = torch.IntTensor(label)
        return image, label

    def __len__(self):
        return len(self.image_name)

    def readfile(self, fileName):
        res = []
        with open(fileName, 'r') as f:
            lines = f.readlines()
            for line in lines:
                res.append(line.strip())
        dic = {}
        total = 0
        for line in res:
            part = line.split(' ')
            # 由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在
            if not os.path.exists(os.path.join(self.image_root, part[0])):
                print(os.path.join(self.image_root, part[0]))
                total += 1
            else:
                dic[part[0]] = part[1:]
        print(total)

        return dic


trainData =NewDataSet(label_sources,image_sources)

trainLoader = DataLoader(dataset=trainData, batch_size=1, shuffle=True, num_workers=0)




# valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
#                       labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")
#
# valLoader = DataLoader(dataset=valData, batch_size=1, shuffle=True, num_workers=1)


#
# def decode(preds):
#     pred = []
#     for i in range(len(preds)):
#         if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i - 1])):
#             pred.append(int(preds[i]))
#     return pred
#
#
def toSTR(l):
    str_l=[]
    if isinstance(l, int):
        l=[l]
    for i in range(len(l)):
        str_l.append(STR[l[i]])
    return str_l
def toRES(l):
    new_l=[]
    new_str=' '
    for i in range(len(l)):
        if(l[i]==' '):
            new_str = ' '
            continue
        elif new_str!=l[i]:
            new_l.append(l[i])
            new_str=l[i]
    return new_l

def val(model=torch.load("pytorch-crnn.pth")):
    # 将模式切换为验证评估模式
    loss_func = torch.nn.CTCLoss(blank=0, reduction='mean')
    model.eval()

    test_n=10



    for i, (data, label) in enumerate(trainLoader):
        if(i>test_n):
            break;
        output = model(data.cuda())
        pred_label=output.max(2)[1]
        input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
        target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))
        # forward(self, log_probs, targets, input_lengths, target_lengths)
        #log_probs = output.log_softmax(2).requires_grad_()
        targets = label.cuda()
        loss = loss_func(output.cpu(), targets.cpu(), input_lengths, target_lengths)

        pred_l=np.array(pred_label.cpu().squeeze()).tolist()
        label_l=np.array(targets.cpu().squeeze()).tolist()
        print(i,":",loss,"pred:",toRES(toSTR(pred_l)),"label_l",toSTR(label_l))




def train():
    model = crnn.CRNN(32, 1, n_class, 256)
    if torch.cuda.is_available() and use_gpu:
        model.cuda()

    loss_func = torch.nn.CTCLoss(blank=0,reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,betas=(0.9, 0.999))

    lossTotal = 0.0
    k = 0
    printInterval = 100
    start_time = time.time()
    loss_list=[]
    total_list=[]
    for epoch in range(max_epoch):
        n=0
        data_list = []
        label_list = []
        label_len=[]
        for i, (data, label) in enumerate(trainLoader):
            #
            data_list.append(data)
            label_list.append(label)
            label_len.append(label.size(1))
            n=n+1
            if n%batch_size!=0:
                continue
            k=k+1
            data=torch.cat(data_list, dim=0)
            data_list.clear()

            label = torch.cat(label_list, dim=1).squeeze(0)
            label_list.clear()

            target_lengths=torch.tensor(np.array(label_len))
            label_len.clear()
            # 开启训练模式
            model.train()


            if torch.cuda.is_available and use_gpu:
                data = data.cuda()
                loss_func = loss_func.cuda()
                label = label.cuda()

            output = model(data)
            log_probs = output
            # example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题
            #log_probs = output.log_softmax(2).requires_grad_()
            targets = label.cuda()
            input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
            # forward(self, log_probs, targets, input_lengths, target_lengths)
            #targets =torch.zeros(targets.shape)
            loss = loss_func(log_probs.cpu(), targets, input_lengths, target_lengths)/batch_size
            lossTotal += float(loss)
            print("epoch:",epoch,"num:",i,"loss:",float(loss))
            loss_list.append(float(loss))
            if k % printInterval == 0:
                print("[%d/%d] [%d/%d] loss:%f" % (
                    epoch, max_epoch, i + 1, len(trainLoader), lossTotal / printInterval))
                total_list.append( lossTotal / printInterval)
                lossTotal = 0.0
                torch.save(model, 'pytorch-crnn.pth')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    plt.figure()
    plt.plot(loss_list)
    plt.savefig("loss.jpg")

    plt.clf()
    plt.figure()
    plt.plot(total_list)
    plt.savefig("total.jpg")
    end_time = time.time()
    print("takes {}s".format((end_time - start_time)))
    return model

if __name__ == '__main__':
    train()

测试结果如下:
在这里插入图片描述
最后给一些参考文献:
https://www.cnblogs.com/azheng333/p/7449515.html
https://blog.csdn.net/wzw12315/article/details/106643182

另外给出数据集和我训练好的模型:
链接:https://pan.baidu.com/s/1-jTA22bLKv2ut_1EJ1WMKA?pwd=jvk8
提取码:jvk8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/758422.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实体店搭建多用户商城系统有什么好处

现在很多的线下店铺都开始慢慢的转型线上了,想线上线下相结合,但是最近很多的商家都在问什么样的B2B2C商城系统开发适合线下店铺呢?这个问题今天加速度jsudo小编给大家一起整理如下,相信商家看完后就知道如何选择一款合适的商城系统了。 一、…

Spring Batch之读数据—读XML文件(三十二)

一、XML格式文件解析 XML是一种通用的数据交换格式,它的平台无关性、语言无关性、系统无关性,给数据集成与交换带来了极大的方便。XML在Java领域的解析方式有两种,一种叫SAX,另一种叫DOM。SAX是基于事件流的解析,DOM是…

刷题总结1

暑假第二周练习题 - Virtual Judge (vjudge.net) 该题就是将含4的数字全部跳过,不难发现,这就导致每位数都要少一个树,这就和9进制十分像,我们只要将该数字转化为9进制,然后将该9进制树的每位大于等于4的树加一就行了&…

【CXL】CXL QoS Telemetry 介绍

🔥点击查看精选 CXL 系列文章🔥 🔥点击进入【芯片设计验证】社区,查看更多精彩内容🔥 📢 声明: 🥭 作者主页:【MangoPapa的CSDN主页】。⚠️ 本文首发于CSDN&#xff0c…

51单片机的智能交通控制系统【含仿真+程序+演示视频带原理讲解】

51单片机的智能交通控制系统【含仿真程序演示视频带原理讲解】 1、系统概述2、核心功能3、仿真运行及功能演示4、程序代码 1、系统概述 该系统由AT89C51单片机、LED灯组、数码管组成。通过Protues对十字路口红绿灯控制逻辑进行了仿真。 每个路口包含了左转、右转、直行三条车道…

rapid_latex_ocr: 更快更好用的公式图像转latex工具

Rapid Latex OCR rapid_latex_ocr是一个将公式图像转为latex格式的工具。仓库中的推理代码来自修改自LaTeX-OCR,模型已经全部转为ONNX格式,并对推理代码做了精简,推理速度更快,更容易部署。仓库只有基于ONNXRuntime或者OpenVINO推…

AI辅助瞄准系统开发与实战(一)

文章目录 前言系统窗体设计提示弹窗功能主体页面 windows窗体绘制矩形绘制自定义线程池完整代码 总结 前言 直接看效果,狗头: 之所以搞这个的话,当然主要一方面是因为确实有点意思在里面,此外在很久以前,也有很多的UP…

光伏并网逆变器低电压穿越MATLAB仿真模型

使用MATLAB 2017b搭建 光伏逆变器低电压穿越仿真模型,boost加NPC拓扑结构,基于MATLAB/Simulink建模仿真。具备中点平衡SVPWM控制,正负序分离控制,pll,可进行低电压穿越仿真。 控制结构完整,波形完美&…

Web入门-HTTP协议

目录 HTTP概述 HTTP特点 HTTP请求协议 请求数据的格式 响应数据的格式 响应的状态码 HTTP协议的解析 HTTP概述 HTTP:Hyper Text Transfer Protocol,超文本传输协议,规定浏览器和服务器之间数据传输的规则。(即请求数据和响应数据的格式)以上一篇…

动态规划之119杨辉三角 II(第7道)

题目:给定一个非负索引 rowIndex,返回「杨辉三角」的第 rowIndex 行。 在「杨辉三角」中,每个数是它左上方和右上方的数的和。 题目链接:119. 杨辉三角 II - 力扣(LeetCode) 示例: 解法&…

高阶C语言|字符函数和字符串函数--函数的模拟实现

C语言中对字符和字符串的处理很是频繁,但是C语言本身是没有字符串类型的,字符串通常放在常量字符串中或者字符数组中。字符串常量适用于那些对它不做修改的字符串函数 字符函数和字符串函数 一、求字符串长度1.1strlen的使用1.2strlen函数的模拟实现 二…

基于linux下的高并发服务器开发(第一章)- 模拟实现 ls-l 命令

这一小节会用到上面两张图的红色框里面的变量 任务&#xff1a; 模拟实现 ls -l 指令 -rw-rw-r-- 1 nowcoder nowcoder 12 12月 3 15:48 a.txt #include <stdio.h> #include <sys/types.h> #include <sys/stat.h> #include <unistd.h> #include <p…

C++中菱形继承中继承不明确问题

C中菱形继承中继承不明确问题 class A { public:virtual void func1(){cout << "A::func1()" << endl;}int _a; };class B:virtual public A { public:virtual void func1(){cout << "B::func1()" << endl;}int _b; };class C:vi…

JavaScript混淆加密:Ty2y平台配置参数详解

Ty2y是国内一个JavaScript混淆加密平台&#xff0c;可以实现在线JS代码混淆加密。它有多达20多项的参数配置。如下图所示&#xff1a; 添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09; 本文将对这些配置实现的混淆加密的效果&#xff0c;进行详细说明&…

基于自注意和残差结构的跨模态情感识别融合网络

题目A cross-modal fusion network based on self-attention and residual structure for multimodal emotion recognition译题基于自注意和残差结构的跨模态情感识别融合网络时间2021年代码https://github.com/skeletonNN/CFN-SR A cross-modal fusion network based on self…

verilog实现数码管静态显示

文章目录 verilog实现数码管静态显示一、任务要求二、实验代码三、仿真代码四、仿真结果五、总结 verilog实现数码管静态显示 一、任务要求 六个数码管同时间隔0.5s显示0-f。要求&#xff1a;使用一个顶层模块&#xff0c;调用计时器模块和数码管静态显示模块。 二、实验代码…

DS-SLAM论文翻译

DS-SLAM:面向动态环境的语义可视化SLAM 摘要-同时定位与绘图(SLAM)被认为是智能移动机器人的一项基本能力。在过去的几十年里&#xff0c;许多印象深刻的SLAM系统已经开发出来&#xff0c;并在某些情况下取得了良好的性能。然而&#xff0c;一些问题仍然没有很好地解决&#x…

windows下mingw 编译boost-1.78.0

1.mingw环境设置 添加C:\cygwin64\bin 到环境变量&#xff0c;cmd运行检查是否安装成功 打开cmd&#xff0c;验证&#xff1a; 2.boost编译 创建文件夹 #后期可以删除&#xff0c;安装Boost.Buildmkdir D:\boost_build#后期可以删除&#xff0c;存放mkdir D:\boost_1_78_0\b…

SpringBoot使用Redis作为缓存器缓存数据的操作步骤以及避坑方案

1.非注解式实现 2.1使用之前要明确使用的业务场景 例如我们在登录时&#xff0c;可以让redis缓存验证码&#xff0c;又如在分类下显示菜品数据时&#xff0c;我们可以对分类和菜品进行缓存数据等等。 2.2导入Redis相关依赖 <dependency><groupId>org.springfra…

Leetcode每日一题(困难):834. 树中距离之和(2023.7.16 C++)

目录 834. 树中距离之和 题目描述&#xff1a; 实现代码与解析&#xff1a; DFS 原理思路&#xff1a; 834. 树中距离之和 题目描述&#xff1a; 给定一个无向、连通的树。树中有 n 个标记为 0...n-1 的节点以及 n-1 条边 。 给定整数 n 和数组 edges &#xff0c; edge…