微调模型来完成热狗识别的图像分类任务

news2024/11/24 16:50:59

我们来实践一个具体的例子:热狗识别。将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需的工具包。

import tensorflow as tf
import numpy as np

获取数据集

我们首先将数据集放在路径hotdog/data之下:

1678083514572_81.png

每个类别文件夹里面是图像文件。

上一节中我们介绍了ImageDataGenerator进行图像增强,我们可以通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self, directory,
                            target_size=(256, 256), color_mode='rgb',
                            classes=None, class_mode='categorical',
                            batch_size=32, shuffle=True, seed=None,
                            save_to_dir=None)

主要参数:

▪ directory: 目标文件夹路径,对于每一个类对应一个子文件夹,该子文件夹中任何JPG、PNG、BNP、PPM的图片都可以读取。

▪ target_size: 默认为(256, 256),图像将被resize成该尺寸。

▪ batch_size: batch数据的大小,默认32。

▪ shuffle: 是否打乱数据,默认为True。

我们创建两个tf.keras.preprocessing.image.ImageDataGenerator实例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高和宽均为224像素的输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化。

# 获取数据集
import pathlib
train_dir = 'transferdata/train'
test_dir = 'transferdata/test'
# 获取训练集数据
train_dir = pathlib.Path(train_dir)
train_count = len(list(train_dir.glob('*/*.jpg')))
# 获取测试集数据
test_dir = pathlib.Path(test_dir)
test_count = len(list(test_dir.glob('*/*.jpg')))
# 创建imageDataGenerator进行图像处理
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
# 设置参数
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
# 获取训练数据
train_data_gen = image_generator.flow_from_directory(directory=str(train_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)
# 获取测试数据
test_data_gen = image_generator.flow_from_directory(directory=str(test_dir),
                                                    batch_size=BATCH_SIZE,
                                                    target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                    shuffle=True)

下面我们随机取1个batch的图片然后绘制出来。

import matplotlib.pyplot as plt
# 显示图像
def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(15):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.axis('off')
# 随机选择一个batch的图像        
image_batch, label_batch = next(train_data_gen)
# 图像显示
show_batch(image_batch, label_batch)

1678083965743_82.png

模型构建与训练

我们使用在ImageNet数据集上预训练的ResNet-50作为源模型。这里指定weights='imagenet’来自动下载并加载预训练的模型参数。在第一次使用时需要联网下载模型参数。

Keras应用程序(keras.applications)是具有预先训练权值的固定架构,该类封装了很多重量级的网络架构,如下图所示:

模型构建与训练

实现时实例化模型架构:

tf.keras.applications.ResNet50(
    include_top=True, weights='imagenet', input_tensor=None, input_shape=None,
    pooling=None, classes=1000, **kwargs
)

主要参数:

▪ include_top: 是否包括顶层的全连接层。

▪ weights: None 代表随机初始化, ‘imagenet’ 代表加载在 ImageNet 上预训练的权值。

▪ input_shape: 可选,输入尺寸元组,仅当 include_top=False 时有效,否则输入形状必须是 (224, 224, 3)(channels_last 格式)或 (3, 224, 224)(channels_first 格式)。它必须为 3 个输入通道,且宽高必须不小于 32,比如 (200, 200, 3) 是一个合法的输入尺寸。

在该案例中我们使用resNet50预训练模型构建模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights='imagenet', input_shape=(224,224,3))
# 设置所有层不可训练
for layer in ResNet50.layers:
    layer.trainable = False
# 设置模型
net = tf.keras.models.Sequential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation='softmax'))

接下来我们使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练。

# 模型编译:指定优化器,损失函数和评价指标net.compile(optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy'])# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集history = net.fit(
                    train_data_gen,
                    steps_per_epoch=10,
                    epochs=3,
                    validation_data=test_data_gen,
                    validation_steps=10
                    )
Epoch 1/3
10/10 [==============================] - 28s 3s/step - loss: 0.6931 - accuracy: 0.5031 - val_loss: 0.6930 - val_accuracy: 0.5094
Epoch 2/3
10/10 [==============================] - 29s 3s/step - loss: 0.6932 - accuracy: 0.5094 - val_loss: 0.6935 - val_accuracy: 0.4812
Epoch 3/3
10/10 [==============================] - 31s 3s/step - loss: 0.6935 - accuracy: 0.4844 - val_loss: 0.6933 - val_accuracy: 0.4875

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/786449.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

