TensorFlow图像多标签分类实例

news2024/9/22 11:28:53

接下来,我们将从零开始讲解一个基于TensorFlow的图像多标签分类实例,这里以图片验证码为例进行讲解。

在我们访问某个网站的时候,经常会遇到图片验证码。图片验证码的主要目的是区分爬虫程序和人类,并将爬虫程序阻挡在外。

下面的程序就是模拟人类识别验证码,从而使网站无法区分是爬虫程序还是人类在网站登录。

10.4.1  使用TFRecord生成训练数据

以图10.5所示的图片验证码为例,将这幅验证码图片标记为label=[3,8,8,7]。我们知道分类网络一般一次只能识别出一个目标,那么如何识别这个多标签的序列数据呢?

通过下面的TFRecord结构可以构建多标签训练数据集,从而实现多标签数据识别。

图10.5  图片验证码

以下为构造TFRecord多标签训练数据集的代码:

import tensorflow as tf
# 定义对整型特征的处理    
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# 定义对字节特征的处理
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 定义对浮点型特征的处理
def _floats_feature(value):
    return tf.train_Feature(float_list=tf.train.floatList(value=[value]))
# 对数据进行转换    
def convert_to_record(name, image, label, map):
    filename = os.path.join(params.TRAINING_RECORDS_DATA_DIR,
        name + '.' + params.DATA_EXT)
    writer = tf.python_io.TFRecordWriter(filename)
    image_raw = image.tostring()
    map_raw = map.tostring()
    label_raw = label.tostring()
    example = tf.train.Example(feature=tf.train.Feature(feature={
        'image_raw': _bytes_feature(image_raw),
        'map_raw': _bytes_feature(map_raw),
        '1abel_raw': _bytes_feature(label_raw)
    }))
    writer.write(example.SerializeToString())
    writer.close()

通过上面的代码,我们构建了一条支持多标签的TFRecord记录,多幅验证码图片可以构建一个验证码的多标签数据集,用于后续的多标签分类训练。

10.4.2  构建多标签分类网络

通过前一步操作,我们得到了用于多标签分类的验证码数据集,现在需要构建多标签分类网络。

我们选择VGG网络作为特征提取网络骨架。通常越复杂的网络,对噪声的鲁棒性就越强。验证码中的噪声主要来自形变、粘连以及人工添加,VGG网络对这些噪声具有好的鲁棒性,代码如下:

