【模型训练】-图形验证码识别

news2024/10/7 20:24:57

针对网站中的图形验证码图片,进行反向的内容识别,支持数字和字母,不区分大小写。

​​​​​​​​​​​​​​数据集地址

数据格式如下:

1、依赖导入

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt

2、数据集创建

class Dataset(Dataset):
    def __init__(self, img_dir):
        path_list = os.listdir(img_dir)
        # 获取文件夹绝对路径
        abspath = os.path.abspath(img_dir)
        self.img_list = [os.path.join(abspath, path) for path in path_list]
        self.transform = transforms.Compose([
            # 灰度化,配合 卷积网络初始通过 1
            # transforms.Grayscale(), 
            transforms.ToTensor(),
        ])
    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        path = self.img_list[idx]
        label = os.path.basename(path).split('.')[0].lower().strip()
        img = Image.open(path).convert('RGB')
        img_tensor = self.transform(img)
        return img_tensor, label

3、创建crnn卷积循环神经网络

stride 步长 

padding 完成卷积后是否填充空白

MaxPool2d :减少数据空间大小,池化窗口的大小,通常设置为2×2。减少参数数量和计算量,同时也能提高模型的鲁棒性。

BatchNorm(512):对输入数据进行归一化处理,使得每个通道的数据均值为0,方差为1,提高模型的泛化能力

dropout:随机丢弃神经元的输出来减少模型的复杂度和过拟合的风险

nn.GRU:PyTorch中的一个函数,用于创建一个双向的GRU(门控循环单元)层。

参数解释如下:

  • 255:输入的特征维度。输入数据的特征维度为255
  • 255:隐藏状态的维度。隐藏状态的维度为255
  • bidirectional=True:表示是否使用双向GRU。如果设置为True,则使用双向GRU;如果设置为False,则使用单向GRU。
  • batch_first=True:表示输入数据的维度顺序。如果设置为True,则输入数据的维度顺序为(batch_size, sequence_length, feature_dim);如果设置为False,则输入数据的维度顺序为(sequence_length, batch_size, feature_dim)。
