Week-T11-优化器对比试验

news2024/9/24 5:29:42

文章目录

  • 一、准备环境
  • 二、准备数据
  • 三、搭建训练网络
  • 三、训练模型
    • (1)VSCode训练情况:
    • (2)`jupyter notebook`训练情况:
  • 四、模型评估 & 模型预测
    • 1、绘制Accuracy-Loss图
    • 2、显示model2的预测效果
  • 五、总结
    • 1、`plt.savefig("./数据展示.jpg")`保存的图片在文件夹内打开是空白的,如下图所示:
    • 2. 优化器是什么?包括哪些?

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本文主要探究不同优化器、以及不同参数配置对模型的影响,最终对Adam、SGD优化器进行比较,并绘制比较结果。

使用的数据集为咖啡豆数据集,共有四类。

优化器常用的有Adam、SGD。优化器的归纳将放在文末的总结部分。

本文将使用Adam优化器的模型命名为"model1",使用SGD优化器的模型命名为"model2",然后根据模型训练结果绘制各自的Accuracy-Loss图。比较得出,在运行环境、epoch次数相同、模型结构相同等条件下,Adam优化器的整体情况要优于SGD优化器。

一、准备环境

# 1. 设置环境
import sys
import tensorflow as tf
from datetime import datetime

from tensorflow          import keras
import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import warnings,os,PIL,pathlib

print("---------------------1.配置环境------------------")
print("Start time: ", datetime.today())
print("tensorflow version: " + tf.__version__)
print("Python version: " + sys.version)

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
        # 打印显卡信息,确认GPU可用
    print("GPU: " + gpus)
else:
    print("Using CPU")

warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

在这里插入图片描述

Q1: VSCode虚拟环境安装pandas
在这里插入图片描述

二、准备数据

# 2.导入数据
# 本次使用咖啡豆数据集(共4类)
print("---------------------2.1 从本地读取数据------------------")
data_dir    = "D:/jupyter notebook/DL-100-days/datasets/coffebeans-data"
data_dir    = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

batch_size = 16
img_height = 336
img_width  = 336

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
print("---------------------2.2 划分训练数据------------------")
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
print("---------------------2.3 划分验证数据------------------")
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

print("---------------------2.4 打印数据类别 && 数据的shape------------------")
class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

print("---------------------2.5 配置数据集------------------")
AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

print("---------------------2.6 数据可视化,显示部分样本图片------------------")
plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()
plt.savefig("./数据展示.jpg")

在这里插入图片描述
在这里插入图片描述

Q2:plt.savefig("./数据展示.jpg")保存的图片在文件夹内打开是空白的

三、搭建训练网络

print("---------------------3. 搭建训练网络,此处预训练模型调用VGG-16官方模型------------------")
# 自定义一个创建模型的函数,形参是优化器类型,预训练模型是VGG-16,但屏蔽了自带的训练部分以及顶层,然后对输出进行处理
# 在此处创建了两个网络,拥有不同的优化器类型
from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')
    for layer in vgg16_base_model.layers:
        layer.trainable = False

    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)

    output = Dense(len(class_names), activation='softmax')(X)
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)

    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

在这里插入图片描述

三、训练模型

print("---------------------4.启动训练,epoch==50------------------")
# try:加入早停试一下,一个epoch跑完要220s,时间还是有点久
NO_EPOCHS = 50

history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

(1)VSCode训练情况:

model1.fit():Adam优化器
在这里插入图片描述
model2.fit():SGD优化器
在这里插入图片描述

(2)jupyter notebook训练情况:

model1.fit():即Adam优化器
在这里插入图片描述
model2.fit():即SGD优化器
在这里插入图片描述

四、模型评估 & 模型预测

1、绘制Accuracy-Loss图

print("---------------------5.1 模型评估,绘制Accuracy-Loss图------------------")
from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))
plt.savefig("./Accuracy-Loss图.jpg")
plt.show()

plt.show()显示的图片:
请添加图片描述
比较Accuracy图表,可以看出训练时Adam优化器的表现要稍优于SGD优化器,而验证时则相反。

Q: VSCode绘制出来的图咋这么奇怪?
改变plt.savefig("./Accuracy-Loss图.jpg")的位置后所保存的图片,比直接plt.show()的图片比例要好些。
在这里插入图片描述

