神经网络实战--使用迁移学习完成猫狗分类

news2024/12/26 23:12:40

在这里插入图片描述

前言: Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~

本文目录:

  • 一、加载数据集
    • 1.调用库函数
    • 2.加载数据集
    • 3.数据集管理
  • 二、猫狗数据集介绍
    • 1.猫狗数据集介绍:
    • 2.图片展示
  • 三、MobileNetV2网络介绍
    • 1.加载tensorflow提供的预训练模型
    • 2.轻量级网络——MobileNetV2
    • 3.MobileNetV2的网络模块
  • 四、搭建迁移学习
    • 1.训练
    • 2.训练结果可视化
    • 3.输出训练的准确率
    • 4.用cnn工具可视化一批数据的预测结果
    • 5.数据输出
    • 6.用cnn工具可视化一个数据样本的各层输出
    • 7.输出结果图像
  • 五、源码获取

说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:

  1. 实现基于tensorflow和keras的迁移学习

  2. 加载tensorflow提供的数据集(不得使用cifar10)

  3. 需要使用markdown单元格对数据集进行说明

  4. 加载tensorflow提供的预训练模型(不得使用vgg16)

  5. 需要使用markdown单元格对原始模型进行说明

  6. 网络末端连接任意结构的输出端网络

  7. 用图表显示准确率和损失函数

  8. 用cnn工具可视化一批数据的预测结果

  9. 用cnn工具可视化一个数据样本的各层输出

一、加载数据集

1.调用库函数

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout

2.加载数据集

数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。

path_to_zip = tf.keras.utils.get_file(
    'data.zip',
    origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',
    extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

3.数据集管理

使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)

validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)

二、猫狗数据集介绍

1.猫狗数据集介绍:

猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。
在这里插入图片描述
在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。
在这里插入图片描述

2.图片展示

下面是数据集中的图片展示:

class_names = ['cats', 'dogs']

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

三、MobileNetV2网络介绍

1.加载tensorflow提供的预训练模型

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

2.轻量级网络——MobileNetV2

使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩
在这里插入图片描述

3.MobileNetV2的网络模块

MobileNetV2的网络模块样子是这样的:
在这里插入图片描述
MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride
在这里插入图片描述

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):
在这里插入图片描述

四、搭建迁移学习

1.训练

inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = False
base_model.summary()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

2.训练结果可视化

用图表显示准确率和损失函数

# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")
 
plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

3.输出训练的准确率

# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

4.用cnn工具可视化一批数据的预测结果

label_dict = {
    0: 'cat',
    1: 'dog'
}

test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')

cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

5.数据输出

# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()

img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]

cnn_utils.get_activations(base_model, image_activation[0])

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

6.用cnn工具可视化一个数据样本的各层输出

cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述

7.输出结果图像

🌟🌟🌟 这里是输出的结果:✨✨✨
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

五、源码获取

关注此公众号:人生苦短我用Pythons,回复 神经网络源码获取源码,快点击我吧

🌲🌲🌲 好啦,这就是今天要分享给大家的全部内容了,我们下期再见!
❤️❤️❤️如果你喜欢的话,就不要吝惜你的一键三连了~
在这里插入图片描述
在这里插入图片描述

最后,有任何问题,欢迎关注下面的公众号,获取第一时间消息、作者联系方式及每周抽奖等多重好礼! ↓↓↓

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/338102.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Spring(十一)】万字带你深入学习面向切面编程AOP

文章目录前言AOP简介AOP入门案例AOP工作流程AOP切入点表达式AOP通知类型AOP通知获取数据总结前言 今天我们来学习AOP,在最初我们学习Spring时说过Spring的两大特征,一个是IOC,一个是AOP,我们现在要学习的就是这个AOP。 AOP简介 AOP:面向切面编程,一种编程范式&#…

计算机网络自顶向下 -- 流水线,滑动窗口协议

流水线协议 Rdt3.0在停等操作的过程中浪费了大量的时间: 从而在Rdt 3.0上引入了流水线机制:为了提高资源利用率 流水线协议: 允许发送方在收到ACK之前连续发送多个分组,更大的序列号范围,同时发送方和/或接收方需要更…

关于自动驾驶高精定位的几大问题

交流群 | 进“传感器群/滑板底盘群”请加微信号:xsh041388交流群 | 进“汽车基础软件群”请加微信号:Faye_chloe备注信息:群名称 真实姓名、公司、岗位作者 | 许良定位是高等级自动驾驶的基础,但在高速NOA和城区NOA等场景中&…

Linux账号与用户组

