29- 迁移学习 (TensorFlow系列) (深度学习)

news2025/1/23 9:26:44

知识要点

  • 迁移学习: 使用别人预训练模型参数时,要注意别人的预处理方式。

  • 常见的迁移学习方式:

    • 载入权重后训练所有参数.
    • 载入权重后只训练最后几层参数.
    • 载入权重后在原网络基础上再添加一层全连接层仅训练最后一个全连接层.
  • 训练数据是 10_monkeys 数据: 10种猴子的图片集
  • 图片显示: plt.imshow(mokey)
  • 读取图片: mokey = plt.imread('./50.jpg')

导入resnet 模型:

  • resnet50 = keras.applications.ResNet50(include_top=False, pooling='avg')      # 导入模型
  • model = keras.models.Sequential()      # 开始建模
  • model.add(resnet50)     # 添加resnet 网络
  • model.add(keras.layers.Dense(num_classes=10, activation = 'softmax'))     # 添加全连接层
  • model.layers[0].trainable = False     # 除了最后一个全连接层, 其余部分参数不变
  • 模型配置:
model.compile(loss = 'categorical_crossentropy',
              optimizer = 'adam',
              metrics = ['acc'])
  • valid_datagen = keras.preprocessing.image.ImageDataGenerator(preprocessing_function = keras.applications.resnet50.preprocess_input)    # 数据初始化处理
  • 指定后面几层参数变化:
# 切片指定, 不可调整的层数
for layer in resnet50.layers[0:-5]:
    layer.trainable = False


一 迁移学习

1.1 简介

使用迁移学习的优势:

  • 能够快速的训练出一个理想的结果
  • 当数据集较小时也能训练出理想的效

注意:使用别人预训练模型参数时,要注意别人的预处理方式。

1.2 常见迁移方式

常见的迁移学习方式:

  • 载入权重后训练所有参数.
  • 载入权重后只训练最后几层参数.
  • 载入权重后在原网络基础上再添加一层全连接层仅训练最后一个全连接层.

二 代码实现

2.1 导包

from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

cpu=tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpu)
print(tf.config.list_logical_devices())

2.2 迁移模型  (在迁移模型 后加一层)

resnet50 = keras.applications.ResNet50(include_top=False, pooling='avg')

num_classes =10
model = keras.models.Sequential()
model.add(resnet50)
model.add(keras.layers.Dense(num_classes, activation = 'softmax'))
model.summary()

2.3 配置模型 (除最后一层外, 其余参数全部冻结)

# 把除最后一层的参数外, 全部冻结
model.layers[0].trainable = False
model.compile(loss = 'categorical_crossentropy',
              optimizer = 'adam',
              metrics = ['acc'])

2.4 导入数据

train_dir = '../day 48 resnet/training/training/'
valid_dir = '../day 48 resnet/validation/validation/'
  • 原始数据处理
train_datagen = keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function = keras.applications.resnet50.preprocess_input,
    rotation_range = 40,
    width_shift_range = 0.2,
    height_shift_range = 0.2,
    shear_range = 0.2,
    zoom_range = 0.2,
    horizontal_flip = True,
    vertical_flip = True,
    fill_mode = 'nearest')

height = 224
width = 224
channels = 3
batch_size = 32
num_classes = 10

train_generator = train_datagen.flow_from_directory(train_dir,
                                                    target_size= (height, width),
                                                    batch_size = batch_size,
                                                    shuffle= True,
                                                    seed = 7,
                                                    class_mode= 'categorical')

valid_datagen = keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function = keras.applications.resnet50.preprocess_input)

valid_generator = valid_datagen.flow_from_directory(valid_dir,
                                                    target_size= (height, width),
                                                    batch_size= batch_size,
                                                    shuffle= True,
                                                    seed = 7,
                                                    class_mode= 'categorical')
print(train_generator.samples)   # 1098
print(valid_generator.samples)   # 272

2.5 模型训练