2、显示model2的预测效果

print("---------------------5.2 模型预测------------------")
def test_accuracy_report(model):
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model2)

VSCode环境下的预测结果:
在这里插入图片描述
jupyter notebook环境下的预测结果:
在这里插入图片描述

五、总结

1、plt.savefig("./数据展示.jpg")保存的图片在文件夹内打开是空白的,如下图所示:

在这里插入图片描述
将保存的语句放在plt.show()之前,因为plt.show()之后会默认打开一个空白画板。

2. 优化器是什么?包括哪些?

(参考文章也是来自训练营文章)

优化器是什么?

  • 优化器是一种算法,它在模型优化过程中,动态地调整梯度的大小和方向,使模型能够收敛到更好的位置,或者用更快的速度进行收敛。
  • 各类优化器方法总结如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1263749.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java使用263和qq邮箱发邮件

一、添加依赖 <dependency><groupId>com.sun.mail</groupId><artifactId>javax.mail</artifactId><version>1.6.2</version></dependency>二、263邮箱 1&#xff0c;邮箱配置 public static void sendEmail(String host, in…

你敢信?四行Python代码就能知道你那的天气!

今天给大家带来的Python实战项目是四行Python代码获取所在城市的天气预报&#xff0c;我们隐隐听到唏嘘声&#xff0c;不信四行Python代码可以获取是吗?那我们一起来看看&#xff1a; 四行Python代码就能知道你那的天气&#xff0c;你敢信&#xff1f; 使用Python获取天气预报…

Linux内存回收:LRU算法

linux操作系统再内存不足时会使用Swap机制&#xff0c;将一些不经常使用的匿名内存页放到磁盘当中&#xff0c;等下次需要时再读取到内存当中&#xff0c;而这个LRU算法就是用来选择把哪些不常使用的匿名内存页放到磁盘当中的。 LRU&#xff08;Least Recently Used&#xff09…

Gossip协议理解

概述 Gossip协议&#xff0c;又称epidemic协议&#xff0c;基于流行病传播方式的节点或进程之间信息交换的协议&#xff0c;在分布式系统中被广泛使用。 在1987年8月由施乐-帕洛阿尔托研究中心发表ACM上的论文《Epidemic Algorithms for Replicated Database Maintenance》中…

CorelDRAW Graphics Suite2023破解版含2024最新注册机下载

CorelDRAW Graphics Suite2023是Corel公司的平面设计软件&#xff1b;该软件是Corel出品的矢量图形制作工具软件&#xff0c;这个图形工具给设计师提供了矢量动画、页面设计、网站制作、位图编辑和网页动画等多种功能。在日常科研绘图中&#xff0c;若较为轻量&#xff0c;通常…

【Redis实现全局唯一ID】

一、全局唯一ID的需求产生。 在订单业务中&#xff0c;我们需要保证id是绝对唯一的。 使用数据库自增长的id在分布式的情况下把表做了拆分处理后有可能会出现id重复的情况&#xff0c;这就违背了唯一性。而且数据自增长的id有很强的规律性&#xff0c;可以根据id推断出订单的数…

人工智能_机器学习053_支持向量机SVM目标函数推导_SVM条件_公式推导过程---人工智能工作笔记0093

然后我们再来看一下支持向量机SVM的公式推导情况 来看一下支持向量机是如何把现实问题转换成数学问题的. 首先我们来看这里的方程比如说,中间的黑线我们叫做l2 那么上边界线我们叫l1 下边界线叫做l3 如果我们假设l2的方程是上面这个方程WT.x+b = 0 那么这里 我们只要确定w和…

Anaconda离线下载torch与安装包

一、下载离线安装包 命令&#xff1a; pip download 安装包名 -d 安装到文件夹名 -i https://pypi.tuna.tsinghua.edu.cn/simple执行这样的命令就会把安装包的离线文件下载到指定文件夹中。 操作&#xff1a; 打开cmd命令行&#xff0c;并进入相应的目录中。 如果是tor…

hutool工具连接数据库实现数据处理重新入库

1 引入依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.18</version></dependency><!--mysql驱动包--><dependency><groupId>mysql</groupId><ar…

详解原生Spring中的控制反转和依赖注入-构造注入和Set注入

