深度学习训练营之优化器对比

news2025/1/15 21:02:53

深度学习训练营之优化器对比

  • 原文链接
  • 环境介绍
  • 前置工作
    • 设置GPU
  • 数据处理
    • 导入数据
    • 数据集处理
    • 数据集可视化
  • 模型构造
  • 模型训练
  • 结果可视化

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第11周-优化器对比实验
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.13
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前置工作

设置GPU

如果

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

from tensorflow          import keras
import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import warnings,os,PIL,pathlib

warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

数据处理

导入数据

import pathlib

data_dir = "./29-data/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

图片总数为: 1462

数据集处理

batch_size = 16#批量大小
img_height = 224
img_width  = 224
##在导入数据的过程当中打乱数据位置
train_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=24,#随机数种子
    image_size=(img_height,img_width),
    batch_size=batch_size
)

Found 1462 files belonging to 9 classes.
Using 1170 files for training.

##在导入数据的过程当中打乱数据位置
val_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=24,#随机数种子
    image_size=(img_height,img_width),
    batch_size=batch_size
)

Found 1462 files belonging to 9 classes.
Using 292 files for validation.

class_names=train_ds.class_names
print("数据类型有:",class_names)
print("需要识别的船有%d类"%len(class_names))

数据类型有: [‘buoy’, ‘cruise ship’, ‘ferry boat’, ‘freight boat’, ‘gondola’, ‘inflatable boat’, ‘kayak’, ‘paper boat’, ‘sailboat’]
需要识别的船有9类

for image_batch,labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

(16, 224, 224, 3)
(16,)

数据集可视化

AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .map(train_preprocessing)    # 这里可以设置预处理函数
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .map(train_preprocessing)    # 这里可以设置预处理函数
    .prefetch(buffer_size=AUTOTUNE)
)
plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

请添加图片描述

模型构造

##对比不同优化器
from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')
    for layer in vgg16_base_model.layers:
        layer.trainable = False

    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)

    output = Dense(len(class_names), activation='softmax')(X)
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)

    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
58889256/58889256 [==============================] - 60s 1us/step
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
...
Total params: 14,804,117
Trainable params: 89,089
Non-trainable params: 14,715,028
_________________________________________________________________
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

模型训练

开始进行模型训练

NO_EPOCHS = 50

history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
Epoch 1/50
74/74 [==============================] - 82s 1s/step - loss: 1.6497 - accuracy: 0.4966 - val_loss: 1.4824 - val_accuracy: 0.6507
Epoch 2/50
74/74 [==============================] - 78s 1s/step - loss: 0.9829 - accuracy: 0.7043 - val_loss: 1.1832 - val_accuracy: 0.6952
Epoch 3/50
74/74 [==============================] - 78s 1s/step - loss: 0.8367 - accuracy: 0.7316 - val_loss: 0.9519 - val_accuracy: 0.7089
Epoch 4/50
74/74 [==============================] - 78s 1s/step - loss: 0.7420 - accuracy: 0.7684 - val_loss: 0.8481 - val_accuracy: 0.7021
Epoch 5/50
74/74 [==============================] - 79s 1s/step - loss: 0.6643 - accuracy: 0.7880 - val_loss: 0.8094 - val_accuracy: 0.7568
Epoch 6/50
74/74 [==============================] - 81s 1s/step - loss: 0.6044 - accuracy: 0.8060 - val_loss: 0.7265 - val_accuracy: 0.7705
Epoch 7/50
74/74 [==============================] - 81s 1s/step - loss: 0.5680 - accuracy: 0.8094 - val_loss: 0.7506 - val_accuracy: 0.7226
Epoch 8/50
74/74 [==============================] - 83s 1s/step - loss: 0.5108 - accuracy: 0.8333 - val_loss: 0.7361 - val_accuracy: 0.7534
Epoch 9/50
74/74 [==============================] - 84s 1s/step - loss: 0.4895 - accuracy: 0.8316 - val_loss: 0.8021 - val_accuracy: 0.7603
Epoch 10/50
74/74 [==============================] - 82s 1s/step - loss: 0.4669 - accuracy: 0.8470 - val_loss: 0.7546 - val_accuracy: 0.7568
Epoch 11/50
74/74 [==============================] - 82s 1s/step - loss: 0.4323 - accuracy: 0.8521 - val_loss: 0.8549 - val_accuracy: 0.7226
Epoch 12/50
74/74 [==============================] - 88s 1s/step - loss: 0.4015 - accuracy: 0.8778 - val_loss: 0.7263 - val_accuracy: 0.7671
Epoch 13/50
...
Epoch 49/50
74/74 [==============================] - 82s 1s/step - loss: 0.3593 - accuracy: 0.8880 - val_loss: 0.7675 - val_accuracy: 0.7808
Epoch 50/50
74/74 [==============================] - 81s 1s/step - loss: 0.3503 - accuracy: 0.8872 - val_loss: 0.7342 - val_accuracy: 0.7979
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

