TensorFlow2实战-系列教程6:迁移学习实战

news2025/1/12 6:00:49

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、迁移学习

  • 用已经训练好模型的权重参数当做自己任务的模型权重初始化
  • 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层

一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练

2、猫狗识别

import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')

validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样

3、加载预训练模型

from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3

从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。

pre_trained_model = ResNet101(input_shape = (75, 75, 3),  
                                include_top = False, 
                                weights = 'imagenet')
  • 加载导入的模型
  • input_shape 为输入大小
  • include_top为False就是表示不要最后的全连接层
  • 这段代码执行后,会自动进行下载

downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step

for layer in pre_trained_model.layers:
    layer.trainable = False

选择要进行重新训练的层

4、callback模块

在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:

4.1 callback示例

callbacks = [
# 如果连续两个epoch还没降低就停止:
  tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:
  tf.keras.callbacks.LearningRateScheduler
# 保存模型:
  tf.keras.callbacks.ModelCheckpoint
# 自定义方法:
  tf.keras.callbacks.Callback
]

上面是一个模板,继续我们的猫狗识别的迁移学习项目:

4.2 定义callback

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('acc')>0.95):
            print("\nReached 95% accuracy so cancelling training!")
            self.model.stop_training = True
  1. 定义一个类,继承Callback
  2. 定义一个函数,传入epoch值和日志
  3. 从当前epoch的日志中取出准确率,如果准确率大于95%
  4. 打印信息
  5. 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense(1, activation='sigmoid')(x)           
model = Model(pre_trained_model.input, x) 
model.compile(optimizer = Adam(lr=0.001), 
              loss = 'binary_crossentropy', 
              metrics = ['acc'])
  1. 导入优化器
  2. 将预训练模型的输出展平为一维
  3. 定义一个1024的全连接层
  4. 在这层加入dropout
  5. 输出全连接层
  6. 构建模型
  7. 指定优化器、损失函数、验证方法等配置训练器

5、模型训练

定义需要重新训练的层

