TensorFlow-slim包进行图像数据集分类---具体流程

news2024/10/5 21:24:29

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim
    |-- BUILD
    |-- README.md
    |-- WORKSPACE
    |-- __init__.py
    |-- datasets
    |   |-- __init__.py
    |   |-- __pycache__
    |   |   |-- __init__.cpython-37.pyc
    |   |   |-- dataset_utils.cpython-37.pyc
    |   |   |-- download_and_convert_cifar10.cpython-37.pyc
    |   |   |-- download_and_convert_flowers.cpython-37.pyc
    |   |   `-- download_and_convert_mnist.cpython-37.pyc
    |   |-- build_imagenet_data.py
    |   |-- cifar10.py
    |   |-- dataset_factory.py
    |   |-- dataset_utils.py
    |   |-- download_and_convert_cifar10.py
    |   |-- download_and_convert_flowers.py
    |   |-- download_and_convert_imagenet.sh
    |   |-- download_and_convert_mnist.py
    |   |-- download_imagenet.sh
    |   |-- flowers.py
    |   |-- imagenet.py
    |   |-- imagenet_2012_validation_synset_labels.txt
    |   |-- imagenet_lsvrc_2015_synsets.txt
    |   |-- imagenet_metadata.txt
    |   |-- mnist.py
    |   |-- preprocess_imagenet_validation_data.py
    |   `-- process_bounding_boxes.py
    |-- deployment
    |   |-- __init__.py
    |   |-- model_deploy.py
    |   `-- model_deploy_test.py
    |-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式
    |-- eval_image_classifier.py        # 测试模型分类效果
    |-- export_inference_graph.py
    |-- export_inference_graph_test.py
    |-- nets
    |   |-- __init__.py
    |   |-- alexnet.py
    |   |-- alexnet_test.py
    |   |-- cifarnet.py
    |   |-- cyclegan.py
    |   |-- cyclegan_test.py
    |   |-- dcgan.py
    |   |-- dcgan_test.py
    |   |-- i3d.py
    |   |-- i3d_test.py
    |   |-- i3d_utils.py
    |   |-- inception.py
    |   |-- inception_resnet_v2.py
    |   |-- inception_resnet_v2_test.py
    |   |-- inception_utils.py
    |   |-- inception_v1.py
    |   |-- inception_v1_test.py
    |   |-- inception_v2.py
    |   |-- inception_v2_test.py
    |   |-- inception_v3.py
    |   |-- inception_v3_test.py
    |   |-- inception_v4.py
    |   |-- inception_v4_test.py
    |   |-- lenet.py
    |   |-- mobilenet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- conv_blocks.py
    |   |   |-- madds_top1_accuracy.png
    |   |   |-- mnet_v1_vs_v2_pixel1_latency.png
    |   |   |-- mobilenet.py
    |   |   |-- mobilenet_example.ipynb
    |   |   |-- mobilenet_v2.py
    |   |   `-- mobilenet_v2_test.py
    |   |-- mobilenet_v1.md
    |   |-- mobilenet_v1.png
    |   |-- mobilenet_v1.py
    |   |-- mobilenet_v1_eval.py
    |   |-- mobilenet_v1_test.py
    |   |-- mobilenet_v1_train.py
    |   |-- nasnet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- nasnet.py
    |   |   |-- nasnet_test.py
    |   |   |-- nasnet_utils.py
    |   |   |-- nasnet_utils_test.py
    |   |   |-- pnasnet.py
    |   |   `-- pnasnet_test.py
    |   |-- nets_factory.py
    |   |-- nets_factory_test.py
    |   |-- overfeat.py
    |   |-- overfeat_test.py
    |   |-- pix2pix.py
    |   |-- pix2pix_test.py
    |   |-- resnet_utils.py
    |   |-- resnet_v1.py
    |   |-- resnet_v1_test.py
    |   |-- resnet_v2.py
    |   |-- resnet_v2_test.py
    |   |-- s3dg.py
    |   |-- s3dg_test.py
    |   |-- vgg.py
    |   `-- vgg_test.py
    |-- preprocessing
    |   |-- __init__.py
    |   |-- cifarnet_preprocessing.py
    |   |-- inception_preprocessing.py
    |   |-- lenet_preprocessing.py
    |   |-- preprocessing_factory.py
    |   `-- vgg_preprocessing.py
    |-- scripts                     # gqr:存储的是相关的模型训练脚本                
    |   |-- export_mobilenet.sh
    |   |-- finetune_inception_resnet_v2_on_flowers.sh
    |   |-- finetune_inception_v1_on_flowers.sh
    |   |-- finetune_inception_v3_on_flowers.sh
    |   |-- finetune_resnet_v1_50_on_flowers.sh
    |   |-- train_cifarnet_on_cifar10.sh
    |   `-- train_lenet_on_mnist.sh
    |-- setup.py
    |-- slim_walkthrough.ipynb
    `-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e

# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径

# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50

# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径

# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
  mkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
  wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
  tar -xvf resnet_v1_50_2016_08_28.tar.gz
  mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
  rm resnet_v1_50_2016_08_28.tar.gz
fi

# Download the dataset
python download_and_convert_data.py \
  --dataset_name=flowers \
  --dataset_dir=${DATASET_DIR}

# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50 \
  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR} \
  --eval_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --checkpoint_path=${TRAIN_DIR} \
  --model_name=resnet_v1_50 \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR}/all \
  --eval_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \
  --train_dir=./data/flowers-models/resnet_v1_50\
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=./data/flowers \
  --model_name=resnet_v1_50 \
  --checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \ 
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–checkpoint_path:预训练模型存放地址
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

注意:在模型训练前,需要下载预训练模型,
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

解压后存放在相应目录

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/946724.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【FreeRTOS】【应用篇】消息队列【下篇】

前言 本篇文章主要对 FreeRTOS 中消息队列的概念和相关函数进行了详解消息队列【下篇】详细剖析了消息队列中发送、接收时队列消息控制块中各种指针的行为,以及几个发送消息和接收消息的函数的运作流程笔者有关于 【FreeRTOS】【应用篇】消息队列【上篇】——队列基…

文件属性与目录

目录 Linux 系统中的文件类型普通文件目录文件字符设备文件和块设备文件符号链接文件管道文件套接字文件总结 stat 函数struct stat 结构体st_mode 变量struct timespec 结构体练习 fstat 和lstat 函数fstat 函数lstat 函数 文件属主有效用户ID 和有效组IDchown 函数fchown 和l…

WebGL矩阵变换库

目录 矩阵变换库: Matrix4对象所支持的方法和属性如表所示: 方法属性规范: 虽然平移、旋转、缩放等变换操作都可以用一个44的矩阵表示,但是在写WebGL程序的时候,手动计算每个矩阵很耗费时间。为了简化编程&#xf…

74 # koa 的基本使用

koa 是对 http 的一个封装,实现了一个 node 框架,可以根据这个框架实现自己的 MVC 框架。 每个人用 koa 的方式都大不一样,无法做到约定性,所以才会有 egg 基于 koa 封装的约定性的框架。 安装 npm init -y npm install koa使用…

502 bad gateway什么意思502 bad gateway问题解决办法

502 bad gateway是一种常见互联网连接错误,大部分情况就是打不开页面,连接不上网络,访问服务器挂了等问题,下面来看看具体解决方法,希望能够帮助你解决问题。 502 bad gateway什么意思 简单说就是服务器没有收到回应&…

LINQ-查询表达式

文章速览 概述使用注意查询子句实例 概述 LINQ是一组技术的名称,这些技术建立在将查询功能直接集成到C#语言(以及Visual Basic和可能的任何其他.NET语言)的基础上。借助于LINQ,查询已是高级语言构造,就如同类、方法和…

Ubuntu 18.04上无法播放MP4格式视频解决办法

ubuntu18.04系统无法播放MP4格式视频,提示如下图所示: 解决办法: 1、首先,确保ubuntu系统已完全更新。可使用以下命令更新软件包列表:sudo apt update,然后使用以下命令升级所有已安装的软件包&#xff1a…

poi-tl设置图片(通过word模板替换关键字,然后转pdf文件并下载)

选中图片右击 选择设置图片格式 例如word模板 maven依赖 <!-- java 读取word文件里面的加颜色的字体 转pdf 使用 --><dependency><groupId> e-iceblue </groupId><artifactId>spire.doc.free</artifactId><version>3.9.0</ver…

数据的语言:学习数据可视化的实际应用

数据可视化应该学什么&#xff1f;这是一个在信息时代越来越重要的问题。随着数据不断增长和积累&#xff0c;从社交媒体到企业业务&#xff0c;从科学研究到医疗健康&#xff0c;我们都面临着海量的数据。然而&#xff0c;数据本身往往是冰冷、抽象的数字&#xff0c;对于大多…

03-基础例程3

基础例程3 01、外部中断 ESP32的外部中断有上升沿、下降沿、低电平、高电平触发模式。 实验目的 使用外部中断功能实现按键控制LED的亮灭 按键按下为0。【即下降沿】 * 接线说明&#xff1a;按键模块-->ESP32 IO* (K1-K4)-->(14,27,26,25)* * …

安全生产作业现场违规行为识别 opencv

安全生产作业现场违规行为识别算法通过pythonopencv网络模型算法框架设定了各种合规行为和违规行为的模型&#xff0c;安全生产作业现场违规行为识别算法检测到违规行为&#xff0c;将立即进行抓拍并发送告警信息给相关人员&#xff0c;以便及时采取相应的处置措施。OpenCV是一…

HT for Web (Hightopo) 使用心得(6)- 3D场景环境配置(天空球,雾化,辉光,景深)

在前一篇文章《Hightopo 使用心得&#xff08;5&#xff09;- 动画的实现》中&#xff0c;我们将一个直升机模型放到了3D场景中。同时&#xff0c;还利用动画实现了让该直升机围绕山体巡逻。在这篇文章中&#xff0c;我们将对上一篇的场景进行一些环境上的丰富与美化。让场景更…

02-基础例程2

基础例程2 01、直流电机 对于UL2003来说&#xff0c;可以看作是非门 输入为1&#xff0c;输出为0&#xff1b; 输入为0&#xff0c;输出为高组态[接一个上拉电阻即为1] 线路的连接 4路是用做步进电机的 对于直流电机&#xff0c;只需要一路。即异端接VCC&#xff0c;一段接…

unordered-------Hash

✅<1>主页&#xff1a;我的代码爱吃辣&#x1f4c3;<2>知识讲解&#xff1a;数据结构——哈希表☂️<3>开发环境&#xff1a;Visual Studio 2022&#x1f4ac;<4>前言&#xff1a;哈希是一种映射的思想&#xff0c;哈希表即使利用这种思想&#xff0c;…

基于野马算法优化的BP神经网络(预测应用) - 附代码

基于野马算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码 文章目录 基于野马算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码1.数据介绍2.野马优化BP神经网络2.1 BP神经网络参数设置2.2 野马算法应用 4.测试结果&#xff1a;5.Matlab代码 摘要…

【深入浅出设计模式--状态模式】

深入浅出设计模式--状态模式 一、背景二、问题三、解决方案四、 适用场景总结五、后记 一、背景 状态模式是一种行为设计模式&#xff0c;让你能在一个对象的内部状态变化时改变其行为&#xff0c;使其看上去就像改变了自身所属的类一样。其与有限状态机的概念紧密相关&#x…

RT-Thread在STM32硬件I2C的踩坑记录

RT-Thread在STM32硬件I2C的踩坑记录 0.前言一、软硬件I2C区别二、RT Thread中的I2C驱动三、尝试适配硬件I2C四、i2c-bit-ops操作函数替换五、Attention Please!六、总结 参考文章&#xff1a; 1.将硬件I2C巧妙地将“嫁接”到RTT原生的模拟I2C驱动框架 2.基于STM32F4平台的硬件I…

flutter 上传图片并裁剪

1.首先在pubspec.yaml文件中新增依赖pub.dev image_picker: ^0.8.75 image_cropper: ^4.0.1 2.在Android的AndroidManifest.xml文件里面添加权限 <activityandroid:name"com.yalantis.ucrop.UCropActivity"android:screenOrientation"portrait"andro…

Golang网络编程

互联网协议介绍引入 1. 物理层&#xff08;Physical Layer&#xff09;&#xff1a; - 功能&#xff1a;物理层负责定义物理介质传输数据的方式和规范&#xff0c;它传输的是原始数据比特流。 - 协议&#xff1a;Ethernet、Wi-Fi、USB、光纤等。 - 例子&#xff1a;将…

[javaWeb]Socket网络编程

网络编程&#xff1a;写一个应用程序,让这个程序可以使用网络通信。这里就需要调用传输层提供的 api。 Socket套接字 传输层提供协议&#xff0c;主要是两个: UDP和TCP 提供了两套不同的 api&#xff0c;这api也叫做socket api。 UDP和 TCP 特点对比&#xff1a; UDP: 无连…