第T11周:使用TensorFlow进行优化器对比实验

news2025/1/11 21:05:34
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

    文章目录

    • 一、前期工作
      • 1.设置GPU(如果使用的是CPU可以忽略这步)
    • 二、导入数据
      • 1、导入数据
      • 2、检查数据
      • 3、配置数据集
      • 4、数据可视化
    • 三、构建模型
    • 四、训练模型
    • 五、评估模型
      • 1、Accuracy与Loss图
      • 2.、模型评估
    • 六、总结

电脑环境:
语言环境:Python 3.8.0
编译器:Jupyter Notebook
深度学习环境:tensorflow 2.17.0

一、前期工作

1.设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

from tensorflow          import keras
import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import warnings,os,PIL,pathlib

warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

二、导入数据

1、导入数据

data_dir    = "./365-8-data/"
data_dir    = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
batch_size = 16
img_height = 336
img_width  = 336

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names

2、检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(16, 336, 336, 3)
(16,)

3、配置数据集

AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

4、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]])

plt.show()

在这里插入图片描述

三、构建模型

model1为Adam优化器,model2为SGD优化器,网络结构都一致。

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')
    for layer in vgg16_base_model.layers:
        layer.trainable = False

    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)

    output = Dense(len(class_names), activation='softmax')(X)
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)

    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

四、训练模型

NO_EPOCHS = 50

history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
Epoch 1/50
90/90 ━━━━━━━━━━━━━━━━━━━━ 99s 872ms/step - accuracy: 0.1097 - loss: 3.1620 - val_accuracy: 0.0750 - val_loss: 2.7208
..............
Epoch 50/50
90/90 ━━━━━━━━━━━━━━━━━━━━ 20s 197ms/step - accuracy: 0.8159 - loss: 0.5525 - val_accuracy: 0.5806 - val_loss: 1.4261

五、评估模型

1、Accuracy与Loss图

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.show()

在这里插入图片描述

2.、模型评估

def test_accuracy_report(model):
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model2)

Loss function: 1.4261106252670288, accuracy: 0.5805555582046509

六、总结

我们可以从两个不同的图中分别观察到训练和验证的准确率以及损失值的变化,Adam优化器在训练集上表现优异,迅速提高准确率并减少损失值,但在验证集上的表现不稳定,可能会导致过拟合。SGD优化器在训练集上学习速度较慢,但验证集上的表现较为稳定,适合在更长期的训练中提升模型的泛化能力。具体使用哪种优化器还要看具体任务。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2111297.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

淘宝和微信支付“好”上了,打翻了支付宝的“醋坛子”?

文:互联网江湖 作者:刘致呈 最近,淘宝将全面接入微信支付的消息,在整个互联网圈里炸开了锅。 虽说阿里系平台与腾讯之间“拆墙”的消息,早就不算是啥新鲜事了。而且进一步互联互通,无论是对广大用户&…

43. 1 ~ n 整数中 1 出现的次数【难】

comments: true difficulty: 中等 edit_url: https://github.com/doocs/leetcode/edit/main/lcof/%E9%9D%A2%E8%AF%95%E9%A2%9843.%201%EF%BD%9En%E6%95%B4%E6%95%B0%E4%B8%AD1%E5%87%BA%E7%8E%B0%E7%9A%84%E6%AC%A1%E6%95%B0/README.md 面试题 43. 1 ~ n 整数中 1 …

(postman)接口测试进阶实战

1.内置和自定义的动态参数 内置的动态参数有哪些? ---{{$}}--是内置动态参数的标志 //自定义的动态参数 此处date.now()的作用就相当于上面的timestamp 2.业务闭环及文件接口测试 返回的url地址可以在网页中查询得到。 3. 常规断言,动态参数断言&#xf…

Linux进程初识:OS基础、fork函数创建进程、进程排队和进程状态讲解

目录 1、冯诺伊曼体系结构 问题一:为什么在体系结构中存在存储器(内存)? 存储单元总结: 问题二:为什么程序在运行的时候,必须把程序先加载到内存? 问题三:请解释&am…

爆改YOLOv8|利用yolov10的SCDown改进yolov8-下采样

1, 本文介绍 YOLOv10 的 SCDown 方法来优化 YOLOv8 的下采样过程。SCDown 通过点卷积调整通道维度,再通过深度卷积进行空间下采样,从而减少了计算成本和参数数量。这种方法不仅降低了延迟,还在保持下采样过程信息的同时提供了竞争性的性能。…

使用Python通过字节串或字节数组加载和保存PDF文档

处理PDF文件的可以直接读取和写入文件系统中的PDF文件,然而,通过字节串(byte string)或字节数组(byte array)来加载和保存PDF文档在某些情况下更高效。这种方法不仅可以提高数据处理的灵活性,允…

