卷积神经网络(CNN)对验证码图片识别案例

news2024/9/25 11:13:39

数据集

数据集下载

链接:https://pan.baidu.com/s/1ypNNQkR1_ZK-_KO92x6Phw?pwd=6753 
提取码:6753

图片1  -->NZPP   一个样本对应四个目标值

NZPP  ---【13,25,15,15】

使用one-hot编码转换

第一个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0]

第二个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]

第三个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,]

第四个位置:[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,]

如何衡量损失?

       softmax交叉熵  ---适合一个样本对应一个目标值

       sigmoid交叉熵  ---每个类别独立且不互相排斥 适合多个目标值情况

准确率如何计算?

        核心:对比真实值和预测值最大值所在位置是否一致 

                 本案例要比对4个26列 4个所有一致才能说明识别正确

                  y_predict[None,4,26]

                 tf.argmax(y_predict,axis = 2)  将返回[[True],[True],[True],[True]]等结果

                  判断全是True才能返回True  ---->tf.reduce_all()

流程分析

6000张图片

1)读取图片数据

              filename --->标签值

2)解析CSV文件,将标签值处理成数字形式  例 NZPP  ---[13,25,15,15]

3)将filename和标签值联系起来

4)构建卷积神经网络 --->y_predict

5)构造损失函数  sigmoid交叉熵损失

6)优化损失

7)计算准确率

8)开启会话、开启线程

代码实现

import tensorflow as tf
import glob
import pandas as pd
import numpy as np
tf.compat.v1.disable_eager_execution()

# 1) 读取图片数据filename  --标签值
def read_pic():
    """
    读取图片数据
    :return:
    """
    #1、构建文件名队列
    file_names = glob.glob("./tmp/GenPics/*.jpg")
    # print("file_names:\n",file_names)
    file_queue = tf.compat.v1.train.string_input_producer(file_names)
    # 2、读取与解码
    reader = tf.compat.v1.WholeFileReader()
    filename,image = reader.read(file_queue)
    # 解码
    decoded = tf.image.decode_jpeg(image)
    # 更新图像,将图片形状确定下来
    decoded.set_shape([20,80,3])
    # 修改图像类型
    image_cast = tf.cast(decoded,tf.float32)
    # 批处理
    filename_batch,image_batch = tf.compat.v1.train.batch([filename,image_cast],batch_size=100,num_threads=1,capacity=200)
    return filename_batch,image_batch
# 2) 解析CSV使得成为 将标签值NZPP->[13, 25, 15, 15]
def parse_csv():
    """
    解析CSV文件,建立文件名和标签值对应表格
    :return:
    """
    csv_data = pd.read_csv("./tmp/GenPics/labels.csv", names=["file_num", "chars"], index_col="file_num")

    labels = []
    for label in csv_data["chars"]:
        tmp = []
        for letter in label:
            tmp.append(ord(letter) - ord("A"))
        labels.append(tmp)
    csv_data["labels"] = labels
    print(csv_data)
    return csv_data
# 3)将filename和标签值联系起来
def filename2label(filenames, csv_data):
    """
    将filename和标签值联系起来
    :param filenames:
    :param csv_data:
    :return:
    """
    labels = []
    # 将b'文件名中的数字提取出来并索引相应的标签值
    for filename in filenames:
        digit_str = "".join(list(filter(str.isdigit, str(filename))))
        label = csv_data.loc[int(digit_str), "labels"]
        labels.append(label)

     #print("labels:\n", labels)

    return np.array(labels)
# 4)构建卷积神经网络
def create_weights(shape):

    return tf.Variable(initial_value=tf.compat.v1.random_normal(shape=shape, stddev=0.01))

