TensorFlow2实战-系列教程6:猫狗识别3------迁移学习

news2024/12/28 3:49:59

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

1、迁移学习

  • 用已经训练好模型的权重参数当做自己任务的模型权重初始化
  • 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层

一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练

2、猫狗识别

import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样

3、加载预训练模型

from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3

从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。

pre_trained_model = ResNet101(input_shape = (75, 75, 3),  
                                include_top = False, 
                                weights = 'imagenet')
  • 加载导入的模型
  • input_shape 为输入大小
  • include_top为False就是表示不要最后的全连接层
  • 这段代码执行后,会自动进行下载

downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step

for layer in pre_trained_model.layers:
    layer.trainable = False

选择要进行重新训练的层

4、callback模块

在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:

4.1 callback示例

callbacks = [
# 如果连续两个epoch还没降低就停止:
  tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:
  tf.keras.callbacks.LearningRateScheduler
# 保存模型:
  tf.keras.callbacks.ModelCheckpoint
# 自定义方法:
  tf.keras.callbacks.Callback
]

上面是一个模板,继续我们的猫狗识别的迁移学习项目:

4.2 定义callback

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc')>0.95):
            print("\nReached 95% accuracy so cancelling training!")
            self.model.stop_training = True
  1. 定义一个类,继承Callback
  2. 定义一个函数,传入epoch值和日志
  3. 从当前epoch的日志中取出准确率,如果准确率大于95%
  4. 打印信息
  5. 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense(1, activation='sigmoid')(x)           
model = Model(pre_trained_model.input, x) 
model.compile(optimizer = Adam(lr=0.001), 
              loss = 'binary_crossentropy', 
              metrics = ['acc'])
  1. 导入优化器
  2. 将预训练模型的输出展平为一维
  3. 定义一个1024的全连接层
  4. 在这层加入dropout
  5. 输出全连接层
  6. 构建模型
  7. 指定优化器、损失函数、验证方法等配置训练器

5、模型训练

定义需要重新训练的层

