pytorch二维码识别

news2024/12/28 2:45:44

二维码图片的生成

利用captcha可以生成二维码图片

# -*- coding: UTF-8 -*-
from captcha.image import ImageCaptcha  # pip install captcha
from PIL import Image
import random
import time
import os
# 验证码中的字符
# string.digits + string.ascii_uppercase
NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

ALL_CHAR_SET = NUMBER + ALPHABET  #验证码所有的字符和数字
ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
MAX_CAPTCHA = 12 #每个验证码字符个数

# 图像大小
IMAGE_HEIGHT = 60
IMAGE_WIDTH = 160

train_path = 'dataset' + os.path.sep + 'train'
test_path = 'dataset' + os.path.sep + 'test'
predict_path = 'dataset' + os.path.sep + 'predict'

def random_captcha():
    #生成随机字符串
    captcha_text = []
    for i in range(MAX_CAPTCHA):
        c = random.choice(ALL_CHAR_SET)
        captcha_text.append(c)
    return ''.join(captcha_text)

# 生成字符对应的验证码
def gen_captcha_text_and_image():
    image = ImageCaptcha()
    captcha_text = random_captcha()
    #通过随机字符串生成二维码图片
    captcha_image = Image.open(image.generate(captcha_text))
    return captcha_text, captcha_image

if __name__ == '__main__':
    count = 300  #生成二维码的个数
    path = train_path    #通过改变此处目录,以生成 训练、测试和预测用的验证码集
    if not os.path.exists(path):
        os.makedirs(path)
    for i in range(count):
        now = str(int(time.time()))
        text, image = gen_captcha_text_and_image()
        filename = text+'_'+str(i)+'.png'
        #保存二维码图片
        image.save(path  + os.path.sep +  filename)
        print('saved %d : %s' % (i+1,filename))

dataset

# -*- coding: UTF-8 -*-
import os
import random

from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as transforms
from PIL import Image
import one_hot_encoding as ohe
import captcha_setting

class mydataset(Dataset):

    def __init__(self, folder, transform=None):
        self.train_image_file_paths = [os.path.join(folder, image_file) for image_file in os.listdir(folder)]
        self.transform = transform

    def __len__(self):
        return len(self.train_image_file_paths)

    def __getitem__(self, idx):
        image_root = self.train_image_file_paths[idx]
        image_name = image_root.split(os.path.sep)[-1]
        image = Image.open(image_root)
        if self.transform is not None:
            image = self.transform(image)
        label = ohe.encode(image_name.split('_')[0]) # 为了方便,在生成图片的时候,图片文件的命名格式 "4个数字或者数字_时间戳.PNG", 4个字母或者即是图片的验证码的值,字母大写,同时对该值做 one-hot 处理
        return image, label