def create_model(x):
    """
    构建卷积神经网络
    :param x:[None, 20, 80, 3]
    :return:
    """
    # 1)第一个卷积大层
    with tf.compat.v1.variable_scope("conv1"):

        # 卷积层
        # 定义filter和偏置
        conv1_weights = create_weights(shape=[5, 5, 3, 32])
        conv1_bias = create_weights(shape=[32])
        conv1_x = tf.nn.conv2d(input=x, filters=conv1_weights, strides=[1, 1, 1, 1], padding="SAME") + conv1_bias

        # 激活层
        relu1_x = tf.nn.relu(conv1_x)

        # 池化层
        pool1_x = tf.nn.max_pool(input=relu1_x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")

    # 2)第二个卷积大层
    with tf.compat.v1.variable_scope("conv2"):
        # [None, 20, 80, 3] --> [None, 10, 40, 32]
        # 卷积层
        # 定义filter和偏置
        conv2_weights = create_weights(shape=[5, 5, 32, 64])
        conv2_bias = create_weights(shape=[64])
        conv2_x = tf.nn.conv2d(input=pool1_x, filters=conv2_weights, strides=[1, 1, 1, 1], padding="SAME") + conv2_bias

        # 激活层
        relu2_x = tf.nn.relu(conv2_x)

        # 池化层
        pool2_x = tf.nn.max_pool(input=relu2_x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")

    # 3)全连接层
    with tf.compat.v1.variable_scope("full_connection"):
        # [None, 10, 40, 32] -> [None, 5, 20, 64]
        # [None, 5, 20, 64] -> [None, 5 * 20 * 64]
        # [None, 5 * 20 * 64] * [5 * 20 * 64, 4 * 26] = [None, 4 * 26]
        x_fc = tf.reshape(pool2_x, shape=[-1, 5 * 20 * 64])
        weights_fc = create_weights(shape=[5 * 20 * 64, 4 * 26])
        bias_fc = create_weights(shape=[104])
        y_predict = tf.matmul(x_fc, weights_fc) + bias_fc

    return y_predict



if __name__ == "__main__":
    filename,image = read_pic()
    csv_data = parse_csv()
    # 1、准备数据
    x = tf.compat.v1.placeholder(tf.float32, shape=[None, 20, 80, 3])
    y_true = tf.compat.v1.placeholder(tf.float32, shape=[None, 4 * 26])

    # 2、构建模型
    y_predict = create_model(x)

    # 3、构造损失函数
    loss_list = tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_predict)
    loss = tf.reduce_mean(loss_list)

    # 4、优化损失
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

    # 5、计算准确率
    equal_list = tf.reduce_all(
        tf.equal(tf.argmax(tf.reshape(y_predict, shape=[-1, 4, 26]), axis=2),
                 tf.argmax(tf.reshape(y_true, shape=[-1, 4, 26]), axis=2)), axis=1)
    accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))

    # 初始化变量
    init = tf.compat.v1.global_variables_initializer()

    # 开启会话
    with tf.compat.v1.Session() as sess:
        # 初始化变量
        sess.run(init)
        coord = tf.compat.v1.train.Coordinator()
        threads = tf.compat.v1.train.start_queue_runners(sess=sess, coord=coord)
        for i in range(500):
            filename_value, image_value = sess.run([filename, image])
            # print("filename_value:\n",filename_value)
            # print("image_value:\n",image_value)
            labels = filename2label(filename_value, csv_data)
            # 将标签值转换成one-hot
            labels_value = tf.reshape(tf.one_hot(labels, depth=26), [-1, 4 * 26]).eval()

            _, error, accuracy_value = sess.run([optimizer, loss, accuracy],
                                                feed_dict={x: image_value, y_true: labels_value})

            print("第%d次训练后损失为%f,准确率为%f" % (i + 1, error, accuracy_value))

        # 回收线程
        coord.request_stop()
        coord.join(threads)

数据读取解码

解析CSV文件,文件名与标签值对应起来,

训练结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1615787.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最强开源大模型Meta LIama3抢先在线体验!

4月19日Facebook母公司Meta重磅推出了其迄今最强大的开源人工智能(AI)模型——Llama 3。模型分为两种规模:8B 和 70B 参数,每种规模都提供预训练基础版和指令调优版。最强开源大语言模型Meta LIama3可以在线体验啦! G…

心理学|变态心理学健康心理学——躯体疾病患者的一般心理特点

一、对客观世界和自身价值的态度发生改变 患者除了内部器官有器质或功能障碍外,他们的自我感觉和整个精神状态也会发生变化。使人改变对周围事物的感受和态度,也可以改变患者对自身存在价值的态度。这种主观态度的改变,可以使患者把自己置于人…

wps免登录绕路

打开注册表 regedit 新建字符串值--> false

第25天:安全开发-PHP应用文件管理包含写入删除下载上传遍历安全

第二十五天 一、PHP文件管理-下载&删除功能实现 1.文件上传: 无过滤机制黑名单过滤机制白名单过滤机制文件类型过滤机制 2.文件删除: unlink() 文件删除函数调用命令删除:system shell_exec exec等 3. 文件下载: 修改HT…

我独自升级崛起怎么下载 一文分享我独自升级崛起游戏下载教程

我独自升级崛起怎么下载 一文分享我独自升级崛起游戏下载教程 我独自升级:崛起是一款由韩国漫画改编而成的热门多人网络在线联机游戏,这款游戏是一款的角色扮演类型游戏,游戏有着独一无二的剧情模式。小伙伴们在游戏中可以体验到独特的成长系…

URL解析

目录 URIURLURL语法相对URLURL中的转义 现在与未来PURL 在 URL出现之前,人们如果想访问网络中的资源,就需要使用不同的 应用程序,如共享文件需要使用 FTP程序,想要发送邮件必须使用 邮件程序,想要看新闻那只能使用…