train_datagen = ImageDataGenerator(rescale = 1./255.,
                                   rotation_range = 40,
                                   width_shift_range = 0.2,
                                   height_shift_range = 0.2,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator( rescale = 1.0/255. )

train_generator = train_datagen.flow_from_directory(train_dir,
                                                    batch_size = 20,
                                                    class_mode = 'binary', 
                                                    target_size = (75, 75))     

validation_generator =  test_datagen.flow_from_directory( validation_dir,
                                                          batch_size  = 20,
                                                          class_mode  = 'binary', 
                                                          target_size = (75, 75))

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据

callbacks = myCallback()
history = model.fit_generator(
            train_generator,
            validation_data = validation_generator,
            steps_per_epoch = 100,
            epochs = 100,
            validation_steps = 50,
            verbose = 2,
            callbacks=[callbacks])

指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志

打印结果:

Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

展示
在这里插入图片描述
猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1422010.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java基础(面试用)

一、基本语法 1. 注释有哪几种形式? //单行注释:通常用于解释方法内某单行代码的作用。 //int i 0;//多行注释:通常用于解释一段代码的作用。 //int i 0; //int i 0;//文档注释:通常用于生成 Java 开发文档。 /* *int i 0; …

Springboot项目启动后浏览器不能直接访问接口,而postman可以访问?

在云服务器上部署springboot后端时,项目启动后浏览器不能直接访问接口,而postman可以访问。这是当时困扰了我大半天的小问题,在我打开防火墙和阿里云安全组之后还是没解决。然后在网上搜了很多很多资料,以为是浏览器访问权限或者是https什么证…

Oracle 的闪回技术是什么

什么是闪回 Oracle 数据库闪回技术是一组独特而丰富的数据恢复解决方案,能够有选择性地高效撤销一个错误的影响,从人为错误中恢复。闪回是一种数据恢复技术,它使得数据库可以回到过去的某个状态,可以满足用户的逻辑错误的快速恢复…

Walrus 实用教程|Walrus + Gitlab,打通CI/CD 自动化交付!

Walrus file 是 Walrus 0.5 版本推出的新功能,用户可以通过一个非常简洁的 YAML 描述应用或基础设施资源的部署配置,然后通过 Walrus CLI 执行 walrus apply或在 Walrus UI 上进行import,将 Walrus file 提交给 Walrus server,由 …

N65总账凭证管理凭证查询(sql)

--核算账簿 select code , name , pk_setofbook from org_setofbook where ( pk_setofbook in ( select pk_setofbook from org_accountingbook where 1 1 and ( pk_group N0001A11000000000037X ) and ( accountenablestate 2 ) ) ) order by code;--核算账簿 select code …

60、Flink CDC 入门介绍及Streaming ELT示例(同步Mysql数据库数据到Elasticsearch)-完整版

Flink 系列文章 一、Flink 专栏 Flink 专栏系统介绍某一知识点,并辅以具体的示例进行说明。 1、Flink 部署系列 本部分介绍Flink的部署、配置相关基础内容。 2、Flink基础系列 本部分介绍Flink 的基础部分,比如术语、架构、编程模型、编程指南、基本的…

Spring MVC-01

关于Spring MVC Spring MVC是基于Spring框架基础之上的,主要解决了后端服务器接收客户端提交的请求,并给予响应的相关问题。 MVC Model View Controller,它们分别是: - Model:数据模型,通常由业务逻辑层…

TypeScript(七) 函数

1. TypeScript 函数 1.1. 函数的定义 函数就是包裹在花括号中的代码块,前面使用关键字function。 语法: // An highlighted block function function_name() {// 执行代码 }实例: function test() { // 函数定义console.log("我就是…

Linux系统——文本三剑客

目录 一、grep 1.格式 2.选项 2.1 grep重定向 2.2grep -m 匹配到几次停止 2.3grep -i 忽略大小写 2.4grep -n 显示行号 2.5grep -c 统计匹配行数 2.6grep -A 后几行 2.7grep -C 前后三行 2.8grep -B 前三行 2.9grep -e 或 2.10grep -w 匹配整个单词 2.11grep -r…

聊聊百度造车

10月27日,极越-01上市,一个月后大幅降价,时至今日距离发布已经过去了两个月,官方迟迟不肯公布销量,实际情况大家也都心知肚明。 如今小米汽车技术发布会风头无两,而同一年宣布造车的极越却无人问津&#x…

ElementUI安装与使用指南

Element官网-安装指南 提醒一下:下面实例讲解是在Mac系统演示的; 一、开发环境配置 电脑需要先安装好node.js和vue2或者vue3 安装Node.js Node.js 中文网 安装node.js命令:brew install node node.js安装完后,输入&#xff1…

界面控件DevExtreme v23.2新版亮点 - 全新的Fluent主题

DevExtreme拥有高性能的HTML5 / JavaScript小部件集合,使您可以利用现代Web开发堆栈(包括React,Angular,ASP.NET Core,jQuery,Knockout等)构建交互式的Web应用程序。从Angular和Reac&#xff0c…

[力扣 Hot100]Day18 矩阵置零

题目描述 给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 出处 思路 在原数组上直接操作势必会出现“冗余”的0,即原本[i,j]处不是0,例如由于i行的其他位置有0导致[i,j]…

【深度学习每日小知识】Precision 精度

Precision 精度 精度是机器学习 (ML) 中分类器或预测器准确性的度量。它被定义为分类器做出的真正正预测数与分类器所做的正预测总数之比。换句话说,真正正确的是正面预测的比例。 精度是机器学习中的一个关键参数,因为它量化了…

Adobe Photoshop 2024 v25.4.0 - 专业的图片设计软件

Adobe Photoshop 2024 v25.4.0更新了,从照片编辑和合成到数字绘画、动画和图形设计,任何您能想象到的内容都能通过PS2024轻松实现。 利用人工智能技术进行快速编辑。学习新技能并与社区分享您的工作。借助我们的最新版本,做令人惊叹的事情从未…

Java中支持父类转子类,不支持子类转父类吗?

不,我的意思是正好相反。在 Java 中: 子类转父类(向上转型):这是自动的且总是安全的。子类是父类的一个特化,因此子类的对象可以被视为是父类的一个实例。例如,如果 ExamineApproveNode 是 Base…

【Tomcat与网络3】Tomcat的整体架构

目录 1.演进1:将连接和处理服务分开 2演进2:Container的演进 3 再论Tomcat的容器结构 4 Tomcat处理请求的过程 5 请求的处理过程与Pipeline-Valve管道 在前面我们介绍了Servlet的基本原理,本文我们结合Tomcat来分析一下如何设计一个大型…

小程序真机调试websocket错误码1006

有心跳机制 模拟器上都没问题连接稳定 一到真机就30秒断连 怎么解决救救我T_T

Flink CEP实现10秒内连续登录失败用户分析

1、什么是CEP? Flink CEP即 Flink Complex Event Processing,是基于DataStream流式数据提供的一套复杂事件处理编程模型。你可以把他理解为基于无界流的一套正则匹配模型,即对于无界流中的各种数据(称为事件),提供一种组合匹配的…