计算机视觉入门 5)自定义卷积网络

news2025/1/9 15:59:49

系列文章目录

  1. 计算机视觉入门 1)卷积分类器
  2. 计算机视觉入门 2)卷积和ReLU
  3. 计算机视觉入门 3)最大池化
  4. 计算机视觉入门 4)滑动窗口
  5. 计算机视觉入门 5)自定义卷积网络
  6. 计算机视觉入门 6) 数据集增强(Data Augmentation)

提示:仅为个人学习笔记分享,若有错漏请各位老师同学指出,Thanks♪(・ω・)ノ


目录

  • 系列文章目录
  • 一、自定义卷积网络
    • 从简单到精细
    • 卷积块
  • 二、【代码示例】构建一个简单的卷积网络
    • 步骤1 - 加载数据
    • 步骤2 - 定义模型
    • 步骤3 - 训练模型


一、自定义卷积网络

从简单到精细

前几篇笔记介绍了卷积网络通过三个操作进行特征提取过滤检测压缩。一次特征提取只能从图像中提取相对简单的特征,例如简单的线条或对比度。这些特征对于解决大多数分类问题来说过于简单。相反,卷积网络会一遍又一遍地重复这个提取过程,使特征在网络内部深入传递时变得更加复杂和精细。
从图像中提取的特征,从简单到精细。

卷积块

它通过将图像通过一系列的卷积块来进行这个过程,从而实现这一点。
作为一系列块的提取过程。

这些卷积块是Conv2DMaxPool2D层的堆叠,如下所示:

一种提取块:卷积,ReLU,池化。

每个块代表一轮提取,通过组合这些块,卷积网络可以将产生的特征组合和重新组合。现代卷积网络的深层结构使得这种复杂的特征工程成为可能,从而大大提高它们在处理和解决任务上的性能。

二、【代码示例】构建一个简单的卷积网络

步骤1 - 加载数据

# 导入库
import os, warnings
import matplotlib.pyplot as plt
from matplotlib import gridspec

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory

# 设置随机种子以保证可复现性
def set_seed(seed=31415):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
set_seed()

# 设置Matplotlib默认值
plt.rc('figure', autolayout=True)
plt.rc('axes', labelweight='bold', labelsize='large',
       titleweight='bold', titlesize=18, titlepad=10)
plt.rc('image', cmap='magma')
warnings.filterwarnings("ignore") # 清理输出单元格

# 加载训练集和验证集
ds_train_ = image_dataset_from_directory(
    '../input/car-or-truck/train',
    labels='inferred',
    label_mode='binary',
    image_size=[128, 128],
    interpolation='nearest',
    batch_size=64,
    shuffle=True,
)
ds_valid_ = image_dataset_from_directory(
    '../input/car-or-truck/valid',
    labels='inferred',
    label_mode='binary',
    image_size=[128, 128],
    interpolation='nearest',
    batch_size=64,
    shuffle=False,
)

# 数据处理管道
def convert_to_float(image, label):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    return image, label

AUTOTUNE = tf.data.experimental.AUTOTUNE
ds_train = (
    ds_train_
    .map(convert_to_float)
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)
ds_valid = (
    ds_valid_
    .map(convert_to_float)
    .cache()
    .prefetch(buffer_size=AUTOTUNE)
)

步骤2 - 定义模型

下面是我们将使用的模型的图示:

卷积模型的图示。

现在我们来定义模型。注意我们的模型由三个 Conv2DMaxPool2D 层块组成,后面跟着一些 Dense 层。我们可以通过填写适当的参数,将这个图示基本上直接转化为一个 Keras Sequential 模型。

from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential([

    # 第一个卷积块
    layers.Conv2D(filters=32, kernel_size=5, activation="relu", padding='same',
                  # 在第一层中提供输入维度
                  # [高度, 宽度, 颜色通道(RGB)]
                  input_shape=[128, 128, 3]),
    layers.MaxPool2D(),

    # 第二个卷积块
    layers.Conv2D(filters=64, kernel_size=3, activation="relu", padding='same'),
    layers.MaxPool2D(),

    # 第三个卷积块
    layers.Conv2D(filters=128, kernel_size=3, activation="relu", padding='same'),
    layers.MaxPool2D(),

    # 分类头部
    layers.Flatten(),
    layers.Dense(units=6, activation="relu"),
    layers.Dense(units=1, activation="sigmoid"),
])
model.summary()