VSCode 配置 C/C++ 环境

1 安装 VSCode 直接去官网(https://code.visualstudio.com/)下载并安装即可。 2 配置C/C编译环境 方案一 如果是在Windows,需要安装 MingW,可以去官网(https://sourceforge.net/projects/mingw-w64/)下载安装包。 注意安装路径不要出现中文。 打开 w…

JAVA面向对象(下)(四、内部类、枚举、包装类)

一、内部类(续) 1.1 内部类的分类 按照声明的位置划分,可以把内部类分为两大类: 成员内部类:方法外 局部内部类:方法内 public class 外部类名{【修饰符】 class 成员内部类名{ //方法外}【修饰符】 返…

力扣刷题 70.爬楼梯

题干 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2 阶 示例 2&…

HarmonyOS开发实例:【图片编辑应用】

介绍 本篇Codelab通过动态设置元素样式的方式,实现几种常见的图片操作,包括裁剪、旋转、缩放和镜像。效果如图所示: 相关概念 [image组件]:图片组件,用来渲染展示图片。[div组件]:基础容器组件&#xff0…

PLC_博图系列☞N=:在信号下降沿置位操作数

、 PLC_博图系列☞N:在信号下降沿置位操作数 文章目录 PLC_博图系列☞N:在信号下降沿置位操作数背景介绍N: 在信号下降沿置位操作数说明参数示例 关键字: PLC、 西门子、 博图、 Siemens 、 N 背景介绍 这是一篇关于PLC编程的…

Python网络数据抓取(3):Requests

引言 在这一部分,我们将探讨Python的requests库,并且利用这个库来进行网页数据抓取。那么,我们为何需要这个库,以及怎样利用它呢? requests库是广受大家欢迎的一个库,它是下载次数最多的。这个库使我们能够…

C语言学习/复习27----sizeof/strlen/数组/指针

一、数组笔试题目解析 1.一维数组 1.sizeof()操作符与int数组 注意事项1:sizeof()依据类型推断大小 注意事项2:注意区分是( )内是地址还是普通元素类型 注意事项3:()内是单独的数组名时计算整个数组的大小,…

海外服务器被恶意攻击怎么办

如果您的海外服务器遭受了恶意攻击,以下是一些应对措施和步骤,立即隔离服务器。如果您察觉到服务器受到恶意攻击,立即隔离服务器,将其与网络隔离,以防止攻机进一步扩散。通知服务器提供商,以便他们能够提供…

有了可视化工具,你定制设计得瑟瑟发抖了吧,其实你想多了。

目前市面上有N多可视化的工具,可以做成可视化大屏,甚至有很多B端系统也附带可视化页面,据此就有很多人开始怀疑我们这些做定制开发的,还有啥生存空间。 其实你真的多虑了,存在即合理,我们承认可视化工具的标…

小白必备:Python必须掌握的十大模块,建议收藏!

前言 Python 是一种高级、解释型和通用动态编程语言,侧重于代码的可读性。 它在许多组织中使用,因为它支持多种编程范例。 它还执行自动内存管理。 它是世界上最受欢迎的编程语言之一。 这是有很多原因的: 这很容易学习。它超级多才多艺。…

05集合-CollectionListSet

Collection体系的特点、使用场景总结 如果希望元素可以重复,又有索引,索引查询要快? 用ArrayList集合, 基于数组的。(用的最多) 如果希望元素可以重复,又有索引,增删首尾操作快? 用LinkedList集合, 基于链表的。 如果希望增…

【电机控制】滑模观测器PMSM无感控制波形图

【电机控制】滑模观测器PMSM无感控制波形图 文章目录 前言一、FOC控制1.三相电流2.Clark变换静止坐标系iαiβ3.park变换旋转坐标系idiq4.电流环PI控制输出UdUq5.UdUq 反park变换UαUβ 二、反电动势观测器BEMF1.静止坐标系iαiβ提取反电动势EaEb2.反电动势EaEb提取位置信息、…

【国信华源参加全国地质灾害防治新技术新方法新设备交流会】

4月17-18日,以“提升地质灾害防治能力 服务保障高质量发展”为主题,由中国地质灾害防治与生态修复协会主办、云南地质工程第二勘察院有限公司承办的“全国地质灾害防治新技术新方法新设备成果交流会”在云南昆明圆满召开。会议特邀中国工程院院士等知名…

实现游戏地图读取与射击运行

射击代码来源自2D 横向对抗射击游戏(by STF) - CodeBus 地图读取改装自 瓦片地图编辑器 解决边界检测,实现使用不同像素窗口也能移动不闪退-CSDN博客 // 程序:2D RPG 地图编辑器改游戏读取器 // 作者:民用级脑的研发…