“铸网-2023” | 持续保障江西省实网应急演练

​​日前,由江西省工业和信息化厅主办,江西省网络安全研究院承办,南京赛宁信息技术有限公司协办并提供全程技术支撑的“铸网-2023”江西省工业领域网络安全实网应急演练在江西南昌圆满收官。 一、应急演练43天 赛宁持续助推工业企业应急能力…

论文精度系列之详解图神经网络

论文地址:A Gentle Introduction to Graph Neural Networks 翻译:图表就在我们身边;现实世界的对象通常根据它们与其他事物的连接来定义。一组对象以及它们之间的连接自然地表示为图形。十多年来,研究人员已经开发了对图数据进行操作的神经网络(称为图神…

二叉树题目:从根到叶的二进制数之和

文章目录 题目标题和出处难度题目描述要求示例数据范围 解法一思路和算法代码复杂度分析 解法二思路和算法代码复杂度分析 题目 标题和出处 标题:从根到叶的二进制数之和 出处:1022. 从根到叶的二进制数之和 难度 3 级 题目描述 要求 给你二叉树…

vue权限按钮的实现

鉴权函数 由于下面几种方式都需要用到鉴权函数,所以将其放置在组件外面,供组件或其他文件调用。 // src/utils/hasPermission.jsimport { usePermissionStore } from /stores import array from lodash/array export const hasPermission (value, def…

火车头采集文章批量伪原创【php源码】

火车头采集是一款基于Python语言开发的网络爬虫工具&#xff0c;用于快速高效地从互联网上采集数据并存储到本地或远程数据库。它简单易用且功能强大&#xff0c;在各行各业广泛应用。 火车头采集器AI伪原创PHP源码&#xff1a; <?php header("Content-type: text/h…

C# WPF项目创建(基于VS 2019介绍)

1.打开VS&#xff0c;选择《创建新项目》 2.选择《WPF应用程序》&#xff0c;如图 3. 项目名称根据需求自行命名&#xff0c;这边以“WpfApp1”来命名&#xff0c;位置自行选择&#xff0c;这边选择了"E:\test"&#xff0c;如图所示&#xff0c;“将解决方案和项目…

【广州华锐互动】VR模拟灭火逃生体验系统

VR模拟灭火逃生体验系统由广州华锐互动开发&#xff0c;是一种基于虚拟现实技术的应急演练与培训系统&#xff0c;可以真实模拟消防逃生场景&#xff0c;让体验者在沉浸式的虚拟环境中&#xff0c;根据正确的消防逃生方法提示&#xff0c;进行自救演练。这种科学普及方法是更加…

datafree KD CVPR2023 学习笔记(Abstract)

这个 是摘要的前部分 摘要 简单的提及 无数据的蒸馏 是怎么样的 和普通的蒸馏有一个本质的区别&#xff1a;没有训练数据 很火啊 最近 在对抗性的蒸馏框架中呢 存在一个问题 由于多个生成器生成的 非平稳分布的pseudo-samples 导致了学生网络的性能不好 提出一个解决方案的i…

Docker——数据管理与网络通信

Docker——数据管理与网络通信 一、Docker 的数据管理1&#xff0e;数据卷2&#xff0e;数据卷容器 二、端口映射三、容器互联&#xff08;使用centos镜像&#xff09;四、Docker 镜像的创建1&#xff0e;基于现有镜像创建1.1 首先启动一个镜像&#xff0c;在容器里做修改1.2 然…

如何搭建使用dubbo-Admin?

dubbo-Admin介绍 一款用于dubbo可视化界面操作的管理平台 dubbo-Admin特点 dubbo-Admin是dubbo的管理界面平台&#xff0c;且是一个前后端分离的项目&#xff0c;前端使用vue&#xff0c;后端使用springboot。 软件下载 dubbo-admin-0.5.0.zip 软件使用

Java文件的相对路径规则

前言 最近做项目&#xff0c;又涉及到Linux Java文件的相对路径&#xff0c;但是相对路径在不同的服务器或者docker上居然不一样&#xff0c;这个就很难受&#xff0c;只能用绝对路径解决&#xff0c;因为绝对路径是固定的路径&#xff0c;但是相对路径为什么会在不同的服务器…

将pdf转换成word怎么转换?掌握这几个转换方法就够了

将pdf转换成word怎么转换&#xff1f;将 PDF 转换成 Word 文档是一个常见的需求&#xff0c;我们需要编辑 PDF 文件中的文字内容&#xff0c;但是 PDF 文件并不方便直接编辑。下面介绍几个转换方法。 方法一&#xff1a;使用迅捷 PDF 转换器 这是一款专业的 PDF 转换工具&…

【2023】HashMap详细源码分析解读

前言 在弄清楚HashMap之前先介绍一下使用到的数据结构&#xff0c;在jdk1.8之后HashMap中为了优化效率加入了红黑树这种数据结构。 树 在计算机科学中&#xff0c;树&#xff08;英语&#xff1a;tree&#xff09;是一种抽象数据类型&#xff08;ADT&#xff09;或是实作这种…

2023年深圳杯数学建模 D题 基于机理的致伤工具推断

致伤工具的推断一直是法医工作中的热点和难点。由于作用位置、作用方式的不同&#xff0c;相同的致伤工具在人体组织上会形成不同的损伤形态&#xff0c;不同的致伤工具也可能形成相同的损伤形态。致伤工具品种繁多、形态各异&#xff0c;但大致可分为两类&#xff1a;锐器&…

结构体——位段

//结构体--位段 &#xff08;位 指二进制位 &#xff09; // 位段的声明与结构体是类似的 // 1&#xff0c;位段的成员必须是 int &#xff0c;unsigned int 或 signed int // 2&#xff0c;位段的成员名后边有一个冒号和一个数字。&#xff08;数字表示开辟需要的比特位个数&a…

【【51单片机的温度识别】】

51单片机的温度识别 这么热的天&#xff0c;再离我远一点 DS18B20温度读取 温度报警 我们这里对于单总线 也仿照I2C 总线的模样 以下是我们实现的onewire.c文件 #include <REGX52.H> //第一步从原理图触发先定义端口sbit OneWire_DQP3^7;//第一步初始化unsigned char…

商城-学习整理-基础-前端(四)

目录 一、技术栈介绍二、ES61、简介2、什么是ECMAScript3、ES6 新特性1、let 声明变量2、const 声明常量&#xff08;只读变量&#xff09;3、解构表达式1&#xff09;数组解构2&#xff09;对象解构 4、字符串扩展1&#xff09;、几个新的 API2&#xff09;、字符串模板 5、函…

跨洋消息队列CKafka

背景 跨洋消息队列CKafka&#xff0c;是腾讯出品开源消息队列组件&#xff0c;为了解决客户跨地域容灾、冷备的诉求跨地域秒级准实时数据同步而开源的腾讯云中间件。 整体架构图 数据同步主要流程 Connect 集群初始化 Connect Task&#xff0c;每个 Task 会新建多个 Worker C…

3D元宇宙游戏,或许能引爆新的文娱消费增长点

从去年开始&#xff0c;在互联网上&#xff0c;一个名为【神念无界-源起山海】的元宇宙游戏项目火了。除了可以在游戏内体验独战、团队式作战等3D古风经典游戏场景和玩法&#xff0c;还有钓鱼增加能量、情侣姻缘一线牵&#xff0c;结婚等多元化逼真效果与玩法&#xff0c;这令很…

第118天:免杀对抗-二开CS上线流量特征Shellcode生成机制反编译重打包(上)

知识点 #知识点&#xff1a; 1、CS-表面特征消除 2、CS-HTTP流量特征消除 3、CS-Shellcode特征消除#章节点&#xff1a; 编译代码面-ShellCode-混淆 编译代码面-编辑执行器-编写 编译代码面-分离加载器-编写 程序文件面-特征码定位-修改 程序文件面-加壳花指令-资源 代码加载面…