这个代码定义了一个包含三个卷积块的模型,每个块都由一个 Conv2D 层和一个 MaxPool2D 层组成。最后的分类头部包括一个扁平化层(Flatten),两个密集连接层(Dense),用于输出预测结果。模型总共有大约50万个参数,它们将根据训练数据进行调整以进行有效的特征提取和分类。

输出模型Summary:
在这里插入图片描述
注意:这里每个块的过滤器数量都是逐块翻倍增加的:32、64、128。这是一种常见的模式。由于MaxPool2D层在每个块中降低了特征图的尺寸,因此我们可以逐块增加我们创建的特征图数量。

步骤3 - 训练模型

# 加入损失函数和准确率
model.compile(
    optimizer=tf.keras.optimizers.Adam(epsilon=0.01),
    loss='binary_crossentropy',
    metrics=['binary_accuracy']
)

#模型训练
history = model.fit(
    ds_train,
    validation_data=ds_valid,
    epochs=40,
    verbose=0,
)
import pandas as pd

#结果查看
history_frame = pd.DataFrame(history.history)
history_frame.loc[:, ['loss', 'val_loss']].plot()
history_frame.loc[:, ['binary_accuracy', 'val_binary_accuracy']].plot();

在这里插入图片描述
这个模型只有简单的3个卷积层,尽管如此,它仍然能够相当好地适应这个数据集。
见上图,蓝色曲线快速地变化,同时黄色曲线(验证集val)在10轮训练过程后就进入波动稳定的状态, 这表明模型容易过拟合,需要正则化处理。
后续可以添加更多的卷积层、或者Dropout层来进行正则化,继续改进这个模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/916267.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多客户企业选择拥有哪些功能的CRM系统?

管理海量客户信息对于每一家企业都是巨大的挑战。粗放式的管理客户资料是对资源的一种浪费,让很多有意向的高价值客户流失。客户比较多,有什么CRM系统推荐吗?帮助企业轻松地跟进客户,提高销售效率? 1.易于使用 首先是…

昨晚做梦面试官问我三色标记算法

本文已收录至GitHub,推荐阅读 👉 Java随想录 微信公众号:Java随想录 原创不易,注重版权。转载请注明原作者和原文链接 文章目录 三色标记算法增量更新原始快照 某天,爪哇星球上,一个普通的房间&#xff0c…

技术的巅峰演进:深入解析算力网络的多层次技术设计

在数字化时代的浪潮中,网络技术正以前所未有的速度演进,而算力网络作为其中的一颗明星,以其多层次的技术设计引领着未来的网络构架。本文将带您深入探索算力网络独特的技术之旅,从底层协议到分布式控制,为您呈现这一创…

chatgpt官方支持微调了!

前言 刚刚openai在官网宣布chatgpt支持微调了!具体支持微调的模型有: 并且GPT-3.5-Turbo-16k和GPT4在今年晚些也会支持微调。 其在官网也介绍了一些微调和准备数据的实战经验,可以学习~ 官方文档: https://platfor…

常用的数据可视化工具有哪些?要操作简单的

随着数据量的剧增,对分析效率和数据信息传递都带来了不小的挑战,于是数据可视化工具应运而生,通过直观形象的图表来展现、传递数据信息,提高数据分析报表的易读性。那么,常用的操作简单数据可视化工具有哪些&#xff1…

高并发保证接口幂等性方案

接口幂等的解决方案 什么是接口幂等性 接口幂等性是指无论调用多少次相同的接口请求,对系统的状态和数据产生的影响都是一致的。简而言之,幂等性保证了对同一个接口请求的重复调用不会产生额外的副作用或改变系统的状态。 在设计和实现接口时&#xf…

SQL Server 执行报错: “minus“ 附近有语法错误。

sql server 执行带 minus 的语句一直报错,如下图: 找了好久才知道minus是Oracle里面的语法,SQL server 应用 EXCEPT。

PCL中的ISS特征点检测

ISS是基于内部形态描述子(ISS) 的特征点。 算法检测流程(参考论文:基于ISS 特征点结合改进ICP 的点云配准算法): PCL中的实现: template<typename PointInT, typename PointOutT, typename NormalT> void pcl::ISSKeypoint3D<PointInT, PointOutT, NormalT>…

斯坦福大学医学院教授:几年内ChatGPT之类的AI将纳入日常医学实践

注意&#xff1a;本信息仅供参考&#xff0c;分享此内容旨在传递更多信息之目的&#xff0c;并不意味着赞同其观点或证实其说法。 在一项新研究中&#xff0c;斯坦福大学研究人员发现&#xff0c;ChatGPT在复杂临床护理考试题中可以胜过一、二年级的医学生。此项研究显示&#…