transform = transforms.Compose([
    transforms.ColorJitter(),
    # transforms.Grayscale(),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def get_train_data_loader():

    dataset = mydataset(captcha_setting.TRAIN_DATASET_PATH, transform=transform)
    return DataLoader(dataset, batch_size=64, shuffle=True)

def get_test_data_loader():
    dataset = mydataset(captcha_setting.TEST_DATASET_PATH, transform=transform)
    return DataLoader(dataset, batch_size=1, shuffle=True)

def get_predict_data_loader():
    dataset = mydataset(captcha_setting.PREDICT_DATASET_PATH, transform=transform)
    return DataLoader(dataset, batch_size=1, shuffle=True)

if __name__=="__main__":
    from matplotlib import pyplot as plt
    dataset=mydataset('./dataset/train',transform=transform)
    indexes=random.sample(range(len(dataset)-1),16)
    image,label=dataset[0]

    for i,index in enumerate(indexes):
        image,label=dataset[index]
        image=transforms.ToPILImage()(image)
        plt.subplot(4,4,i+1)
        plt.title(ohe.decode(label))
        plt.xticks([])
        plt.yticks([])
        plt.imshow(image)
    plt.show()

 one-hot编码

# -*- coding: UTF-8 -*-
import numpy as np
import captcha_setting

def encode(text):
    vector = np.zeros(captcha_setting.ALL_CHAR_SET_LEN * captcha_setting.MAX_CAPTCHA, dtype=float)
    #每个字符都需要进行编码.
    #每个字符都需要在字典中查询得到,所有变量的维度是  *max_captcha
    def char2pos(c):
        if c =='_':
            k = 62
            return k
        k = ord(c)-48
        #将acii码转换为字符
        # hh=chr(k+48)
        if k > 9:
            k = ord(c) - 65 + 10
            if k > 35:
                k = ord(c) - 97 + 26 + 10
                if k > 61:
                    raise ValueError('error')
        return k
    for i, c in enumerate(text):
        idx = i * captcha_setting.ALL_CHAR_SET_LEN + char2pos(c)
        vector[idx] = 1.0
    return vector

def decode(vec):
    char_pos = vec.nonzero()[0]
    text=[]
    for i, c in enumerate(char_pos):
        char_at_pos = i #c/63
        char_idx = c % captcha_setting.ALL_CHAR_SET_LEN
        if char_idx < 10:
            char_code = char_idx + ord('0')
        elif char_idx <36:
            char_code = char_idx - 10 + ord('A')
        elif char_idx < 62:
            char_code = char_idx - 36 + ord('a')
        elif char_idx == 62:
            char_code = ord('_')
        else:
            raise ValueError('error')
        text.append(chr(char_code))
    return "".join(text)

if __name__ == '__main__':
    e = encode("9L7H")
    print(e)
    print(decode(e))

[0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
9L7H

模型

# -*- coding: UTF-8 -*-
import torch.nn as nn
import captcha_setting

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.Dropout(0.5),  # drop 50% of the neuron
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.Dropout(0.5),  # drop 50% of the neuron
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.Dropout(0.5),  # drop 50% of the neuron
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Sequential(
            nn.Linear((captcha_setting.IMAGE_WIDTH//8)*(captcha_setting.IMAGE_HEIGHT//8)*64, 1024),
            nn.Dropout(0.5),  # drop 50% of the neuron
            nn.ReLU())
        self.rfc = nn.Sequential(
            nn.Linear(1024, captcha_setting.MAX_CAPTCHA*captcha_setting.ALL_CHAR_SET_LEN),
        )

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.rfc(out)
        return out
if __name__=='__main__':
    import torch
    model=CNN()
    input=torch.randn(1,1,60,160)
    gt_output=torch.randn(1,144)
    output=model(input)
    print('输出的向量大小',output.shape)

    criterion = nn.MultiLabelSoftMarginLoss()
    loss=criterion(gt_output,output)
    print('损失的大小',loss.data)

 输出的向量大小 torch.Size([1, 144])
损失的大小 tensor(0.7603)

模型的训练

# -*- coding: UTF-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import my_dataset
from captcha_cnn_model import CNN

# Hyper Parameters
num_epochs = 30
batch_size = 100
learning_rate = 0.001

def main():
    cnn = CNN()
    cnn.train()
    print('init net')
    criterion = nn.MultiLabelSoftMarginLoss()
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)

    # Train the Model
    train_dataloader = my_dataset.get_train_data_loader()
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_dataloader):
            images = Variable(images)
            labels = Variable(labels.float())
            predict_labels = cnn(images)
            loss = criterion(predict_labels, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1) % 10 == 0:
                print("epoch:", epoch, "step:", i, "loss:", loss.item())
            if (i+1) % 100 == 0:
                torch.save(cnn.state_dict(), "./model.pkl")   #current is model.pkl
                print("save model")
        print("epoch:", epoch, "step:", i, "loss:", loss.item())
    torch.save(cnn.state_dict(), "./model.pth")
    print("save last model")

if __name__ == '__main__':
    main()


 init net
epoch: 0 step: 4 loss: 0.38898029923439026
epoch: 1 step: 4 loss: 0.19714970886707306
epoch: 2 step: 4 loss: 0.15116369724273682
epoch: 3 step: 4 loss: 0.1429138034582138
epoch: 4 step: 4 loss: 0.1360236257314682
epoch: 5 step: 4 loss: 0.12835916876792908
epoch: 6 step: 4 loss: 0.1267365664243698
epoch: 7 step: 4 loss: 0.12457828223705292
epoch: 8 step: 4 loss: 0.12483084201812744
epoch: 9 step: 4 loss: 0.11893215030431747
epoch: 10 step: 4 loss: 0.11814623326063156
epoch: 11 step: 4 loss: 0.11591014266014099
epoch: 12 step: 4 loss: 0.11125991493463516
epoch: 13 step: 4 loss: 0.10649068653583527
epoch: 14 step: 4 loss: 0.10284445434808731
epoch: 15 step: 4 loss: 0.10144951194524765
epoch: 16 step: 4 loss: 0.0985511839389801
epoch: 17 step: 4 loss: 0.08964875340461731
epoch: 18 step: 4 loss: 0.08870525658130646
epoch: 19 step: 4 loss: 0.0839766412973404
epoch: 20 step: 4 loss: 0.0823589637875557
epoch: 21 step: 4 loss: 0.07506724447011948
epoch: 22 step: 4 loss: 0.06370603293180466
epoch: 23 step: 4 loss: 0.06234220042824745
epoch: 24 step: 4 loss: 0.06265763193368912
epoch: 25 step: 4 loss: 0.05445406585931778
epoch: 26 step: 4 loss: 0.05590423569083214
epoch: 27 step: 4 loss: 0.0482553206384182
epoch: 28 step: 4 loss: 0.04553262144327164
epoch: 29 step: 4 loss: 0.03754893317818642
save last model

模型的测试

# -*- coding: UTF-8 -*-
import numpy as np
import torch
from torch.autograd import Variable
import captcha_setting
import my_dataset
from captcha_cnn_model import CNN
import one_hot_encoding

def main():
    cnn = CNN()
    cnn.eval()
    cnn.load_state_dict(torch.load('model.pth'))
    print("load cnn net.")

    test_dataloader = my_dataset.get_test_data_loader()

    correct = 0
    total = 0
    for i, (images, labels) in enumerate(test_dataloader):
        image = images
        vimage = Variable(image)
        predict_label = cnn(vimage)

        c0 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 0:captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
        c1 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, captcha_setting.ALL_CHAR_SET_LEN:2 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
        c2 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 2 * captcha_setting.ALL_CHAR_SET_LEN:3 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
        c3 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 3 * captcha_setting.ALL_CHAR_SET_LEN:4 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
        predict_label = '%s%s%s%s' % (c0, c1, c2, c3)
        true_label = one_hot_encoding.decode(labels.numpy()[0,:])
        total += labels.size(0)
        if(predict_label == true_label):
            correct += 1
        if(total%200==0):
            print('Test Accuracy of the model on the %d test images: %f %%' % (total, 100 * correct / total))
    print('Test Accuracy of the model on the %d test images: %f %%' % (total, 100 * correct / total))

if __name__ == '__main__':
    main()


 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/157556.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

整理了一周近万字讲解linux基础开发工具vim,gdb,gcc,yum等的使用

文章目录 前言一、yum的使用二、vim的使用三 . gcc/g的使用四 . gdb的使用总结前言 想用linux开发一些软件等必须要会的几种开发工具是必不可少的&#xff0c;在yum vim gcc gdb中指令繁杂的是vim和gdb这两个工具&#xff0c;至于yum和gcc的指令就比较简单了。 一、yum的使用…

【SpringMVC】拦截器

目录 一、概念 二、自定义拦截器的三个实现方法 三、自定义拦截器执行流程 四、使用 五、拦截器和过滤器 相关文章&#xff08;可以关注我的SpringMVC专栏&#xff09; SpingMVC专栏SpingMVC专栏一、概念在学习拦截器之前&#xff0c;我们得先了解一下它是个什么❓ SpringMVC…

SAP ABAP调用标准事务码

这里介绍常见的几种在开发中常用到的事务代码跳转功能。 1、最常用到的是“SET PARAMETER”语句赋值&#xff0c;然后再使用“CALL TRANSACTION”语句跳转屏幕。 比如采购订单、销售订单、交货单、采购发票、销售发票等事务代码&#xff0c;均可以利用给参数赋值来直接跳转&am…

零售及仓储数字化整理解决方案

价格管控 皮克价格管控方案可实现门店与企业信息管理平台的数据同步&#xff0c;强化零售企业对终端的控制。同时为企业销售决策提供支持&#xff0c;优化门店经营活动的效率和频率。陈列管理 皮克陈列管理方案通过电子价签产品使商品陈列得到固化。 同时实现了陈列可视化&am…

ArcGIS水文分析提取河网及流域

在进行某些研究或者一些论文插图显示的时候&#xff0c;有时我们会碰到在部分资料中找不到一些小的河流或者流域的数据的情况&#xff0c;这里讲述通过DEM数据生成河网及流域。 一、数据来源 四川省高程数据来源于中国科学院资源环境科学与数据中心&#xff08;中国科学院资源环…

Vue3学习之深度剖析CSS Modules和Scope

Css Modules 是通过对标签类名进行加装成一个独一无二的类名&#xff0c;比如.class 转换成.class_abc_123,类似于symbol&#xff0c;独一无二的键名 Css Scope 是通过为元素增加一个自定义属性&#xff0c;这个属性加上独一无二的编号&#xff0c;而实现作用域隔离。 原理 …

爬虫必备抓包工具——Fiddler【认识使用】

目录&#xff1a;1.fiddler &#xff08;抓包工具&#xff09;1.1 引入&#xff1a;HTTP/https代理&#xff08;正向代理&#xff09;1.2 拓展&#xff1a;反向代理&#xff1a;1.2 初识Fiddler①什么是抓包&#xff1f;抓包有什么用&#xff1f;②浅谈fiddler&#xff1a;③fi…

Unity_Skybox自定义插件可实现日夜更替Polyverse Skies | Low Poly

又又一个天空盒,不过这个做的还是比较完善的。。。不会出现买家秀和买家秀差别大问题 此Skybox插件特色提供: 可扩展,自定义很多的Skybox Shader预制体几个,虽然都是夜晚样式(缺白天)若干预设值</

对NIO和BIO的进一步理解

疑问 在之前的学习中&#xff0c;只提到BIO是阻塞IO&#xff0c;在建立连接和读写事件时会阻塞线程。NIO是非阻塞IO&#xff0c;基于事件注册&#xff0c;通过Selector进行切换Channel&#xff0c;不会阻塞线程。对于这种解释&#xff0c;还是带有一些疑问的。Selector进行Cha…

#define 实现快捷模板类实例在eigen::Maxtrix中的应用

欢迎关注更多精彩 关注我&#xff0c;学习常用算法与数据结构&#xff0c;一题多解&#xff0c;降维打击。 背景 在eigen库中&#xff0c;矩阵类原来的用法是 Matrix<Type, row, col>。 为了方便用户&#xff0c;库中还提供了用户常用的快捷类型&#xff0c;比如Matrix…

Java-String的API

一、length()package 做题; import java.lang.reflect.Array; import java.security.PublicKey; import java.util.Arrays; import java.util.Scanner;import javax.naming.StringRefAddr;public class Main {public static void main(String[] args) {Scanner sc new Scanne…

ZeroTierr的moon云服务器搭建和使用

搭建moon 问题是ZeroTier One本身的服务器都在国外访问速度很慢。可以通过搭建国内Moon服务加速解决连接慢的问题。 1、 进入云服务器在线安装zerotier curl -s https://install.zerotier.com/ | sudo bash查看安装zerotier版本 sudo zerotier-cli status安装完成后生成moon…

从0到1完成一个Vue后台管理项目(二十二、列表拖拽排序SortableJS)

往期 从0到1完成一个Vue后台管理项目&#xff08;一、创建项目&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;二、使用element-ui&#xff09; 从0到1完成一个Vue后台管理项目&#xff08;三、使用SCSS/LESS&#xff0c;安装图标库&#xff09; 从0到1完成一个Vu…

Python WebGL 3D应用开发快速入门

在本文中&#xff0c;我们将学习如何在Python中使用three.js库&#xff0c;而无需编写任何一行 JavaScript。我们将使用PyWeb3D&#xff0c;这是一个额外的层&#xff0c;旨在与Brython的three.js轻松交互。 1、什么是PyWeb3D&#xff1f; 简单地说&#xff0c;PyWeb3D是一个…

现代JavaScript,你应该使用的10件事

javascripttip&#xff08;3 部分系列&#xff09;1现代 JavaScript&#xff0c;你应该使用的 10 件事&#xff0c;从今天开始2了解如何在 JavaScript 中使用循环3如何在 JavaScript 中学习足够多的 RegEx 才能变得危险您可能对 JavaScript 完全陌生&#xff0c;也可能多年来只…

文件学习笔记

删除线格式 ## 文件描述符 1.文件文件内容文件属性。 2.文件操作文件内容的操作文件属性的操作。 3.所谓的“打开”文件&#xff0c;是指将文件的属性或内容加载到内存中—这是由冯诺依曼决定。 4.所以文件不全打开&#xff0c;不打开的文件放在磁盘存储。 5.内存文件&#xff…

在线教育-谷粒学院学习笔记(三)

文章目录1 搭建前端项目环境2 前端页面框架介绍3 讲师管理前端开发4 后台系统登录功能改造到本地5 前端框架开发过程6 讲师列表前端实现7 讲师分页前端实现8 讲师条件查询前端实现9 讲师删除功能前端实现10 讲师添加前端实现11 讲师修改前端实现12 前端路由切换问题解决1 搭建前…

Linux工具学习之【gdb】

✨个人主页&#xff1a; Yohifo &#x1f389;所属专栏&#xff1a; Linux学习之旅 &#x1f38a;每篇一句&#xff1a; 图片来源 &#x1f383;操作环境&#xff1a; CentOS 7.6 阿里云远程服务器 Whatever is worth doing is worth doing well. 任何值得去做的事情&#xff0…

【HTML】一款可交互的响应式登陆注册表单,你确定不来看看嘛(附源码)

&#x1f482;作者简介&#xff1a; THUNDER王&#xff0c;一名热爱财税和SAP ABAP编程以及热爱分享的博主。目前于江西师范大学会计学专业大二本科在读&#xff0c;同时任汉硕云&#xff08;广东&#xff09;科技有限公司ABAP开发顾问。在学习工作中&#xff0c;我通常使用偏后…

2022年五一杯数学建模C题火灾报警系统问题求解全过程论文及程序

2022年五一杯数学建模 C题 火灾报警系统问题 原题再现&#xff1a; 二十世纪90年代以来&#xff0c;我国火灾探测报警产业化发展非常迅猛&#xff0c;从事火灾探测报警产品生产的企业已超过100家&#xff0c;年产值达几十亿元&#xff0c;已经成为我国高新技术产业的一个组成…