import tensorflow as tf
tf.enable_eager_execution ()
def model_vgg(x, training = False):
# 第一组第一个卷积使用64个卷积核,核大小为3
conv1_1 = tf.layers.conv2d(inputs=x, filters=64,name="conv1_1",
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第一组第二个卷积使用64个卷积核,核大小为3
convl_2 = tf.layers.conv2d(inputs=conv1_1,filters=64, name="conv1_2",
    kernel_size=3, activation=tf.nn.relu,padding="same")
# 第一个pool操作核大小为2,步长为2
pooll = tf.layers.max_pooling2d(inputs=conv1_2, pool_size=[2, 2],
    strides=2, name= 'pool1')
# 第二组第一个卷积使用128个卷积核,核大小为3
conv2_1 = tf.layers.conv2d(inputs=pool1, filters=128, name="conv2_1",
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第二组第二个卷积使用64个卷积核,核大小为3
conv2_2 = tf.layers.conv2d(inputs=conv2_1, filters=128,name="conv2_2",
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第二个pool操作核大小为2,步长为2
pool2 = tf.layers.max_pooling2d(inputs=conv2_2, pool_size=[2, 2],
    strides=2, name="pool1")
# 第三组第一个卷积使用128个卷积核,核大小为3
conv3_1 = tf.layers.conv2d(inputs=pool2, filters=128, name="conv3_1", 
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第三组第二个卷积使用128个卷积核,核大小为3
conv3_2 = tf.layers.conv2d(inputs=conv3_1, filters=128, name="conv3_2", 
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第三组第三个卷积使用128个卷积核,核大小为3
conv3_3 = tf.layers.conv2d(inputs=conv3_2, filters=128, name="conv3_3", 
    kernel_size=3, activation=tf.nn.relu, padding=" same")
# 第三个pool 操作核大小为2,步长为2
pool3 = tf.layers.max_pooling2d(inputs=conv3_3, pool_size=[2, 2], 
    strides=2,name='pool3')
# 第四组第一个卷积使用256个卷积核,核大小为3
conv4_1 = tf.layers.conv2d(inputs-pool3, filters=256, name="conv4_1", 
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第四组第二个卷积使用128个卷积核,核大小为3
conv4_2 = tf.layers.conv2d(inputs=conv4_1, filters=128, name="conv4_2", 
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第四组第三个卷积使用128个卷积核,核大小为3
conv4_3 = tf.layers.conv2d(inputs=conv4_2, filters=128, name="cov4_3", 
    kernel_size=3, activation=tf.nn.relu, padding="same" )
# 第四个pool操作核大小为2,步长为2
pool4 = tf.layers.max.pooling2d(inputs=conv4_3, pool_size=[2,2], 
    strides=2, name='pool4')
# 第五组第一个卷积使用512个卷积核,核大小为3
conv5_1 = tf.layers.conv2d(inputs=pool4, filters=512, name="conv5_1", 
    kernel_size=3, activation=tf.nn.relu, padding=" same")
# 第五组第二个卷积使用512个卷积核,核大小为3
conv5_2 = t.layers.conv2d(inputs=conv5_1, filters=512, name="conv5_2", 
    kernel_size=3, activation=tf.nn.relu, padding="same")
# 第五组第三个卷积使用512个卷积核,核大小为3
conv5_3 = tf.layers.conv2d(inputs-conv5_2, filters=512, name="conv5_3", 
    kernel_size=3, activation=tf.nn.relu, padding="same"
    )
# 第五个pool操作核大小为2,步长为2
pool5 = tf.layers.max_pooling2d(inputs=conv5_3, pool_size=[2, 2], 
    strides=2, name='pool5')
flatten = tf.layers.flatten(inputs=poo15, name="flatten")

上面是VGG网络的单标签分类TensorFlow代码,但这里我们需要实现的是多标签分类,因此需要对VGG网络进行相应的改进,代码如下:

# 构建输出为4096的全连接层
fc6 = tf.layers.dense(inputs=flatten, units=4096,
activation=tf.nn.relu, name='fc6')
# 为了防止过拟合,引入dropout操作
drop1 = tf.layers.dropout(inputs=fc6,rate=0.5, training=training)
# 构建输出为4096的全连接层
fc7 = tf.layers.dense(inputs=drop1, units=4096,
activation=tf.nn.relu, name='fc7')
# 为了防止过报合,引入dropout操作
drop2 = tf.layers.dropout(inputs=fc7, rate=0.5, training=training)
# 为第一个标签构建分类器
fc8_1 = tf.layers.dense(inputs=drop2, units=10,
activation=tf.nn.sigmoid, name='fc8_1')
# 为第二个标签构建分类器
fc8_2 = tf.layers.dense(inputs=drop2, units=10,
activation=tf.nn.sigmoid, name='fc8_2')
# 为第三个标签构建分类器
fc8_3 = tf.layers.dense(inputs=drop2, units=10,
activation=tf.nn.sigmoid, name='fc8_3')
# 为第四个标签构建分类器
fc8_4 = tf.layers.dense(inputs=drop2,units=10,
activation=tf.nn.sigmoid, name='fc8_4')
# 将四个标签的结果进行拼接操作
fc8 = tf.concat([fc8_1,fc8_2,fc8_3,fc8_4], 0)

这里的fc6和fc7全连接层是对网络的卷积特征进行进一步的处理,在经过fc7层后,我们需要生成多标签的预测结果。由于一幅验证码图片中存在4个标签,因此需要构建4个子分类网络。这里假设图片验证码中只包含10 个数字,因此每个网络输出的预测类别就是10类,最后生成4个预测类别为10的子网络。如果每次训练时传入64幅验证码图片进行预测,那么通过4个子网络后,分别生成(64,10)、(64,10)、(64,10)、(64,10) 4个张量。如果使用Softmax分类器的话,就需要想办法将这4个张量进行组合,于是使用tf.concat函数进行张量拼接操作。

以下是TensorFlow中tf.concat函数的传参示例:

tf.concat (
values,
axis,
name='concat'
)

通过fc8=tf.concat([fc8_1,fc8_2,fc8_3,fc8_4], 0)的操作,可以将前面的4个(64.10)张量变换成(256.10)这样的单个张量,生成单个张量后就能进行后面的Softmax分类操作了。

10.4.3  多标签训练模型

模型训练的第一个步骤就是读取数据,读取方式有两种:一种是直接读取图片进行操作,另一种是转换为二进制文件格式后再进行操作。前者实现起来简单,但速度较慢;后者实现起来复杂,但读取速度快。这里我们以后者二进制的文件格式介绍如何实现多标签数据的读取操作,下面是相关代码。

首先读取TFRecord文件内容:

tfr = TFrecorder()
def input_fn_maker(path, data_info_path, shuffle=False, batch_size = 1,
epoch = 1, padding = None) :    
def input_fn():
    filenames = tfr.get_filenames(path=path, shuffle=shuffle)
    dataset=tfr.get_dataset(paths=filenames,
        data_info=data_info_path, shuffle = shuffle,
        batch_size = batch_size, epoch = epoch, padding = padding)
    iterator = dataset.make_one_shot_iterator ()
    return iterator.get_next()
return input_fn
# 原始图片信息
padding_info = ({'image':[30, 100,3,], 'label':[]})
# 测试集
test_input_fn = input_fn_maker('captcha_data/test/',
'captcha_tfrecord/data_info.csv',
batch_size = 512, padding = padding_info)
# 训练集
train_input_fn = input_fn_maker('captcha_data/train/',
'captcha_tfrecord/data_info.csv',
shuffle=True, batch_size = 128,padding = padding_info)
# 验证集
train_eval_fn = input_fn_maker('captcha_data/train/',
'captcha_tfrecord/data_info.csv',
batch_size = 512,adding = padding_info)

然后是模型训练部分:

def model_fn(features, net, mode):
features['image'] = tf.reshape(features['image'], [-1, 30, 100, 3])
# 获取基于net网络的模型预测结果
predictions = net(features['image'])
# 判断是预测模式还是训练模式
if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode,
        predictions=predictions)
# 因为是多标签的Softmax,所以需要提前对标签的维度进行处理
lables = tf.reshape(features['label'], features['label'].shape[0]*4,))
# 初始化softmaxloss
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
    logits=logits)
# 训练模式下的模型结果获取
if mode ==tf.estimator.ModeKeys.TRAIN:
    # 声明模型使用的优化器类型
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
        train_op = optimizer.minimize(
            loss=loss,global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode,
        loss=loss, train_op=train_op)
# 生成评价指标
eval_metric_ops = {"accuracy": tf.metrics.accuracy(
    labels=features['label'],predictions=predictions["classes"]) }
return tf.estimator.EstimatorSpec(mode=mode, loss=loss,
    eval_metric_ops= eval_metric_ops)

多标签的模型训练流程与普通单标签的模型训练流程非常相似,唯一的区别就是需要将多标签的标签值拼接成一个张量,以满足Softmax分类操作的维度要求。

本文节选自《Python深度学习原理、算法与案例》。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1134891.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络——计算机网络体系结构(4/4)-计算机网络体系结构中的专用术语(实体、协议、服务,三次握手‘三报文握手’、数据包术语)

目录 分类一——实体 实体 对等实体 分类二——协议 协议 协议的三要素 分类三——服务 服务 服务访问点 数据包术语 计算机网络体系结构中的专用术语 本篇所讲的专用术语来源于OSI的七层协议体系结构,但也适用于TCP/IP的四层体系结构和五层协议原理体系…

Qt中的单例模式

QT单例类管理信号和槽函数 Chapter1 Qt中的单例模式一、什么是单例模式?二、Qt中单例模式的实现2.1、静态成员变量2.2、静态局部变量2.3、Q_GLOBAL_STATIC 宏实例2 三、使用场景四、注意事项 Chapter2 QT单例类管理信号和槽函数一、创建单例类二、主界面添加组件三、…

原始航片匀色调色方法

使用PhotoRC 2.0软件,对原始航片进行批量匀色,可以自动处理和人机交互,保留exif信息。 软件下载链接: https://pan.baidu.com/s/1Jj4cMpq8xzYvSa1hhozH-g?pwdndfm 提取码:ndfm

Spring Boot中使用JSR-303实现请求参数校验

JSR-303是Java中的一个规范,用于实现请求参数校验。它定义了一组注解,可以应用于JavaBean的字段上,用于验证输入参数的合法性。下面是一些常用的JSR-303注解及其介绍: NotNull:用于验证字段值不能为null。 NotEmpty&a…

RT-Thread 5. ENV添加自定义模块

代码 /* file: hello.c */ #include <stdio.h> #include <finsh.h> #include <rtthread.h> int hello_world(void) {rt_kprintf("Hello, world!\n");return 0; } MSH_CMD_EXPORT(hello_world, Hello world!)/* file: hello.h */ #ifndef _HELLO_H…

03 vi编辑器

vi编辑器的三种模式: 不同的模式下机键动作解释的意义是不一样的 编辑模式 插入模式 末行模式 文件的打开和关闭保存 移动光标

专业135总分400+西安交通大学信息与通信工程学院909/815考研经验分享

今年初试发挥不错&#xff0c;400&#xff0c;专业课135&#xff0c;将近一年复习一路走来&#xff0c;感慨很多&#xff0c;希望以下经历可以给后来的同学提供一些参考。 初试备考经验 公共课&#xff1a;三门公共课&#xff0c;政治&#xff0c;英语&#xff0c;数学。在备考…

AST反混淆实战|找出某里滑块226混淆代码隐藏的字符串

关注它&#xff0c;不迷路。 本文章中所有内容仅供学习交流&#xff0c;不可用于任何商业用途和非法用途&#xff0c;否则后果自负&#xff0c;如有侵权&#xff0c;请联系作者立即删除&#xff01; 1. 常见的字符串 在还原控制流之后&#xff0c;接下来的动作就是还原字…

高效翻译工具GPT插件的使用教程

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

【语义分割】语义分割概念及算法介绍

文章目录 一、基本概念二、研究现状2.1 传统算法2.2 深度学习方法 三、数据集及评价指标3.1 常用数据集3.2 常用指标 四、经典模型参考资料 一、基本概念 语义分割是计算机视觉中很重要的一个方向。不同于目标检测和识别&#xff0c;语义分割实现了图像像素级的分类。它能够将…

【【萌新的FPGA学习之Vivado下的仿真入门-2】】

萌新的FPGA学习之Vivado下的仿真入门-2 我们上一章大概了解了 我们所需要进行各项操作的基本框架 对于内部实现其实一知半解 我们先从基本的出发 但从FPGA 了解一下 vivado下的仿真入门 正好帮我把自己的riscV 波形拉一下 行为级仿真 step1: 进入仿真界面&#xff1a;SIMULAT…

凉鞋的 Unity 笔记 204. 语句

204. 语句 在上一篇&#xff0c;我们接触了三种常见的类型&#xff0c;如下所示&#xff1a; 这样我们算是对变量进行了一个入门年了。 其实我们除了变量&#xff0c;我们还接触了一个叫做语句的概念。 我们可以看下代码&#xff1a; using System.Collections; using Syst…

四川云汇优想教育咨询有限公司电商服务正规吗

随着抖音等短视频平台的火热&#xff0c;越来越多的消费者选择在平台上购物。四川云汇优想教育咨询有限公司也推出了抖音电商服务&#xff0c;但它的服务是否正规呢&#xff1f;本文将为您揭开真相。 首先&#xff0c;我们先来了解一下四川云汇优想教育咨询有限公司。这是一家致…

基于Java的足球赛会管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09; 代码参考数据库参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&am…

C#,数值计算——分类与推理Phylo_wpgma的计算方法与源程序

1 文本格式 using System; using System.Collections.Generic; namespace Legalsoft.Truffer { public class Phylo_wpgma : Phylagglom { public override void premin(double[,] d, int[] nextp) { } public override double dminfn(double…

java基础篇-环境变量

java基础 编程学习的关键点、重点1.环境变量设置待续 编程学习的关键点、重点 输入输出 Java语言、C语言、Python语言、甚至SQL语言&#xff0c;都需要实战、做大量输入输出等 1.环境变量设置 1.下载jdk安装 jdk官网下载直达链接&#xff1a;https://www.oracle.com/java/te…

2.9.C++项目:网络版五子棋对战之业务处理模块的设计

文章目录 一、意义二、功能三、管理&#xff08;一&#xff09;客户端请求&#xff08;二&#xff09;websocket 四、框架五、完整代码 一、意义 将所有的模块整合在一起&#xff0c;通过网络通信获取到客户端的请求&#xff0c;提供不同的业务处理。 服务器模块&#xff0c;是…

类加载机制和双亲委派机制

文章目录 &#x1f4d5;我是廖志伟&#xff0c;一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO专家博主、阿里云专家博主、清华大学出版社签约作者、产品软文创造者、技术文章评审老师、问卷调查设计师、个人社区创始人、开源项目贡献者。&#x1f30e;跑过十五…

winodos下使用VS2022编译eclipse-paho.mqtt.c并演示简单使用的 demo

本文演示C语言如何使用eclipse-paho.mqtt.c库&#xff0c;包含自行编译库的步骤或者下载编译好的文件。 1.下载paho.mqtt.c库源码&#xff08;zip 文件&#xff09; 到官网选择C版本的paho源码进行下载 Eclipse Paho | The Eclipse Foundation 或者到下述连接下载 Releases ec…

docker在java项目中打成tar包

docker在java项目中打成tar包 1、首先安装一个docker desktop 2、mvn install项目后&#xff0c;建立一个自己的dockerfile 这里我以我的代码举例&#xff0c;from 镜像&#xff0c;这里你也能打包好一个镜像的基础上&#xff0c;from打好的镜像&#xff0c;这里我们用openj…