结果可视化

绘制两种不同训练方法的结果的图片

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.show()

请添加图片描述

def test_accuracy_report(model):
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model2)
test_accuracy_report(model1)

Loss function: 0.7341989278793335, accuracy: 0.7979452013969421
Loss function: 1.1129000186920166, accuracy: 0.7739726305007935

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/628698.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

deque(简单介绍一下)

deque的基本情况: 简单的来说deque是一个双头队列。且两边的尺寸可以动态收缩或者扩张。 其底层实现相当复杂,而且效率并不高。大多数时候都不会使用。 deque诞生的原因是vector和list的优缺点不可分割。 正好复习一下vector和list的优缺点。 vector的…

手机抓包fiddler配置及使用教程

本文基于Fiddler4讲解基本使用 fiddler抓包原理 注意:Fiddler 是以代理web服务器的形式工作的,它使用代理地址:127.0.0.1,端口:8888。当Fiddler退出的时候它会自动注销,这样就不会影响别的 程序。不过如果Fiddler非正常退出&…

学校热水供应系统方案

学校热水供应系统是现代化校园建设的重要组成部分。一套高效、可靠、安全、环保的热水供应系统,不仅能够满足学生、教职工的日常生活需求,也能提高学校形象和竞争力。 在设计学校热水供应系统方案时,需要考虑以下几个方面: 一、热…

【计算机网络复习之路】运输层(谢希仁第八版)万字详解 主打基础

运输层是OSI七层模型中最重要最关键的一层,是唯一负责总体数据传输和控制的一层。运输层要达到两个主要目的:第一,提供可靠的端到端的通信(“端到端的通信” 是应用进程之间的通信);第二,向会话…

【css】box-sizing属性

box-sizing 是一个 CSS 属性,用于指定元素的总宽度和高度的计算方式。它影响内容框的大小,并可以包括或排除元素的填充、边框和外边距。 box-sizing 属性接受两个值: content-box:这是默认值。它指定元素的宽度和高度只包括内容区…

培训班出来拿17K,入职后8天就被裁了....

最近翻了一些网站的招聘信息,把一线大厂和大型互联网公司看了个遍,发现市场还是挺火热的,虽说铜三铁四,但是软件测试岗位并没有削减多少,建议大家有空还是多关注和多投简历,不要闭门造车,错过好…

电脑重装系统后无法开机是什么原因导致的

电脑重装系统是一种常见的解决问题和提升性能的方法,但有时候重装系统后可能会遇到无法开机的问题。本文将介绍一些常见原因和解决方法,帮助您解决电脑重装系统后无法开机的困扰。 工具/原料: 系统版本:windows7系统 品牌型号&…

HNU-操作系统OS-作业3(26-31章)

OS_homework_3 这份文件是OS_homework_3 by计科210X wolf 202108010XXX 文档设置了目录,可以通过目录快速跳转至答案部分。 第26章 运行该程序wolf/OS-homework/threads-intro/x86.py 通过README可知一些重要的标识符如下 -p:指定程序 -t:线程数 -i:指定每执行i个指令…

BT8918D----GPIO模块