train_datagen = ImageDataGenerator(rescale = 1./255.,
                                   rotation_range = 40,
                                   width_shift_range = 0.2,
                                   height_shift_range = 0.2,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator( rescale = 1.0/255. )

train_generator = train_datagen.flow_from_directory(train_dir,
                                                    batch_size = 20,
                                                    class_mode = 'binary', 
                                                    target_size = (75, 75))     

validation_generator =  test_datagen.flow_from_directory( validation_dir,
                                                          batch_size  = 20,
                                                          class_mode  = 'binary', 
                                                          target_size = (75, 75))

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据

callbacks = myCallback()
history = model.fit_generator(
            train_generator,
            validation_data = validation_generator,
            steps_per_epoch = 100,
            epochs = 100,
            validation_steps = 50,
            verbose = 2,
            callbacks=[callbacks])

指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志

打印结果:

Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

展示
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1417579.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

读AI3.0笔记08_自然语言

1. 人工智能研究的惯用的套路 1.1. 定义一个在细分领域中比较有用的任务 1.2. 收集一个大型数据集来测试机器在该任务上的性能 1.3. 对人类在该数据集上完成任务的能力进行一个有限的度量 1.4. 建立一场竞赛使得人工智能系统可以在该数据集上互相竞争 1.5. 直到最终达到或…

内存储器之只读存储器(ROM),随机存取存储器(RAM)和Cache详解

内存储器 计算机中的存储器分为内存和外存两大类。 内存的存取速度快而容量相对较小,它与CPU直接相连,用来存放等待CPU运行的程序和处理的数据;外存的速度较慢而容量相对很大,它与CPU并不直接连接,用于永久性地存放计…

网络安全04-sql注入靶场第一关

目录 一、环境准备 1.1我们进入第一关也如图: ​编辑 二、正式开始第一关讲述 2.1很明显它让我们在标签上输入一个ID,那我们就输入在链接后面加?id1 ​编辑 2.2链接后面加个单引号()查看返回的内容,127.0.0.1/sqli/less-1/?id1,id1 …

Unity 迭代器模式(实例详解)

文章目录 简介**实例1:遍历数组****实例2:自定义迭代器类****实例3:异步加载资源****实例4:游戏关卡序列****实例5:无限生成敌人** 简介 在Unity中,虽然不直接使用迭代器模式的原始定义(即设计…

PDF标准详解(一)——PDF文档结构

已经很久没有写博客记录自己学到的一些东西了。但是在过去一年的时间中自己确实又学到了一些东西。一直攒着没有系统化成一篇篇的文章,所以今年的博客打算也是以去年学到的一系列内容为主。通过之前Vim系列教程的启发,我发现还是写一些系列文章对自己的帮…

Springmvc-@RequestBody

SpringBoot-2.7.12 请求的body参数无法转换,服务端没有报错信息打印,而是响应的状态码是400 PostMapping("/static/user") public User userInfo(RequestBody(required false) User user){user.setAge(19);return user; }PostMapping("…

05 Redis之Benchmark+简单动态字符串SDS+集合的底层实现

3.8 Benchmark Redis安装完毕后会自动安装一个redis-benchmark测试工具,其是一个压力测试工具,用于测试 Redis 的性能。 src目录下可找到该工具 通过 redis-benchmark –help 命令可以查看到其用法 3.8.1 测试1 3.9 简单动态字符串SDS 无论是 Redis …

【面试】测试开发面试题

帝王之气,定是你和万里江山,我都护得周全 文章目录 前言1. 网络原理get与post的区别TCP/IP各层是如何传输数据的IP头部包含哪些内容TCP头部为什么有浮动网络层协议1. 路由协议2. 路由信息3. OSPF与RIP的区别Cookie与Session,Token的区别http与…

Redis学习——高级篇①

Redis学习——高级篇① Redis7高级之单线程和多线程(一) 一、Redis单线程VS多线程1.Redis的单线程部分1.1 Redis为什么是单线程?1.2 Redis所谓的“单线程”1.3 Redis演进变化1.3.1 Redis 3.x 单线程时代性能很快的原因1.3.2…

林浩然科学趣谈:妙解麦克斯韦方程的电磁奥秘

林浩然科学趣谈:妙解麦克斯韦方程的电磁奥秘 Lin Haoran’s Scientific Banter: Playful Insights into the Electromagnetic Mysteries of Maxwell’s Equations 在科学的璀璨星河中,林浩然如同一颗热爱探索的行星,以其独特的幽默和严谨的态…

latent-diffusion model环境配置--我转载的

latent-diffusion model环境配置,这可能是你能够找到的最细的博客了_latent diffusion model 训练 autoencoder-CSDN博客 前言 最近在研究diffusion模型,并对目前最火的stable-diffusion模型很感兴趣,又因为stable-diffusion是一种latent-di…

opencv#35 连通域分析

连通域分割原理 像素领域介绍: 4邻域是指中心的像素与它邻近的上下左右一共有4个像素,那么称这4个像素为中心像素的4邻域。 8邻域是以中心像素周围的8个像素分别是上下左右和对角线上的4个像素。 连通域的定义(分割)分为两种:以4邻域为相邻判定条件的连通域分割和…

C++笔记之RTTI、RAII、MVC、MVVM、SOLID在C++中的表现

C++笔记之RTTI、RAII、MVC、MVVM、SOLID在C++中的表现 —— 杭州 2024-01-28 code review! 文章目录 C++笔记之RTTI、RAII、MVC、MVVM、SOLID在C++中的表现1.RTTI、RAII、MVC、MVVM、SOLID简述2.RAII (Resource Acquisition Is Initialization)3.RTTI (Run-Time Type Informat…

steam幻兽帕鲁服务器配置费用报价,4核16G

幻兽帕鲁服务器价格多少钱?4核16G服务器Palworld官方推荐配置,阿里云4核16G服务器32元1个月、96元3个月,腾讯云换手帕服务器服务器4核16G14M带宽66元一个月、277元3个月,8核32G22M配置115元1个月、345元3个月,16核64G3…

某度Pan复活,突破限速,很强大!

软件简介: 软件【下载地址】获取方式见文末。注:推荐使用,更贴合此安装方法! 作为国内领先的云存储服务提供商之一,某度Pan为用户提供了一个便捷的文件存储和分享平台。然而,用户普遍反映某度Pan的下载速…

简盒工具箱iapp源码

一款工具箱兼做软件库。 新增远程更新功能 修复了部分失效功能 修复了偶尔会卡在启动页的情况 源码下载:https://download.csdn.net/download/m0_66047725/88776737 更多资源下载:关注我。

漏洞原理MySql注入 Windows中Sqlmap 工具的使用

漏洞原理MySql注入 SQLmap是一款开源的自动化SQL注入工具,用于检测和利用Web应用程序中的SQL注入漏洞。以下是SQLmap工具的使用总结: 安装和配置:首先需要下载并安装SQLmap工具。安装完成后,可以通过命令行界面或图形用户界面来使…

2024幻兽帕鲁服务器,阿里云配置

阿里云幻兽帕鲁服务器Palworld服务器推荐4核16G配置,可以选择通用型g7实例或通用算力型u1实例,ECS通用型g7实例4核16G配置价格是502.32元一个月,算力型u1实例4核16G是432.0元/月,经济型e实例是共享型云服务器,价格是32…

案例分享:长沙红胖子公司内部评估高清内窥镜功能列表流程产出成果鉴赏

若该文为原创文章,转载请注明出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/135898723 红胖子(红模仿)的博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬结…

多地多活与单元化架构

多地多活与单元化架构 背景 在业务发展到一定阶段之后,任何因故障而导致的服务中断都会带来巨大的损失。为了提高系统的伸缩能力与高可用能力,我们都不断的在努力消除系统单点瓶颈。如使用应用集群是为了解决服务层的单点问题,使用主从数据…