class CRNN(nn.Module):
    def __init__(self, vocab_size, dropout=0.5):
        super(CRNN, self).__init__()
        
        self.dropout = nn.Dropout(dropout)

        self.convlayer = nn.Sequential(
            # 如果预处理采用Grayscale 则 channel=1
            nn.Conv2d(3, 32, (3,3), stride=1, padding=1),
            # 激活函数,x小于0,y=0
            nn.ReLU(),
            nn.MaxPool2d((2,2), 2),

            nn.Conv2d(32, 64, (3,3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((2,2), 2),

            nn.Conv2d(64, 128, (3,3), stride=1, padding=1),
            nn.ReLU(),

            nn.Conv2d(128, 256, (3,3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d((1,2), 2),

            nn.Conv2d(256, 512, (3,3), stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 512, (3,3), stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d((1,2), 2),

            nn.Conv2d(512, 512, (2,2), stride=1, padding=0),
            self.dropout
        )

        self.mapSeq = nn.Sequential(
            nn.Linear(1024, 256),
            self.dropout
        )

        self.lstm_0 = nn.GRU(256, 256, bidirectional=True)
        self.lstm_1 = nn.GRU(512, 256, bidirectional=True)

        self.out = nn.Sequential(
            nn.Linear(512, vocab_size),
        )

    def forward(self, x):
        x = self.convlayer(x)
        x = x.permute(0, 3, 1, 2)
        x = x.view(x.size(0), x.size(1), -1)
        
        x = self.mapSeq(x)

        x, _ = self.lstm_0(x)
        x, _ = self.lstm_1(x)

        x = self.out(x)

        return x.permute(1, 0, 2)

4、创建模型


class OCR:
    def __init__(self):
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

        self.crnn = CRNN(VOCAB_SIZE).to(self.device)
        print('Model loaded to ', self.device)

        self.critertion = nn.CTCLoss(blank=0)

        self.char2idx, self.idx2char = self.char_idx()

    def char_idx(self):
        char2idx = {}
        idx2char = {}

        characters = CHARS.lower() + '-'
        for i, char in enumerate(characters):
            char2idx[char] = i + 1
            idx2char[i+1] = char
        return char2idx, idx2char
    
    def encode(self, labels):
        length_per_label = [len(label) for label in labels] 
        joined_label = ''.join(labels)

        joined_encoding = []
        for char in joined_label:
            joined_encoding.append(self.char2idx[char])

        return (torch.IntTensor(joined_encoding), torch.IntTensor(length_per_label)) 

    def decode(self, logits):
        tokens = logits.softmax(2).argmax(2).squeeze(1)

        tokens = ''.join([self.idx2char[token]
                          if token !=0 else '-'
                          for token in tokens.numpy()])
        tokens = tokens.split('-')

        text = [char 
                for batch_token in tokens
                for idx, char in enumerate(batch_token)
                if char != batch_token[idx-1] or len(batch_token) == 1]    
        
        text = ''.join(text)  

        return text

    def calculate_loss(self, logits, labels):
        encoded_labels, labels_len = self.encode(labels)

        logits_lens = torch.full(
            size=(logits.size(1),),
            fill_value = logits.size(0),
            dtype = torch.int32
        ).to(self.device)

        return self.critertion(
            logits.log_softmax(2), encoded_labels,
            logits_lens, labels_len
        )
    
    def train_step(self, optimizer, images, labels):
        logits = self.predict(images)

        optimizer.zero_grad()
        loss = self.calculate_loss(logits, labels)
        loss.backward()
        optimizer.step()

        return logits, loss
    
    def val_step(self, images, labels):
        logits = self.predict(images)
        loss = self.calculate_loss(logits, labels)

        return logits, loss
    
    def predict(self, img):
        return self.crnn(img.to(self.device))
    
    def train(self, num_epochs, optimizer, train_loader, val_loader, print_every = 2):
        train_losses, valid_losses = [],[]

        for epoch in range(num_epochs):
            tot_train_loss = 0
            self.crnn.train()

            for i, (images, labels) in enumerate(train_loader):
                logits, train_loss = self.train_step(optimizer, images, labels)
                tot_train_loss += train_loss.item()

            with torch.no_grad():
                tot_val_loss = 0
                self.crnn.eval()

                for i, (images, labels) in enumerate(val_loader):
                    logits, val_loss = self.val_step(images, labels)

                    tot_val_loss += val_loss.item()
                
                train_loss = tot_train_loss / len(train_loader.dataset)
                valid_loss = tot_val_loss / len(val_loader.dataset)

                train_losses.append(train_loss)
                valid_losses.append(valid_loss)
            if epoch % print_every == 0:
                print('Epoch [{:5d}/{:5d}] | train loss {:6.4f} | val loss {:6.4f}'.format(
                    epoch + 1, num_epochs, train_loss, val_loss
                ))                
        return train_losses, valid_losses

5、开启训练


TRAIN_DIR = '../data/train'
VAL_DIR = '../data/val'

# batch_size lr 参数值训练,得到的结果较合适
BATCH_SIZE = 8
N_WORKERS = 0
EPOCHS = 20

CHARS ='abcdefghijklmnopqrstuvwxyz0123456789'
VOCAB_SIZE = len(CHARS) + 1

lr = 0.02
# 权重衰减
weight_decay = 1e-5
# 下降幅度
momentum = 0.7


train_dataset = Dataset(TRAIN_DIR)
val_dataset = Dataset(VAL_DIR)

train_loader = DataLoader(
    train_dataset, batch_size = BATCH_SIZE,
    num_workers = N_WORKERS, shuffle=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE,
    num_workers=N_WORKERS, shuffle=False
)

ocr = OCR()

optimizer = optim.SGD(
    ocr.crnn.parameters(), lr =lr, nesterov=True,
    weight_decay=weight_decay, momentum=momentum
) 

train_losses, val_losses = ocr.train(EPOCHS, optimizer, train_loader, val_loader, print_every=1)

6、随机采样,验证模型


sample_result = []

for i in range(10):
    idx = np.random.randint(len(val_dataset))
    img, label = val_dataset.__getitem__(idx)
    logits = ocr.predict(img.unsqueeze(0))
    pred_text = ocr.decode(logits.cpu())

    sample_result.append((img, label, pred_text))

fig = plt.figure(figsize=(17,5))    
for i in range(10):
    ax = fig.add_subplot(2, 5, i+1, xticks=[], yticks=[])

    img, label, pred_text = sample_result[i]
    title = f'Truth: {label} | Pred: {pred_text}'

    ax.imshow(img.permute(1,2, 0))
    ax.set_title(title)

plt.show()

7、输出统计图

plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Valid Loss')
plt.title('Loss stats')
plt.legend()
plt.show()

8、外部数据验证

trans = transforms.Compose([
    # 取决于与处理中是否也做相同处理
    transforms.Grayscale(),
    # 原始数据集图片尺寸
    transforms.Resize([50, 200]),
    transforms.ToTensor(),
])
    

image = Image.open('../data/123.png').convert('RGB')
tensor_img = trans(image)
result = ocr.predict(tensor_img.unsqueeze(0))
text = ocr.decode(result.cpu())
print(text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1494296.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GPT vs Gemini vs Claude 测试大比拼 到底谁是最强王者?

Anthropic发布的通用大语言模型Claude,在各项能力方面号称是全方面超越GPT,实测究竟如何呢?这次测试顺便把前段时间发布的Gemini拉上一起做对比!主要是以一些有趣幽默的脑筋急转弯为题目,来看看不同大模型对此的反馈。…

闰年导致的哪些 Bug

每次闰年对程序员们都是一个挑战,平时运行好好的系统,在 02-29 这一天,好像就会有各种毛病。 虽然,提前一天,领导们都会提前给下面打招呼。但是,不可避免的,今天公司因为闰年还是有一些小故障。…

SpringBoot中集成LiteFlow(轻量、快速、稳定可编排的组件式规则引擎)实现复杂业务解耦、动态编排、高可扩展

场景 在业务开发中,经常遇到一些串行或者并行的业务流程问题,而业务之间不必存在相关性。 使用策略和模板模式的结合可以解决这个问题,但是使用编码的方式会使得文件太多, 在业务的部分环节可以这样操作,在项目角度就无法一眼洞…

Keil软件无法烧录程序的解决方案

1.由于单片机程序有些情况下出错,导致烧录进去单片机运行异常,无法烧录程序,但是Keil软件可以识别到SW Device器件,点击烧录程序提示no target connected连接。 解决方案: (1).点击魔术棒->debug->Settings,选择…

网络编程day6

1.思维导图 2.数据库操作的增、删、改完成。 #include<myhead.h> //定义新增员工信息函数 int do_add(sqlite3 *ppDb) {int numb;char name;double salary;printf("请输入要插入的信息&#xff1a;");scanf("%d%s%d\n",&numb,name,&salary)…

7大必备应用推荐,为你的 Nextcloud 实例增添更多效率功能

适用于 Linux 的开源云存储软件有很多&#xff0c;ownCloud、Seafile 和 Pydio 只是其中的几个。 不过&#xff0c;如果您非常重视安全问题&#xff0c;并希望完全掌管您的数据&#xff0c;可以选择​Nextcloud并将其安装到您的服务器上。​ Nextcloud 是一个基于 PHP 的开源安…

Pytest中实现自动生成测试用例脚本代码!

前言 在Python的测试框架中&#xff0c;我们通常会针对某个系统进行测试用例的维护&#xff0c;在对庞大系统进行用例维护时&#xff0c;往往会发现很多测试用例是差不多的&#xff0c;甚至大多数代码是一样的。 故为了提高我们测试用例维护的效率&#xff0c;在本文中&#…

Java常用笔试题,面试java对未来的规划

最重要的话 2021年&#xff0c;真希望行业能春暖花开。 去年由于疫情的影响&#xff0c;无数行业都受到了影响&#xff0c;互联网寒冬下&#xff0c;许多程序员被裁&#xff0c;大环境格外困难。 我被公司裁掉后&#xff0c;便着急地开始找工作&#xff0c;一次次地碰壁&#…

爬虫学习笔记-requests爬取王者荣耀皮肤图片

1.导入所需的包 import requests from lxml import etree import os from time import sleep 2.定义请求头 headers {User-Agent:Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36} 3.发送请求 # hero…

数据结构->链表分类与oj(题),带你提升代码好感

✅作者简介&#xff1a;大家好&#xff0c;我是橘橙黄又青&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;橘橙黄又青-CSDN博客 1.&#x1f34e;链表的分类 前面我们学过顺序表&#xff0c;顺序表问题&#xff1a; …

基于springboot实现的幼儿园管理系统

一、系统架构 前端&#xff1a;html | layui | jquery | css 后端&#xff1a;springboot | mybatis 环境&#xff1a;jdk1.8 | mysql | maven 二、代码及数据库 三、功能介绍 01. 登录页 02. 系统管理-用户管理 03. 系统管理-页面管理 04. 系统管理-角色管…

喜报|3DCAT成为国内首批适配Vision Pro内容开发者

近日&#xff0c;苹果在上海总部举办了国内首场 Apple Vision Pro 开发者实验室活动&#xff0c;3DCAT作为国内领先的实时渲染云平台参与了此次活动&#xff0c;成为国内首批适配 Vision Pro 的内容开发者之一。 Vision Pro是苹果于2023年6月发布的首个空间计算设备&#xff0…

【C++STL详解 —— string类】

【CSTL详解 —— string类】 CSTL详解 —— sring类一、string的定义方式二、string的插入三、string的拼接四、string的删除五、string的查找六、string的比较七、string的替换八、string的交换九、string的大小和容量十、string中元素的访问十一、string中运算符的使用十二、…

鸿蒙NEXT开发实战:【视频文件裁剪】

使用OpenHarmony系统提供的ffmpeg三方库的能力在系统中实现了音视频文件裁剪的功能&#xff0c;并通过NAPI提供给上层应用调用。 基础信息 视频文件裁剪 简介 在OpenHarmony系统整个框架中有很多子系统&#xff0c;其中多媒体子系统是OpenHarmony比较重要的一个子系统&#…

Java+SpringBoot+Vue+MySQL:农业管理新篇章

✍✍计算机毕业编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java、…

软件测试实战,Web项目网页bug定位详细分析总结(详全)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、前置条件 1&a…

9、Linux-安装JDK、Tomcat和MySql

目录 一、安装JDK 1、传输JDK文件&#xff08;.tar.gz&#xff09; 2、解压 3、备份环境变量 4、配置环境变量 5、重新加载环境变量 6、验证&#xff08;java -version&#xff09; 二、安装Tomcat 1、传输文件&#xff0c;解压到/usr/local 2、进入Tomcat的bin目录 …

数据库-ER图教程

一.什么是E-R图 E-R图全称&#xff1a;“Entity-Relationship Approach”&#xff0c;是一种“实体-联系”方法。 E-R图的优点&#xff1a; 1.自然地描述现实世界。 2.图形结构简单。 3.设计者和用户易理解。 4.是数据库设计的中间步骤&#xff0c;易于向数据模型转换。 …

44、网络编程/数据库相关操作练习20240306

一、代码实现数据库的创建&#xff08;员工信息表&#xff09;&#xff0c;并存储员工信息&#xff08;工号、姓名、薪资&#xff09;&#xff0c;能实现增加人员信息、删除人员信息、修改人员薪资操作。 代码&#xff1a; #include<myhead.h>int do_update(sqlite3 *p…

作业1-32 P1059 [NOIP2006 普及组] 明明的随机数

题目 思路 根据题意&#xff0c;需要将读入的数据排序&#xff0c;去重。 参考代码 #include<bits/stdc.h> using namespace std; int n,a[5000],k;int main() {while(cin>>n){//读入数据for(int i0;i<n;i)cin>>a[i];sort(a,an);//排序int b[5000];in…