0 Preface/Foreword 1 宏定义和数据结构 1.1 gpio_t typedef struct { psfr_t sfr; //GPIO SFR ADDR u8 num; u8 type; //type 1,高压IO,没有300R的强上下拉电阻。 type 0, 普通IO, 有内部300R上下拉电阻。 } gpio_t; 1.…

LeetCode - 18 四数之和

目录 题目来源 题目描述 示例 提示 题目解析 算法源码 题目来源 18. 四数之和 - 力扣(LeetCode) 题目描述 给你一个由 n 个整数组成的数组 nums ,和一个目标值 target 。请你找出并返回满足下述全部条件且不重复的四元组 [nums[a], n…

selenium.chrome怎么写扩展拦截或转发请求?

Selenium WebDriver 是一组开源 API,用于自动测试 Web 应用程序,利用它可以通过代码来控制chrome浏览器! 有时候我们需要mock接口的返回,或者拦截和转发请求,今天就来实现这个功能。 代码已开源: https:/…

Spring AOP(面向切面编程)详解

文章目录 前言AOP 概述AOP的基本概念AOP 应用场景AOP 常用的接口 Spring AOP 实现方式基于XML的配置基于注解的配置 Spring AOP 的使用一、Spring 的 API 接口实现AOP二、自定义类实现AOP三、注解实现AOP 模拟实现 Spring AOP 前言 Spring 提供了功能强大IOC、AOP等功能&#…

文心一言 VS 讯飞星火 VS chatgpt (36)-- 算法导论5.3 6题

六、请解释如何实现算法 PERMUTE-BY-SORTING,以处理两个或更多优先级相同的情形。也就是说,即使有两个或更多优先级相同,你的算法也应该产生一个均匀随机排列。 文心一言: 算法 PERMUTE-BY-SORTING 是一种基于排序的随机排列算法…

【第二章:物理层】

目录 知识框架No.0 : 引言No.1 :通信基础一、基本概念二、两个公式lim求极限1、失真(失去真实性) 三、编码和调制四、数据交换方式 No.2 :传输介质一、导向型介质二、非导向型介质 No.3 :物理层设备一、中继器二、集线…

selenium:元素定位之xpath、css

元素定位是在做UI自动化测试中最重要的一环,要牢牢掌握定位的方法,才能更有效率的进行UI自动化测试。 常见的元素定位方式: idnametag_nameclass_namelink_textpartial_link_textxpathcss 其中id,name是具有唯一性的&#xff0…

高考攀升小红书热榜!互动量破千万,品牌如何毕业季营销?

光影间,又是一年毕业季,弹指之间,那些青葱岁月如同白驹过隙般悄然从指缝溜走。近期,一年一度的高考、大学毕业又来袭,登上各大社媒平台热搜,成为热门话题;本期,随小编一起运用小红书…

【C++】智能指针 学习总结 |std::shared_ptr |std::unique_ptr | std::weak_ptr

文章目录 前言一、智能指针介绍二、普通指针和智能指针的比较案例三、std::shared_ptr四、std::unique_ptr五、std::weak_ptr六、std::shared_ptr |std::unique_ptr | std::weak_ptr三大智能指针的区别 前言 参考答案:chatgpt 一、智能指针介绍 智能指针是C的一种…

chatgpt赋能python:Python循环执行一个函数:简单而高效的代码实现

Python循环执行一个函数:简单而高效的代码实现 Python是一种高级编程语言,非常流行,不仅因为它易于学习和使用,而且因为它的灵活性。Python编程语言有很多特性,其中包括使用函数模块化编程,这在大型项目中…

图文验证码怎么测试及自动化测试怎么解决验证码问题?

目录 前言 首先,来简单认识下验证码 1、验证码的由来和作用 2、验证码的存储 3、验证码的原理 如何测试验证码? 1、手动测试 2、自动化测试 总结: 前言 在对安全性有要求的软件(系统)中都存在验证码&#xf…

高级查询 — 连接查询

关于northwind 示例数据库 查询数据库中的表 show tables;查询表的表属性 desc xxx(表名);连接查询 1.概述 若一个查询同时涉及两个及以上的表,则称之为连接查询。也可以叫做多表查询。使用join关键字进行多表连接。 2.分类(按功能) 内连…