# 使用迁移学习, 效果较差, 原始数据的处理方式不同
# 修改需处理方式继续执行, 效果较好
histroy = model.fit(train_generator,
                          steps_per_epoch= train_generator.samples // batch_size,
                          epochs = 10,
                          validation_data = valid_generator,
                          validation_steps= valid_generator.samples // batch_size)

 2.6 训练后面几层神经网络参数

resnet50 = keras.applications.ResNet50(include_top=False, pooling='avg', weights='imagenet')
# 切片指定, 不可调整的层数
for layer in resnet50.layers[0:-5]:
    layer.trainable = False
    
# 添加输出层
resnet50_new = keras.models.Sequential([resnet50, keras.layers.Dense(10, activation = 'softmax')])
resnet50_new.compile(loss = 'categorical_crossentropy',
                     optimizer = 'adam',
                     metrics = ['acc'])
resnet50_new.summary()

histroy = resnet50_new.fit(train_generator,
                          steps_per_epoch= train_generator.samples // batch_size,
                          epochs = 10,
                          validation_data = valid_generator,
                          validation_steps= valid_generator.samples // batch_size)

三 图片处理查看

3.1 图片显示

# 预测数据
mokey = plt.imread('./n5020.jpg')
plt.imshow(mokey)    # mokey.shape  (600, 336, 3)

for i in range(2):
    x, y = train_generator.next()
    print(type(x), type(y))    # <class 'numpy.ndarray'> <class 'numpy.ndarray'>
    print('***', x.shape, y.shape)  # *** (32, 224, 224, 3) (32, 10)

3.2 尺寸变换

# 主要是形状和尺寸不对
# 改变尺寸, 再改变形状reshape
from scipy import ndimage  # 专门处理图片
# 改变形状
# 224 = 367 * x  x = 224/367
# 224 = 550 * y  y = 224/550
zoom = (224/mokey.shape[0], 224/mokey.shape[1])
monkey_zoomed = ndimage.zoom(mokey, (224/mokey.shape[0], 224/mokey.shape[1], 1))
monkey_zoomed.shape   # (224, 224, 3)
monkey_1 = keras.applications.resnet50.preprocess_input(monkey_zoomed)
monkey_1.min()    # -123.68
monkey_1 = monkey_1.reshape(1, 224, 224, 3)
model.predict(monkey_1).argmax(axis = 1)  # array([5], dtype=int64)

3.3 resnet 图片处理方式

3.3.1 前景查看

mokey1 = mokey/127.5
plt.imshow(mokey1)

3.3.2 背景查看

mokey1 = mokey1 - 1
plt.imshow(mokey1)

mokey1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/389476.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

滚蛋吧,正则表达式!

大家好&#xff0c;我是良许。 不知道大家有没有被正则表达式支配过的恐惧&#xff1f;看着一行火星文一样的表达式&#xff0c;虽然每一个字符都认识&#xff0c;但放在一起直接就让人蒙圈了~ 你是不是也有这样的操作&#xff0c;比如你需要使用「电子邮箱正则表达式」&…

面试之String、StringBuffer、StringBuilder区别

String、StringBuffer、StringBuilder区别 (1)是否可变 string对象不可变&#xff1b; StringBuffer、StringBuilder继承自AbstractStringBuilder类&#xff0c;实现原理都基于可修改的char数组&#xff0c;默认大小为16 (2)线程安全性 string中的对象不可变&#xff0c;可…

Java中String类intern()详解

1、背景在开发过程中很多朋友&#xff0c;由于不会正确使用intern()&#xff0c;导致开发的程序&#xff0c;执行效率比较差。同时最近发现一道非常有意思的关于intern()的面试题&#xff0c;这道面试题还是有不小的难度&#xff0c;相信很多朋友看到以后也不知道怎么解答&…

c++类与对象整理(上)

目录 1.类的引入 2.类的定义 3.类的访问限定符及封装 1&#xff09;访问限定符 2&#xff09;封装 4.类的作用域 5.类的实例化 6.类的对象大小的计算 1&#xff09;类对象的存储方式 2&#xff09;内存对齐和大小计算 ​编辑 7.类成员函数的this指针 1&#xff09…

linux配置网络详解

linux配置网络详解 文章目录linux配置网络详解前置准备配置流程错误排查前置准备 确定是否有网&#xff0c;比如在家里&#xff0c;确定是否连上网线&#xff1f;确定这个网线的网关是什么&#xff1f;&#xff08;这个需要和给你办网的人确定&#xff09;&#xff0c;在公司的…

超详细JDK1.8所有版本下载地址

JDK1.8即为JDK8&#xff0c;JDK8是目前是最成熟最稳定的版本&#xff0c;本文将详细介绍JDK1.8历史版本的下载方式。 在此附上JDK1.8安装与配置教程 超详细JDK1.8安装与配置 一、JDK官网 首先打开oracle官网&#xff0c;官网首页地址为 JDK官网首页地址 点击Products 点击…

Kotlin实现简单的学生信息管理系统

文章目录一、实验内容二、实验步骤1、页面布局2、数据库3、登录活动4、增删改查三、运行演示四、实验总结五、源码下载一、实验内容 根据Android数据存储的内容&#xff0c;综合应用SharedPreferences和SQLite数据库实现一个用户信息管理系统&#xff0c;强化对SharedPreferen…

ks通过恶意低绩效来变相裁员(六)各方核心利益点分析

目录 公司利益点 管理层利益点 直接管理者利益点 一线干活的同学 一线嫡系同学 公司利益点 核心利益点&#xff1a;围绕财报营收&#xff0c;降本&#xff0c;拿到好看的财报数据&#xff0c;让资本市场继续看好自己 核心手段&#xff1a; 扩展新业务&#xff0c;挖掘已…

基于数据驱动的智能空调系统需求响应可控潜力评估研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

深入理解多线程

一、线程基本概念 1、概述 线程是允许应用程序并发的一种机制。线程共享进程内的所有资源。 线程是调度的基本单位。 每个线程都有自己的 errno。 所有 pthread 函数均以返回 0 表示成功&#xff0c;返回一个正值表示失败。 编译 pthread 程序需要添加链接库&#xff08;…

【Java】反射机制和代理机制

目录一、反射1. 反射概念2. 反射的应用场景3. 反射机制的优缺点4. 反射实战获取 Class 对象的四种方式二、代理机制1. 代理模式2. 静态代理3. 动态代理3.1 JDK动态代理机制1. 介绍2.JDK 动态代理类使用步骤3. 代码示例3.2 CGLIB 动态代理机制1.介绍2.CGLIB 动态代理类使用步骤3…

程序员压力大?用 PyQt 做一个美*女GIF设置桌面,每天都有好心情

嗨害大家好鸭&#xff01;我是小熊猫~ 要说程序员工作的最大压力不是来自于工作本身&#xff0c; 而是来自于需要不断学习才能更好地完成工作&#xff0c; 因为程序员工作中面对的编程语言是在不断更新的&#xff0c; 同时还要学习熟悉其他语言来提升竞争力… 好了&#xff0c…

使用Python通过拉马努金公式快速求π

使用Python通过拉马努金公式快速求π 一、前言 π是一个数学常数&#xff0c;定义为&#xff1a;圆的周长与直径的比值。 π是一个无理数&#xff0c;也是一个超越数&#xff0c;它的小数部分无限不循环。 π可以用来精确计算圆周长、圆面积、球体积等几何形状的关键值。 有关…

【电子学会】2022年12月图形化二级 -- 老鹰捉小鸡

老鹰捉小鸡 小鸡正在农场上玩耍&#xff0c;突然从远处飞来一只老鹰&#xff0c;小鸡要快速回到鸡舍中&#xff0c;躲避老鹰的抓捕。 1. 准备工作 &#xff08;1&#xff09;删除默认白色背景&#xff0c;添加背景Farm&#xff1b; &#xff08;2&#xff09;删除默认角色小…

进制间转换

md&#xff0c;离开学校好多年了&#xff0c;这些基础趁现在还记得记录一下&#xff0c;不然怕哪天还给老师就尴尬了&#xff0c;方便复习 基本概念 二进制&#xff1a;&#xff08;逢2进1&#xff09;由0和1组成。十六进制&#xff1a;&#xff08;逢16进1&#xff09;由0-9&a…

编码器SIQ-02FVS3驱动

一.简介 此编码器可以是功能非常强大&#xff0c;可以检测左右转动&#xff0c;和按键按下&#xff0c;所以说这一个编码器可以抵三个按键&#xff0c;而且体积非常小&#xff0c;使用起来比三个按键要高大尚&#xff0c;而且驱动也简单。唯一不足的点就是价格有点小贵6-8元才…

Faster RCNN 论文阅读

1.网络架构 VGG16网络 anchors:人工放上去的 RPN对anchors进行二分类&#xff0c;正样本&#xff0c;负样本 RoIP&#xff1a;前面的框框已经圈出目标&#xff0c;但还不知道具体属于哪个类&#xff0c;它就是干这个工作的 2.VGG网络 VGG网络可以任意替换其他的任意神经网络&am…

Spring核心模块——Aware接口

Aware接口前言基本内容例子结尾前言 Spring的依赖注入最大亮点是所有的Bean对Spring容器对存在都是没有意识到&#xff0c;Spring容器中的Bean的耦合度是很低的&#xff0c;我们可以将Spring容器很容易换成其他的容器。 但是实际开发的时候&#xff0c;我们经常要用到Spring容…

虚拟机安装Windows 10

虚拟机安装Windows 10 镜像下载 方法一&#xff1a;下载我制作好的镜像文件->百度网盘链接 提取码&#xff1a;Chen 方法二&#xff1a;自己做一个 进入微软官网链接 下载"MediaCreationTool20H2" 运行该工具 点击下一步选择路径&#xff0c;等他下载好就欧克了…

我就不信你还不懂HashSet/HashMap的底层原理

&#x1f4a5;注&#x1f4a5; &#x1f497;阅读本博客需备的前置知识如下&#x1f497; &#x1f31f;数据结构常识&#x1f31f;&#x1f449;1️⃣八种数据结构快速扫盲&#x1f31f;Java集合常识&#x1f31f;&#x1f449;2️⃣Java单列集合扫盲 ⭐️本博客知识点收录于…