TensorFlow中slim包的具体用法

news2024/11/18 15:35:25

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim
    |-- BUILD
    |-- README.md
    |-- WORKSPACE
    |-- __init__.py
    |-- datasets
    |   |-- __init__.py
    |   |-- __pycache__
    |   |   |-- __init__.cpython-37.pyc
    |   |   |-- dataset_utils.cpython-37.pyc
    |   |   |-- download_and_convert_cifar10.cpython-37.pyc
    |   |   |-- download_and_convert_flowers.cpython-37.pyc
    |   |   `-- download_and_convert_mnist.cpython-37.pyc
    |   |-- build_imagenet_data.py
    |   |-- cifar10.py
    |   |-- dataset_factory.py
    |   |-- dataset_utils.py
    |   |-- download_and_convert_cifar10.py
    |   |-- download_and_convert_flowers.py
    |   |-- download_and_convert_imagenet.sh
    |   |-- download_and_convert_mnist.py
    |   |-- download_imagenet.sh
    |   |-- flowers.py
    |   |-- imagenet.py
    |   |-- imagenet_2012_validation_synset_labels.txt
    |   |-- imagenet_lsvrc_2015_synsets.txt
    |   |-- imagenet_metadata.txt
    |   |-- mnist.py
    |   |-- preprocess_imagenet_validation_data.py
    |   `-- process_bounding_boxes.py
    |-- deployment
    |   |-- __init__.py
    |   |-- model_deploy.py
    |   `-- model_deploy_test.py
    |-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式
    |-- eval_image_classifier.py        # 测试模型分类效果
    |-- export_inference_graph.py
    |-- export_inference_graph_test.py
    |-- nets
    |   |-- __init__.py
    |   |-- alexnet.py
    |   |-- alexnet_test.py
    |   |-- cifarnet.py
    |   |-- cyclegan.py
    |   |-- cyclegan_test.py
    |   |-- dcgan.py
    |   |-- dcgan_test.py
    |   |-- i3d.py
    |   |-- i3d_test.py
    |   |-- i3d_utils.py
    |   |-- inception.py
    |   |-- inception_resnet_v2.py
    |   |-- inception_resnet_v2_test.py
    |   |-- inception_utils.py
    |   |-- inception_v1.py
    |   |-- inception_v1_test.py
    |   |-- inception_v2.py
    |   |-- inception_v2_test.py
    |   |-- inception_v3.py
    |   |-- inception_v3_test.py
    |   |-- inception_v4.py
    |   |-- inception_v4_test.py
    |   |-- lenet.py
    |   |-- mobilenet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- conv_blocks.py
    |   |   |-- madds_top1_accuracy.png
    |   |   |-- mnet_v1_vs_v2_pixel1_latency.png
    |   |   |-- mobilenet.py
    |   |   |-- mobilenet_example.ipynb
    |   |   |-- mobilenet_v2.py
    |   |   `-- mobilenet_v2_test.py
    |   |-- mobilenet_v1.md
    |   |-- mobilenet_v1.png
    |   |-- mobilenet_v1.py
    |   |-- mobilenet_v1_eval.py
    |   |-- mobilenet_v1_test.py
    |   |-- mobilenet_v1_train.py
    |   |-- nasnet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- nasnet.py
    |   |   |-- nasnet_test.py
    |   |   |-- nasnet_utils.py
    |   |   |-- nasnet_utils_test.py
    |   |   |-- pnasnet.py
    |   |   `-- pnasnet_test.py
    |   |-- nets_factory.py
    |   |-- nets_factory_test.py
    |   |-- overfeat.py
    |   |-- overfeat_test.py
    |   |-- pix2pix.py
    |   |-- pix2pix_test.py
    |   |-- resnet_utils.py
    |   |-- resnet_v1.py
    |   |-- resnet_v1_test.py
    |   |-- resnet_v2.py
    |   |-- resnet_v2_test.py
    |   |-- s3dg.py
    |   |-- s3dg_test.py
    |   |-- vgg.py
    |   `-- vgg_test.py
    |-- preprocessing
    |   |-- __init__.py
    |   |-- cifarnet_preprocessing.py
    |   |-- inception_preprocessing.py
    |   |-- lenet_preprocessing.py
    |   |-- preprocessing_factory.py
    |   `-- vgg_preprocessing.py
    |-- scripts                     # gqr:存储的是相关的模型训练脚本                
    |   |-- export_mobilenet.sh
    |   |-- finetune_inception_resnet_v2_on_flowers.sh
    |   |-- finetune_inception_v1_on_flowers.sh
    |   |-- finetune_inception_v3_on_flowers.sh
    |   |-- finetune_resnet_v1_50_on_flowers.sh
    |   |-- train_cifarnet_on_cifar10.sh
    |   `-- train_lenet_on_mnist.sh
    |-- setup.py
    |-- slim_walkthrough.ipynb
    `-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e

# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径

# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50

# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径

# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
  mkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
  wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
  tar -xvf resnet_v1_50_2016_08_28.tar.gz
  mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
  rm resnet_v1_50_2016_08_28.tar.gz
fi

# Download the dataset
python download_and_convert_data.py \
  --dataset_name=flowers \
  --dataset_dir=${DATASET_DIR}

# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50 \
  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR} \
  --eval_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --checkpoint_path=${TRAIN_DIR} \
  --model_name=resnet_v1_50 \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR}/all \
  --eval_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \
  --train_dir=./data/flowers-models/resnet_v1_50\
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=./data/flowers \
  --model_name=resnet_v1_50 \
  --checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \ 
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/932729.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电商项目part07 订单系统的设计与海量数据处理

订单重复下单问题(幂等) 用户在点击“提交订单”的按钮时,不小心点了两下,那么浏览器就会向服务端连续发送两条创建订单的请求。这样肯定是不行的 解决办法是,让订单服务具备幂等性。什么是幂等性?幂等操作的特点是&a…

Vue2向Vue3过度Vue3组合式API

目录 1. Vue2 选项式 API vs Vue3 组合式API2. Vue3的优势3 使用create-vue搭建Vue3项目1. 认识create-vue2. 使用create-vue创建项目 4 熟悉项目和关键文件5 组合式API - setup选项1. setup选项的写法和执行时机2. setup中写代码的特点3. <script setup>语法糖 6 组合式…

Cpp学习——编译链接

目录 ​编辑 一&#xff0c;两种环境 二&#xff0c;编译环境下四个部分的 1.预处理 2.编译 3.汇编 4.链接 三&#xff0c;执行环境 一&#xff0c;两种环境 在程序运行时会有两种环境。第一种便是编译环境&#xff0c;第二种则是执行环境。如下图&#xff1a; 在程序运…

wxpython:wx.html2 是好用的 WebView 组件

wxpython : wx.html2 是好用的 WebView 组件。 pip install wxpython4.2 wxPython-4.2.0-cp37-cp37m-win_amd64.whl (18.0 MB) Successfully installed wxpython-4.2.0 cd \Python37\Scripts wxdemo.exe 取得 wxPython-demo-4.2.0.tar.gz wxdocs.exe 取得 wxPython-docs-4.…

Android平台RTMP|RTSP直播播放器功能进阶探讨

我们需要怎样的直播播放器&#xff1f; 很多开发者在跟我聊天的时候&#xff0c;经常问我&#xff0c;为什么一个RTMP或RTSP播放器&#xff0c;你们需要设计那么多的接口&#xff0c;真的有必要吗&#xff1f;带着这样的疑惑&#xff0c;我们今天聊聊Android平台RTMP、RTSP播放…

STM32f103入门(1) 配置点亮Led灯

1 安装keil5 MDK 双击 MDK524a.EXE安装成功后管理员模式打开CID复制到破解软件 选择ARM生成代码复制到New License ID CodeAdd LIC破解完毕 2安装stm32芯片 可找资料自行安装 如下 3 创建工程 Project->new project 本篇芯片为stm32f103保存到自定义文件夹下在根目录下…

JavaScript立即执行函数(自执行函数)的3种写法

一、什么是立即执行函数 顾名思义&#xff0c;声明一个函数并马上调用这个函数就叫做立即执行函数&#xff1b;也可以说立即执行函数是一种语法&#xff0c;让你的函数在定义以后立即执行&#xff1b;立即执行函数又叫做自执行函数。 二、立即执行函数的写法 立即执行函数的…

快速排序笔记

一、quick_sort方法中如果 il,jr 会死循环的分析 1、示例代码 void quick_sort(int a[],int l,int r){if(l>r) return;int il,jr; //此处设置会导致死循环int x num[(lr)>>1];while(i<j){while(a[i] <x); //死循环的地方while(a[--j] >x);if(i<j) swap(a…

RabbitMQ---订阅模型-Direct

1、 订阅模型-Direct • 有选择性的接收消息 • 在订阅模式中&#xff0c;生产者发布消息&#xff0c;所有消费者都可以获取所有消息。 • 在路由模式中&#xff0c;我们将添加一个功能 - 我们将只能订阅一部分消息。 例如&#xff0c;我们只能将重要的错误消息引导到日志文件…

网站常见安全漏洞 | 青训营

Powered by:NEFU AB-IN 文章目录 网站常见安全漏洞 | 青训营 网站基本组成及漏洞定义服务端漏洞**SQL注入****命令执行****越权漏洞****SSRF****文件上传漏洞** 客户端漏洞**开放重定向****XSS****CSRF****点击劫持****CORS跨域配置错误****WebSocket** 网站常见安全漏洞 | 青训…

软件架构业务及技术复杂度分析总结

目录 一、综述分析 二、业务复杂性分析 &#xff08;一&#xff09;领域建模 &#xff08;二&#xff09;领域分层 &#xff08;三&#xff09;服务粒度 &#xff08;四&#xff09;流程编排 三、技术复杂性分析 &#xff08;一&#xff09;高可用 底层逻辑 CAP原则 …

Mac OS 13.4.1 搜狗输入法导致的卡顿问题

一、Mac OS 系统版本 搜狗输入法已经更新到最新 二、解决方案 解决方案一 在我的电脑上面需要关闭 VSCode 和 Chrmoe 以后&#xff0c;搜狗输入法回复正常。 解决方案二 强制重启一下搜狗输入法。 可以用 unix 定时任务去隔 2个小时自动 kill 掉一次进程 # kill 掉 mac …

EWM怎么取消pinking,SAP_EWM取消拣配报错处理方式

EWM是SAP的一个模块&#xff0c;代表扩展仓库管理&#xff08;Extended Warehouse Management&#xff09;&#xff0c;是SAP企业资源计划&#xff08;ERP&#xff09;的一部分。它提供了一个完整的、高级的仓库管理解决方案&#xff0c;支持企业在全球范围内的仓库管理、订单管…

QGIS学习1-入门学习

QGIS作为一个广受欢迎的开源GIS&#xff0c;很多GIS的学生都了解过。但是因为学校老师都是教的Arcgis&#xff0c;因此很少去充分的学习。QGIS和arcgis一样&#xff0c;有完整的官方帮助文档&#xff0c;我也是要根据官方的帮助文档进行学习等。 https://www.qgis.org/zh-Hans/…

一文看懂Cat.1、Cat.4、NB-IOT、4G之间的区别

01 什么是Cat.1&#xff1f; Cat.1的全称是LTEUE-Category1&#xff0c;其中UE指的是用户设备&#xff0c;它是LTE网络下用户终端设备的无线性能的分类。根据3GPP的定义&#xff0c;UE类别以1-15分为15个等级。Cat.1&#xff0c;可以称为“低配版”的 4G 终端&#xff0c;上行…

微信小程序发布迭代版本后如何提示用户强制更新新版本

在点击小程序发布的时候选择&#xff0c;升级选项 之前用户使用过的再打开小程序页面就会弹出升级弹窗modal

企业博客搭建:经营好企业博客,能让你的业务蹭蹭上涨!

企业博客本身作为企业产品知识的沉淀&#xff0c;搭建并且经营好企业博客不仅有利于企业文化建设&#xff0c;更可以利用博客来推动业务增长。 何谓企业博客营销&#xff1f;简单地说&#xff0c;就是利用HelpLook这种工具创建并开展网络营销活动&#xff0c;称之为博客营销。 …

Linux学习之nginx虚拟域名主机,lsof和netstat命令查看端口是否被监听

需要先参考我的博客《Linux学习之Ubuntu 20.04在https://openresty.org下载源码安装Openresty 1.19.3.1&#xff0c;使用systemd管理OpenResty服务》安装好Openresty。 虚拟域名可以使用让不同的域名访问到同一台主机。 cd /usr/local/openresty切换当前访问目录到/usr/local/o…

stm32之USART(总结)

串行通信 UART串口内部结构示意图 普中科技的详细介绍 中断知识补充 代码 #ifndef __USART_H #define __USART_H #include "stdio.h" #include "stm32f10x_usart.h" #define USART1_REC_LEN 200 //定义最大接收字节数 200extern u8 USART1_RX_BUF[US…

LeetCode--HOT100题(42)

目录 题目描述&#xff1a;108. 将有序数组转换为二叉搜索树&#xff08;简单&#xff09;题目接口解题思路代码 PS: 题目描述&#xff1a;108. 将有序数组转换为二叉搜索树&#xff08;简单&#xff09; 给你一个整数数组 nums &#xff0c;其中元素已经按 升序 排列&#xf…