Mysql8客户端连接异常:Public Key Retrieval is not allowed

mysql 8.0 默认使用 caching_sha2_password 身份验证机制 (即从原来mysql_native_password 更改为 caching_sha2_password。) 从 5.7 升级 8.0 版本的不会改变现有用户的身份验证方法,但新用户会默认使用新的 caching_sha2_password 。 客户…

ISO26262和Aspice之间的关联

ASPICE 介绍: ASPICE(Automotive Software Process Improvement and Capability dEtermination)是汽车软件过程改进及能力评定的模型,它侧重于汽车软件的开发过程。ASPICE 定义了一系列的过程和活动,包括需求管理、软…

基于yolov8的抽烟检测系统python源码+onnx模型+评估指标曲线+精美GUI界面

【算法介绍】 基于YOLOv8的抽烟检测系统是一种利用先进深度学习技术实现的实时目标检测系统。该系统采用YOLOv8算法,该算法以其高速度和高精度在目标检测领域脱颖而出。该系统通过训练大量标注好的抽烟行为数据集,使模型能够自动识别和定位视频或图像中…

使用YOLOv10训练自定义数据集之二(数据集准备)

0x00 前言 经过上一篇环境部署的介绍【传送门】,我们已经得到了一个基本可用的YOLOv10的运行环境,还需要我们再准备一些数据,用于模型训练。 0x01 准备数据集 1. 图像标注工具 数据集是训练模型基础素材。 对于小白来说,一般…

如何判断小程序是运行在“企业微信”中的还是运行在“微信”中的?

如何判断小程序是运行在“企业微信”中的还是运行在“微信”中的? 目录 如何判断小程序是运行在“企业微信”中的还是运行在“微信”中的? 一、官方开发文档 1.1、“微信小程序”开发文档的说明 1.2、“企业微信小程序”开发文档的说明 1.3、在企业…

redis缓存预热、缓存穿透的详细教程

前言 作此篇主要在于关于redis的缓存预热、缓存雪崩、缓存击穿和缓存穿透在面试中经常遇到,工作中也是经常遇到。中级程序员基本上不可避免要克服的几个问题,希望一次性解释清楚 缓存预热 MySQL加入新增100条记录,一般默认以MySQL为准为底单…

5-2 检测内存容量

1 使用的是bios 中断, 每次进行检测都会返回一块 内容。并且标志上,这块内存是否可用。 接下来是代码: 首先是构建 一个文件夹, 两个文件。 types.h 的内容。 #ifndef TYPES_H #define TYPES_H// 基本整数类型,下面的…

C++系统教程002-数据类型(01)

一、数据类型 学习一门编程语言,首先要掌握它的数据类型。不同的数据类型占用的内存空间不同,定义数据类型合理在一定程度上可以优化程序的运行。本次主要介绍C中常见的数据类型及数据的输入与输出格式。本章知识架构及重难点如下: &#xf…

linux监听网速

方法一 tcpdump -i ens33 -w - | pv -bert > /dev/null方法二

问题 J: 数据结构基础33-查找二叉树

题目描述 已知一棵二叉树用邻接表结构存储&#xff0c;中序查找二叉树中值为x的结点&#xff0c;并指出是第几个结点。例&#xff1a;如图二叉树的数据文件的数据格式如下 输入 第一行n为二叉树的结点个树&#xff0c;n<100&#xff1b;第二行x表示要查找的结点的值&…

windows环境安装OceanBase数据库并创建表、插入数据

windows环境安装OceanBase数据库并创建表、插入数据 前言&#xff1a;OceanBase数据库目前不支持直接在Windows环境下安装&#xff0c;安装比较麻烦&#xff0c;记录一下安装过程 1.安装方案 根据官方文档&#xff1a;https://www.oceanbase.com/docs/common-oceanbase-databa…

实验六 异常处理

实验目的及要求 目的&#xff1a;了解异常的概念&#xff0c;掌握异常处理的方法&#xff0c;掌握throws与throw关键字的区别与联系&#xff0c;掌握自定义异常的方法及用途。 要求&#xff1a; &#xff08;1&#xff09;编写程序了解程序中可能出现的运行时异常与非运行时…

摆花 NOIP2012普及组

目录 思路 代码 思路 代码 #include <iostream> #include <algorithm>using namespace std; using LL long long;const int N 1e2 9; const int mod 1e6 7;int n,m; LL a[N]; LL f[N][N];void solve() {cin >> n >> m;f[0][0] 1;for (int i 1;…

Jmeter模拟用户登录时获取token如何跨线程使用?

一、用户定义的变量 1、添加"用户定义的变量" 2、填写"host、port" 二、setUp线程组 1、添加"setUp线程组" 2、设置循环次数"100" 三、CSV 数据文件设置 1、添加"CSV 数据文件设置" 2、填写信息"用户登录数据.csv、…