组件库的使用和自定义组件

目录 一、组件库介绍 1、什么是组件 2、组件库介绍 3、arco.design 二、组件库的使用 1、快速上手 2、主题定制 3、暗黑模式 4、语言国际化 5、业务常见问题 三、自定义组件 2、组件开发规范 3、示例实践guide-tip 4、业务组件快速托管 一、组件库介绍 1、什么是…

allegro输出.IPC文件

1、ipc文件的导出 板厂会使用cam软件生产一个网表文件&#xff1b;如果传递给板厂的数据中也有一个IPC文件&#xff0c;板厂将对两个网表文件进行对比&#xff1b;提高生产的安全性&#xff0c;准确性&#xff1b; 1&#xff0c;PCB软件输出的光绘文件&#xff0c;有时会变异&a…

从LeakCanary看如何生成内存快照

前面我们已经完成了生命周期监控并且可以通过ReferenceQueue和WeakHashMap的比较确定哪些对象发生泄漏了&#xff0c;那么接下来需要考虑的就是如何确定这个对象是被谁持有导致泄漏的呢&#xff1f; 从内存泄漏一文中可知&#xff0c;当我们使用Android Studio或MAT分析内存泄…

【从零学习python 】75. TCP协议:可靠的面向连接的传输层通信协议

文章目录 TCP协议TCP通信的三个步骤TCP特点TCP与UDP的区别TCP通信模型进阶案例 TCP协议 TCP协议&#xff0c;传输控制协议&#xff08;英语&#xff1a;Transmission Control Protocol&#xff0c;缩写为 TCP&#xff09;是一种面向连接的、可靠的、基于字节流的传输层通信协议…

java八股文面试[数据结构]——List和Set的区别

List和Set是用来存放集合的接口&#xff0c;并且二者都继承自接接口Collection List 中的元素存放是有序的&#xff0c;可以存放重复的元素&#xff0c;检索效率较高&#xff0c;插入删除效率较低。 Set 没有存放顺序不能存放重复元素检索效率较低&#xff0c;插入删除效率较…

【前端】深入理解CSS盒子模型与浮动

目录 一、前言二、盒子模型1、盒子模型组成1.1、border边框1.1.1、边框的三部分组成1.1.2、边框复合简写1.1.3、边框分开写1.1.4、表格的细线边框 1.2、padding内边距1.3、margin外边距1.3.1、外边距水平居中1.3.2、外边距合并1.3.3、嵌套块元素垂直 外边距的塌陷1.3.3.1、解决…

全流程R语言Meta分析核心技术应用

Meta分析是针对某一科研问题&#xff0c;根据明确的搜索策略、选择筛选文献标准、采用严格的评价方法&#xff0c;对来源不同的研究成果进行收集、合并及定量统计分析的方法&#xff0c;最早出现于“循证医学”&#xff0c;现已广泛应用于农林生态&#xff0c;资源环境等方面。…

LMLCCS_UPDATEFO2 LCL DB 方法 get_normvector 头寸 A 中RC 1 内部错误,过账时报错<转载>

原文链接&#xff1a;https://blog.csdn.net/XFYBB/article/details/129174579 物料的成本中心&#xff0c;作业价格没有维护 再用FCMLHELP&#xff0c;重新创建一下 se37&#xff0c;FCMLHELP_CHECK_TESTFLAG&#xff0c;打断点&#xff0c;跳过PW

外围信息收集

一、查询域名信息 1、安装whois sudo apt update sudo apt install whois2、使用 whois [域名]也可以通过在线网站进行查询网站 3、反查 4、网站在线查询 4.1、网站一 通过使用网站去查询&#xff1a;网址 &#xff0c;这个网站只会记录他所知道的域名&#xff0c;不全 4.…

网络综合布线实训室方案(2023版)

综合布线实训室概述 随着智慧城市的蓬勃发展,人工智能、物联网、云计算、大数据等新兴行业也随之崛起,网络布线系统作为现代智慧城市、智慧社区、智能建筑、智能家居、智能工厂和现代服务业的基础设施和神经网络,发挥着重要作用。实践表明,网络系统故障的70%发生在布线系统,直接…

centos7 忘记密码需要重置密码

第一步进入系统加载条之前 按 e 键 第二步到了这个界面 找到 linux16 开头的 将 ro 改成 rw init/sysroot/bin/sh 修改完之后 按下按键 ctrlx 或者是 F10 第三步输入 命令 chroot /sysroot第四步 重置root用户密码 这里重置密码是叫你输入两遍密码 passwd root第五步 更…