深度学习笔记11-优化器对比实验(Tensorflow)

news2025/1/18 2:34:40
  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签

data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,#不包含顶层的全连接层
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')#平均池化层替代顶层的全连接层
    for layer in vgg16_base_model.layers:
        layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新

    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)

    output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)

    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20

history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2276573.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL 16 章——变量、流程控制和游标

一、变量 在MySQL数据库的存储过程和存储函数中,可以使用变量来存储查询或计算的中间结果数据,或者输出最终的结果数据 在MySQL数据库中,变量分为系统变量和用户自定义变量 (1)系统变量 1.1.1系统变量分类 变量由…

【HTML+CSS+JS+VUE】web前端教程-13-Form表单

表单在web网页中用来给用户填写信息,从而能采用户信息,使网页具有交互的功能, 所有的用户输入内容的地方都用表单来写,如登录注册、搜索框。 表单是由容器和控件组成的,一个表单一般应该包含用户填写信息的输入框,提交按钮等,这些输入框,按钮叫做控件,表单就是容器,他…

LabVIEW滤波器功能

程序通过LabVIEW生成一个带噪声的正弦波信号,并利用滤波器对其进行信号提取。具体来说,它生成一个正弦波信号,叠加高频噪声后形成带噪信号,再通过低通滤波器滤除噪声,提取原始正弦波信号。整个过程展示了信号生成、噪声…

linux: 文本编辑器vim

文本编辑器 vi的工作模式 (vim和vi一致) 进入vim的方法 方法一:输入 vim 文件名 此时左下角有 "文件名" 文件行数,字符数量 方法一: 输入 vim 新文件名 此时新建了一个文件并进入vim,左下角有 "文件名"[New File] 灰色的长方形就是光标,输入文字,左下…

Java Web开发进阶——错误处理与日志管理

错误处理和日志管理是任何生产环境中不可或缺的一部分。在 Spring Boot 中,合理的错误处理机制不仅能够提升用户体验,还能帮助开发者快速定位问题;而有效的日志管理能够帮助团队监控应用运行状态,及时发现和解决问题。 1. 常见错误…

《零基础Go语言算法实战》【题目 2-25】goroutine 的执行权问题

《零基础Go语言算法实战》 【题目 2-25】goroutine 的执行权问题 请说明以下这段代码为什么会卡死。 package main import ( "fmt" "runtime" ) func main() { var i byte go func() { for i 0; i < 255; i { } }() fmt.Println("start&quo…

Ubuntu系统Qt的下载、安装及入门使用,图文详细,内容全面

文章目录 说明1 在线安装2 离线安装3 使用Qt Creator创建Qt应用程序并构建运行补充补充一&#xff1a;注册Qt账号 说明 本文讲解Ubuntu系统下安装Qt&#xff0c;包括在线安装和离线安装两种方式&#xff0c;内容充实细致&#xff0c;话多但是没有多余&#xff08;不要嫌我啰嗦…

线形回归与小批量梯度下降实例

1、准备数据集 import numpy as np import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader from torch.utils.data import TensorDataset######################################################################### #################准备若干个随机的x和…

P3884 [JLOI2009] 二叉树问题

题目描述&#xff1a; 如下图所示的一棵二叉树的深度、宽度及结点间距离分别为&#xff1a; - 深度&#xff1a;4 - 宽度&#xff1a;4 - 结点 8 和 6 之间的距离&#xff1a;8 - 结点 7 和 6 之间的距离&#xff1a;3 其中宽度表示二叉树上同一层最多的结点个数&#xff0c;节…

ssm旅游攻略网站设计+jsp

系统包含&#xff1a;源码论文 所用技术&#xff1a;SpringBootVueSSMMybatisMysql 需要源码或者定制看文章最下面或看我的主页 目 录 目 录 III 1 绪论 1 1.1 研究背景 1 1.2 目的和意义 1 1.3 论文结构安排 2 2 相关技术 3 2.1 SSM框架介绍 3 2.2 B/S结构介绍 3 …

Qt类的提升(Python)

from PyQt5.QtWidgets import QPushButtonclass apushbutton(QPushButton):def __init__(self, parentNone):super().__init__(parent)self.setText("Custom Button")self.setStyleSheet("background-color: yellow;")上述为一个“模板类”&#xff0c;命名…

kubernetes上安装kubesphere

准备工作 需要配置三台虚拟机 关闭防火墙 systemctl stop firewalldsystemctl disable firewalld 临时关闭selinux setenforce 0 永久关闭selinux vi /etc/selinux/config 安装docker rpm -qa|grep docker yum remove docker* -y rpm -qa|grep docker yum install -y yum-u…

Windows图形界面(GUI)-QT-C/C++ - QT控件创建管理初始化

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 目录 控件创建 包含对应控件类型头文件 实例化控件类对象 控件设置 设置父控件 设置窗口标题 设置控件大小 设置控件坐标 设置文本颜色和背景颜色 控件排版 垂直布局 QVBoxLayout …

分页工具代码重构

文章目录 1.common-mybatis-plus-starter1.目录2.PageInfo.java3.PageResult.java4.SunPageHelper.java 1.common-mybatis-plus-starter 1.目录 2.PageInfo.java package com.sunxiansheng.mybatis.plus.page;import lombok.EqualsAndHashCode; import lombok.ToString;impor…

Vue学习二——创建登录页面

前言 以一个登录页面为例子&#xff0c;这篇文章简单介绍了vue&#xff0c;element-plus的一些组件使用&#xff0c;vue-router页面跳转&#xff0c;pinia及持久化存储&#xff0c;axios发送请求的使用。后面的页面都大差不差&#xff0c;也都这么实现&#xff0c;只是内容&am…

初始Django框架

初识Django Python知识点&#xff1a;函数、面向对象。前端开发&#xff1a;HTML、CSS、JavaScript、jQuery、BootStrap。MySQL数据库。Python的Web框架&#xff1a; Flask&#xff0c;自身短小精悍 第三方组件。Django&#xff0c;内部已集成了很多组件 第三方组件。【主要…

深度学习每周学习总结R4(LSTM-实现糖尿病探索与预测)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客R6中的内容&#xff0c;为了便于自己整理总结起名为R4&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制 目录 0. 总结1. LSTM介绍LSTM的基本组成部分如何理解与应用LSTM 2. 数据预处理3. 数…

虚假星标:GitHub上的“刷星”乱象与应对之道

在开源软件的世界里&#xff0c;GitHub无疑是最重要的平台之一。它不仅是一个代码托管平台&#xff0c;也是一个社交网络&#xff0c;允许开发者通过“点赞”&#xff08;即加星&#xff09;来表达对某个项目的喜爱和支持&#xff0c;“星标”&#xff08;Star&#xff09;则成…

前端笔记----

在我的理解里边一切做页面的代码都是属于前端代码。 之前用过qt框架&#xff0c;也是用来写界面的&#xff0c;但是那是用来写客户端的&#xff0c;而html是用来写web浏览器的&#xff0c;相较之下htmlcssJavaScript写出来的界面是更加漂亮的。这里就记录我自个学习后的一些笔…

【面试题】技术场景 4、负责项目时遇到的棘手问题及解决方法

工作经验一年以上程序员必问问题 面试题概述 问题为在负责项目时遇到的棘手问题及解决方法&#xff0c;主要考察开发经验与技术水平&#xff0c;回答不佳会影响面试印象。提供四个回答方向&#xff0c;准备其中一个方向即可。 1、设计模式应用方向 以登录为例&#xff0c;未…