&#x1f609;&#x1f609; 学习交流群&#xff1a; ✅✅1&#xff1a;这是孙哥suns给大家的福利&#xff01; ✨✨2&#xff1a;我们免费分享Netty、Dubbo、k8s、Mybatis、Spring...应用和源码级别的视频资料 &#x1f96d;&#x1f96d;3&#xff1a;QQ群&#xff1a;583783…

KMP算法【数据结构】

KMP算法 KMP算法是一种改进的字符串匹配算法 Next[j] k :一个用来存放子串返回位置的数组&#xff0c;回溯的位置用字母k来表示。其实就是从匹配失败位置&#xff0c;找到他前面的字符串的最大前后相等子串长度。默认第一个k值为-1(Next[0] -1),第二个k值为0(Next[1] 0),我…

C++ 背包理论基础01 + 滚动数组

背包问题的重中之重是01背包 01背包 有n件物品和一个最多能背重量为w 的背包。第i件物品的重量是weight[i]&#xff0c;得到的价值是value[i] 。每件物品只能用一次&#xff0c;求解将哪些物品装入背包里物品价值总和最大。 每一件物品其实只有两个状态&#xff0c;取或者不…

E云管家开发自动转发朋友圈

简要描述&#xff1a; 转发朋友圈&#xff0c;直接xml数据。(对谁不可见) 请求URL&#xff1a; http://域名地址/forwardSns 请求方式&#xff1a; POST 请求头Headers&#xff1a; Content-Type&#xff1a;application/jsonAuthorization&#xff1a;login接口返回 参…

手机便签app哪个比较好用?

手机便签类软件的种类是比较多的&#xff0c;不管是安卓手机品牌还是苹果手机品牌的手机&#xff0c;在手机的应用商店中搜索“便签”&#xff0c;大家会找到很多便签类软件。那么&#xff0c;手机便签APP哪个比较好用呢&#xff1f; 在选择手机便签APP时&#xff0c;大家比较…

STM32F103C8T6第7天:

1. 智能小车&#xff1a;让小车动起来&#xff08;360.64&#xff09; 硬件接线 B-2A – PB0B-1A – PB1A-1B – PB2A-1A – PB10其余接线参考上官一号小车项目。 cubemx配置 代码&#xff08;28.smartCar_project1/MDK-ARM&#xff09; 2. 智能小车&#xff1a;串口控制小…

python爬虫进阶教程之如何正确的使用cookie

文章目录 前言一、获取cookie二、程序实现三、动态获取cookie四、其他关于Python爬虫技术储备一、Python所有方向的学习路线二、Python基础学习视频三、精品Python学习书籍四、Python工具包项目源码合集①Python工具包②Python实战案例③Python小游戏源码五、面试资料六、Pytho…

自学成为android framework工程师需要准备哪些装备-千里马车载车机系统开发学习

背景 hi&#xff0c;粉丝朋友们&#xff1a; 大家好&#xff01;经常有很多学员买课同学都会问到需要准备哪些装备&#xff0c;我也回答了很多学员了&#xff0c;今天就搞一篇文章来统一说明一下&#xff0c;告诉一下大家如果你想从一个framework新手变成一个framework开发的高…

微服务--04--SpringCloudGateway 网关

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 1.网关路由1.1 认识网关在SpringCloud当中&#xff0c;提供了两种网关实现方案&#xff1a; 1.2.快速入门1.3.路由过滤 2.网关登录校验2.1.鉴权思路分析2.2.网关过滤…

vue+echarts实现依赖关系无向网络拓扑结图节点折叠展开策略

目录 引言 一、设计 1. 树状图&#xff08;不方便呈现节点之间的关系&#xff0c;次要考虑&#xff09; 2. 力引导依赖关系图 二、力引导关系图 三、如何实现节点的Open Or Fold 1. 设计逻辑 节点展开细节 节点收缩细节 代码实现 四、结果呈现 五、完整代码 引言 我…

使用Kafka、Flink、Druid构建实时数据系统架构

1. 背景 对于很多数据团队来说&#xff0c;要满足实时需求并不容易。为什么&#xff1f;因为作流程&#xff08;数据采集、预处理、分析、结果保存&#xff09;涉及大量等待。等待数据发送到 ETL 工具&#xff0c;等待数据批量处理&#xff0c;等待数据加载到数据仓库中&#…