目录 用户标识符:UID与GID 用户账号 /etc/passwd文件结构 1、账号名称 2、密码 3、UID 4、GID 5、用户信息说明栏 6、家目录 7、shell /etc/shadow文件结构 1、账号名称 2、密码 3、最近修改密码的日期 4、密码不可被修改的天数(与第三字…

Git | 在IDEA中使用Git

目录 一、在IDEA中配置Git 1.1 配置Git 1.2 获取Git仓库 1.3 将本地项目推送到远程仓库 1.4 .gitignore文件的作用 二、本地仓库操作 2.1 将文件加入暂存区 2.2 将暂存区的文件提交到版本库 2.3 查看日志 三、远程仓库操作 3.1 查看和添加远程仓库 3.2 推送至远程仓…

fastcgi未授权访问漏洞(php-fpm fast-cgi未授权访问漏洞)

本文参考《Fastcgi协议分析 && PHP-FPM未授权访问漏洞 && Exp编写》进行该漏洞的复现以及分析。 1.前置基础 1.1 nginx中的fastcgi 先来看先前用过的一张图,其是nginx解析用户请求的过程。 图中的几个定义: CGI:CGI是一种…

1628_MIT 6.828 xv6_chapter0操作系统接口

全部学习汇总: GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 这本书最初看名字以为是对早期unix的一个解读,但是看了开篇发现 不完全是,只是针对JOS教学OS系统来做的一些讲解。 Xv6是对UNIX v6的重新实…

【Java 面试合集】Java中修饰符有哪些,有什么应用场景

Java中修饰符有哪些,有什么应用场景 1. 概述 首先我们要知道Java的三大特性:封装,继承,多态。 而我们今天要分析的修饰符就跟封装有着密切的联系。因为权限修饰符可以控制变量以及方法的作用范围。 废话不多说,上图…

Python推导式

列表&#xff08;list&#xff09;推导式 [remove for source in xx_list]或者[remove for source in xx_list if condition] 实例&#xff1a; names[Bob,Mark,Mausk,Johndan,Wendy] new_names[name.upper() for name in names if len(name)<5] print(new_names)即迭代列…

PC端开发GUI

PC端开发GUI PC端环境搭建1、Python2、PycharmPC端环境搭建 1、Python 注意Python版本不能超过3.9,因为pyqt-tools只维护到python对应的该版本 1.1、查找是否安装python:win+R,输入cmd回车,输入python或python -V或python --version 1.2、若1.1没有,则下载安装下载链接…

天津菲图尼克科技携洁净及无菌防护服解决方案与您相约2023生物发酵展

BIO CHINA 生物发酵产业一年一度行业盛会&#xff0c;由中国生物发酵产业协会主办&#xff0c;上海信世展览服务有限公司承办&#xff0c;2023第10届国际生物发酵产品与技术装备展览会&#xff08;济南&#xff09;于2023年3月30-4月1日在山东国际会展中心&#xff08;济南市槐…

亿级高并发电商项目-- 实战篇 --万达商城项目 二(Zookeeper、Docker、Dubbo-Admin等搭建工作

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是小童&#xff0c;Java开发工程师&#xff0c;CSDN博客博主&#xff0c;Java领域新星创作者 &#x1f4d5;系列专栏&#xff1a;前端、Java、Java中间件大全、微信小程序、微信支付、若依框架、Spring全家桶 &#x1f4…

第二章-进程(2)

进程一、进程的引入二、进程的状态及组成三、进程控制一、进程的引入 &#xff08;1&#xff09;程序的顺序执行: P1:axy P2:ba-5 P3:cb1 程序总是按照P1→P2→P3的顺序执行。 特点&#xff1a; 顺序性&#xff1a;处理机的操作严格按规定顺序执行。封闭性&#xff1a;程序执…

python(8):使用conda update更新conda后,anaconda所有环境崩溃----问题没有解决,不要轻易更新conda

文章目录0. 教训1. 问题:使用conda update更新conda后&#xff0c;anaconda所有环境崩溃1.1 问题描述1.2 我搜索到的全网最相关的问题----也没有解决3 尝试流程记录3.1 重新安装pip3.2 解决anaconda编译问题----没成功0. 教训 (1) 不要轻易使用conda update更新conda----我遇到…

[OpenMMLab]AI实战营第六节课

语义分割算法基础 任务&#xff1a;图像按照物体的类别分隔成不同区域&#xff0c;即将每个像素进行分类 应用&#xff1a;无人驾驶、医疗、人像、智能遥感 思路 基本思路&#xff1a;按照颜色区分 --> 逐像素分类&#xff08;滑动窗口&#xff0c;用CNN分类&#xff0c…

微搭低代码从入门到精通11-数据模型

学习微搭低代码&#xff0c;先学习基本操作&#xff0c;然后学习组件的基本使用。解决了前端的问题&#xff0c;我们就需要深入学习后端的功能。后端一般包括两部分&#xff0c;第一部分是常规的数据库的操作&#xff0c;包括增删改查。第二部分是业务逻辑的编写&#xff0c;在…

QT基础入门

学习视频&#xff1a;QT开发概述_哔哩哔哩_bilibili 1.QT开发概述 1.什么是QT QT是一个1991年由Qt Company开发的跨平台C图形用户界面应用程序开发框架。它既可以开发GUI程序&#xff0c;也可用于开发非GUI程序&#xff0c;比如控制台工具和服务器。Qt是面向对象的框架&#…

STC15单片机软串口的使用

STC15软串口的使用&#x1f4d6;在没有使用定时器资源的情况下&#xff0c;根据波特率位传输时间&#xff0c;利用STC-ISP工具自动计算出位延时函数。 ✨在官方所提供的库函数中位传输时间函数,仅适用于使用波特率为&#xff1a;9600的串口数据传输&#xff1a; void BitTime(…

Grafana 系列文章(十四):Helm 安装Loki

前言 写或者翻译这么多篇 Loki 相关的文章了, 发现还没写怎么安装 &#x1f613; 现在开始介绍如何使用 Helm 安装 Loki. 前提 有 Helm, 并且添加 Grafana 的官方源: helm repo add grafana https://grafana.github.io/helm-charts helm repo update &#x1f43e;Warning…

nacos的单机模式和集群模式

文章目录 目录 文章目录 前言 一、nacos数据库配置 二、单机模式 三、集群模式 四、使用nginx集群模式的负载均衡 总结 前言 一、nacos数据库配置 在数据库中创建nacos_config 编码格式utf8-mb4的数据库 把上面的数据库文件导入数据库 在 配置文件中添加如下 spring.datasour…