第T11周:优化器对比实验

news2025/1/19 20:26:26

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

本次主要是探究不同优化器、以及不同参数配置对模型的影响

🚀我的环境:

  • 语言环境:Python3.11.7
  • 编译器:jupyter notebook
  • 深度学习框架:TensorFlow2.13.0

一、设置GPU

import tensorflow as tf
gpus=tf.config.list_physical_devices("GPU")

if gpus:
    gpu0=gpus[0]
    tf.config.experimental.set_memory_growth(gpu0,True)
    tf.config.set_visible_devices([gpu0],"GPU")
    
import warnings
warnings.filterwarnings("ignore")

二、导入数据

1. 导入数据

import pathlib

data_dir="D:\THE MNIST DATABASE\P6-data"
data_dir=pathlib.Path(data_dir)
image_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

运行结果:

图片总数为: 1800

2. 加载数据

加载训练集:

train_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(224,224),
    batch_size=16
)

运行结果:

Found 1800 files belonging to 17 classes.
Using 1440 files for training.

加载验证集:

val_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(224,224),
    batch_size=16
)

运行结果:

Found 1800 files belonging to 17 classes.
Using 360 files for validation.

显示数据集分类情况:

class_names=train_ds.class_names
print(class_names)

运行结果:

['Angelina Jolie', 'Brad Pitt', 'Denzel Washington', 'Hugh Jackman', 'Jennifer Lawrence', 'Johnny Depp', 'Kate Winslet', 'Leonardo DiCaprio', 'Megan Fox', 'Natalie Portman', 'Nicole Kidman', 'Robert Downey Jr', 'Sandra Bullock', 'Scarlett Johansson', 'Tom Cruise', 'Tom Hanks', 'Will Smith']

3. 检查数据

for image_batch,labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

运行结果:

(16, 224, 224, 3)
(16,)

4. 配置数据集

AUTOTUNE=tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds=(train_ds.cache().shuffle(1000).map(train_preprocessing).prefetch(buffer_size=AUTOTUNE))
val_ds=(val_ds.cache().shuffle(1000).map(train_preprocessing).prefetch(buffer_size=AUTOTUNE))

5. 数据可视化

import matplotlib.pyplot as plt

plt.figure(figsize=(15,8))
plt.suptitle("数据展示")

for images,labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        
        #显示图片
        plt.imshow(images[i])
        #显示标签
        plt.xlabel(class_names[labels[i]-1])
        
plt.show()

运行结果:

三、构建模型 

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    #加载预训练模型
    vgg16_base_model=tf.keras.applications.vgg16.VGG16(
        weights='imagenet',
        include_top=False,
        input_shape=(224,224,3),
        pooling='avg'
        )
    for layer in vgg16_base_model.layers:
        layer.trainable=False
        
    x=vgg16_base_model.output
        
    x=Dense(170,activation='relu')(x)
    x=BatchNormalization()(x)
    x=Dropout(0.5)(x)
        
    output=Dense(len(calss_names),activation='softmax')(x)
    vgg16_model=Model(inputs=vgg16_base_model.input,outputs=output)
    
    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1=create_model(optimizer=tf.keras.optimizers.Adam())
model2=create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

运行结果:

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 56, 56, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 56, 56, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 28, 28, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 28, 28, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 28, 28, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 14, 14, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 14, 14, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 7, 7, 512)         0         
                                                                 
 global_average_pooling2d_1  (None, 512)               0         
  (GlobalAveragePooling2D)                                       
                                                                 
 dense_2 (Dense)             (None, 170)               87210     
                                                                 
 batch_normalization_1 (Bat  (None, 170)               680       
 chNormalization)                                                
                                                                 
 dropout_1 (Dropout)         (None, 170)               0         
                                                                 
 dense_3 (Dense)             (None, 17)                2907      
                                                                 
=================================================================
Total params: 14805485 (56.48 MB)
Trainable params: 90457 (353.35 KB)
Non-trainable params: 14715028 (56.13 MB)
_________________________________________________________________

四、训练模型

no_epochs=50

history_model1=model1.fit(train_ds,epochs=no_epochs,verbose=1,validation_data=val_ds)
history_model2=model2.fit(train_ds,epochs=no_epochs,verbose=1,validation_data=val_ds)

运行结果:

Epoch 1/50
90/90 [==============================] - 131s 1s/step - loss: 2.8359 - accuracy: 0.1500 - val_loss: 2.6786 - val_accuracy: 0.1083
Epoch 2/50
90/90 [==============================] - 127s 1s/step - loss: 2.1186 - accuracy: 0.3243 - val_loss: 2.3780 - val_accuracy: 0.3361
Epoch 3/50
90/90 [==============================] - 127s 1s/step - loss: 1.7924 - accuracy: 0.4229 - val_loss: 2.1311 - val_accuracy: 0.4000
Epoch 4/50
90/90 [==============================] - 128s 1s/step - loss: 1.6097 - accuracy: 0.4750 - val_loss: 1.9252 - val_accuracy: 0.4028
……
Epoch 46/50
90/90 [==============================] - 129s 1s/step - loss: 0.1764 - accuracy: 0.9465 - val_loss: 2.7244 - val_accuracy: 0.5528
Epoch 47/50
90/90 [==============================] - 131s 1s/step - loss: 0.1833 - accuracy: 0.9410 - val_loss: 2.3910 - val_accuracy: 0.5278
Epoch 48/50
90/90 [==============================] - 131s 1s/step - loss: 0.2151 - accuracy: 0.9340 - val_loss: 2.8985 - val_accuracy: 0.4389
Epoch 49/50
90/90 [==============================] - 130s 1s/step - loss: 0.1725 - accuracy: 0.9458 - val_loss: 2.3219 - val_accuracy: 0.5306
Epoch 50/50
90/90 [==============================] - 130s 1s/step - loss: 0.1764 - accuracy: 0.9375 - val_loss: 2.9708 - val_accuracy: 0.4972
Epoch 1/50
90/90 [==============================] - 130s 1s/step - loss: 3.0062 - accuracy: 0.1125 - val_loss: 2.7298 - val_accuracy: 0.1778
Epoch 2/50
90/90 [==============================] - 129s 1s/step - loss: 2.4726 - accuracy: 0.2271 - val_loss: 2.5667 - val_accuracy: 0.2250
Epoch 3/50
90/90 [==============================] - 129s 1s/step - loss: 2.2530 - accuracy: 0.2917 - val_loss: 2.3789 - val_accuracy: 0.2972
Epoch 4/50
90/90 [==============================] - 129s 1s/step - loss: 2.0593 - accuracy: 0.3458 - val_loss: 2.1837 - val_accuracy: 0.3194
…………
Epoch 47/50
90/90 [==============================] - 146s 2s/step - loss: 0.6468 - accuracy: 0.8021 - val_loss: 1.5983 - val_accuracy: 0.5194
Epoch 48/50
90/90 [==============================] - 146s 2s/step - loss: 0.6093 - accuracy: 0.8111 - val_loss: 1.6223 - val_accuracy: 0.4972
Epoch 49/50
90/90 [==============================] - 146s 2s/step - loss: 0.6051 - accuracy: 0.7979 - val_loss: 1.6518 - val_accuracy: 0.5139
Epoch 50/50
90/90 [==============================] - 150s 2s/step - loss: 0.6074 - accuracy: 0.8007 - val_loss: 1.6507 - val_accuracy: 0.5167

五、评估模型

1. Accuracy与Loss图

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi']=300 #图片像素
plt.rcParams['figure.dpi']=300 #图片分辨率

acc1=history_model1.history['accuracy']
acc2=history_model2.history['accuracy']
val_acc1=history_model1.history['val_accuracy']
val_acc2=history_model2.history['val_accuracy']

loss1=history_model1.history['loss']
loss2=history_model2.history['loss']
val_loss1=history_model1.history['val_loss']
val_loss2=history_model2.history['val_loss']

epochs_range=range(len(acc1))

plt.figure(figsize=(16,4))

plt.subplot(1,2,1)
plt.plot(epochs_range,acc1,label="Training Accuracy-Adam")
plt.plot(epochs_range,acc2,label="Training Accuracy-SGD")
plt.plot(epochs_range,val_acc1,label="Validation Accuracy-Adam")
plt.plot(epochs_range,val_acc2,label="Validation Accuracy-SGD")
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
#设置刻度间隔,x轴每1一个刻度
ax=plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1,2,2)
plt.plot(epochs_range,loss1,label="Training Loss-Adam")
plt.plot(epochs_range,loss2,label="Training Loss-SGD")
plt.plot(epochs_range,val_loss1,label="Validation Loss-Adam")
plt.plot(epochs_range,val_loss2,label="Validation Loss-SGD")
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
#设置刻度间隔,x轴每1一个刻度
ax=plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.show()

运行结果:

2. 模型评估 

def test_accuracy_report(model):
    score=model.evaluate(val_ds,verbose=0)
    print('Loss function: %s,accuracy:' % score[0],score[1])
    
test_accuracy_report(model2)

运行结果:

Loss function: 1.6506924629211426,accuracy: 0.5166666507720947

六、心得体会

通过本项目的练习,学习如何在不同优化器环境下进行对比实验,通过实验可以筛选出提升准确率的优化器。

当然,以此类推,可以通过修改模型中的各项参数建立各类模型,最后对比各类模型的结果,以求达到最优模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2059542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CSS 布局

CSS 页面布局技术允许我们拾取网页中的元素,并且控制它们相对正常布局流、周边元素、父容器或者主视口/窗口的位置。布局有一下几种 正常布局流display属性弹性盒子网格浮动定位CSS 表格布局多列布局 每种布局都有它们的用途,各有优缺点,相…

CSS伪类选择器和伪元素

伪类(Pseudo-classes) 伪类用于定义元素的特殊状态。它们被添加到选择器中以指定元素在其生命周期的特定状态下的样式。伪类不创建新的文档内容,也不创建新的文档树中的元素。相反,它们提供了一种方法来根据元素的状态来应用样式…

统信UOS系统连接打印机操作步骤

系统版本 操作步骤 首先点击开始菜单 搜索框输入打印,点击打印管理器 点击图下所示的号 按照图下所示,手动查找->输入打印机的ip地址->点击查找 等到如图下所示,出现打印机的时候,选择打印机,然后选择驱动&…

嵌入式AI快速入门课程-K510篇 (第三篇 环境搭建及开发板操作)

第三篇 环境搭建及开发板操作 文章目录 第三篇 环境搭建及开发板操作1.配置VMware使用桥接网卡1.1 vmware设置1.2 虚拟网络编辑器设置 2.安装软件2.2 安装 Windows 软件2.3 使用MobaXterm远程登录Ubuntu2.4 使用FileZilla在Windows和Ubuntu之间传文件2.5编程示例:Ub…

迎接“云+AI”智算时代!生态案例分论坛议程一览 | 2024 龙蜥大会

2024 龙蜥操作系统大会由中国计算机学会开源发展委员会、中关村科学城委员会、海淀区委网信办、中国开源软件推进联盟指导,龙蜥社区主办,阿里云、中兴通讯、Intel、浪潮信息、Arm、中科方德等 24 家理事单位共同承办,主题为“进化重构赴未来”…

海南云亿商务咨询有限公司助力抖音商家破浪前行

在当下这个短视频与直播电商风起云涌的时代,抖音作为头部平台,正以其庞大的用户基数和强大的算法推荐机制,成为众多品牌与商家竞相追逐的新蓝海。而在这片波澜壮阔的海洋中,海南云亿商务咨询有限公司如同一艘稳健的航船&#xff0…

软件测试 —— JMeter 参数化4种方式!

一、JMeter参数化简介 1.JMeter参数化的概念 当使用JMeter进行测试时,测试数据的准备是一项重要的工作。若要求每次迭代的数据不一样时,则需进行参数化,然后从参数化的文件中来读取测试数据。 参数化:是自动化测试脚本的一种常…

【Prettier】代码格式化工具Prettier的使用和配置介绍

前言 前段时间,因为项目的prettier的配置和eslint格式检查有些冲突,在其prettier官网和百度了一些配置相关的资料,在此做一些总结,以备不时之需。 Prettier官网 Prettier Prettier 是一种前端代码格式化工具,支持ja…

从ESG尽职调查、ESG立法与ESG诉讼谈ESG营销(01)

哈佛大学2024年中回顾全球ESG发展近况 作者:哈佛大学 编辑:数字化营销工兵 2024年上半年,环境、社会和治理(ESG)问题以及对方法的不同意见继续成为全球头条新闻。今年年初,公司及其利益相关者在ESG的支持…

AppenTalk | 不止于赛场,巴黎奥运会上的中国AI科技

当地时间8月11日,第33届夏季奥林匹克运动会在巴黎法兰西体育场落下帷幕。本届奥运会,中国体育代表团收获令人振奋的40金27银24铜总计91枚奖牌,其中金牌数更是创下了境外参加奥运会的最佳成绩。 在中国健儿闪耀奥运赛场时,中国AI科…

Transformer系列-4丨DETR模型和代码解析

1 前言 往期的文章中,笔者从网络结构和代码实现角度较为深入地和大家解析了Transformer模型、Vision Transformer模型(ViT)以及BERT模型,其具体的链接如下: 基础Transformer解析 ViT模型与代码解析 BERT模型与代码解…

嵌入式AI快速入门课程-K510篇 (第七篇 系统BSP开发)

第七篇 系统BSP开发 文章目录 第七篇 系统BSP开发1. 嵌入式Linux系统介绍嵌入式Linux系统组成产品形态嵌入式芯片启动流程Linux系统Linux系统框架嵌入式编译环境 2.嵌入式Linux开发准备手册文档开发工具配套硬件工程源码 3.嵌入式Linux开发组成概述编译工具链什么是工具链什么是…

[Linux#43][线程] 死锁 | 同步 | 基于 BlockingQueue 的生产者消费者模型

目录 1. 死锁 解决死锁问题 2. 同步 2.1 条件变量函数 cond 2.2 条件变量的使用: 3.CP 问题--理论 4. 基于 BlockingQueue 的生产者消费者模型 1. 基本概念 2.BlockQueue.hpp 基本设置: 生产关系控制: 消费关系的控制 ⭕思考点 …

公开整理-全国各省AI算力数据集(2000-2024年)

数据来源:本数据来源于,根据显卡HS编码筛选统计后获得时间跨度:2000-2024年数据范围:省级层面数据指标: 由于未发布2015至2016年的数据,因此该年份数据存在缺失。下表仅展示了部分指标及数据 年份 省份…

Mac apache 配置

命令 sudo apachectl -v //查看apache 版本 sudo apachectl -k start //启动apache sudo apachectl -k stop //停止apache sudo apachectl -k restart //重启apache配置 apache 的配置在 /etc/apache2/httpd.conf 默认情况下httpd.conf 为锁定状态,无法编辑 使用…

SAP B1 三大基本表单标准功能介绍-业务伙伴主数据(三)

背景 在 SAP B1 中,科目表、业务伙伴主数据、物料主数据被称为三大基本表单,其中的标准功能是实施项目的基础。本系列文章将逐一介绍三大基本表单各个字段的含义、须填内容、功能等内容。 附上 SAP B1 10.0 的帮助文档:SAP Business One 10…

单片机外部中断+定时器实现红外遥控NEC协议解码

单片机外部中断定时器实现红外遥控NEC协议解码 概述解码过程参考代码 概述 红外(Infrared,IR)遥控,是一种通过调制红外光实现的无线遥控器,常用于家电设备:电视机、机顶盒等等。NEC协议采用PPM(Pulse Position Modulation&#x…

敏感词替换为星号

编写一个函数,接收一个字符串参数,将其中 的敏感词替换为星号,并返回替换后的结果。 def getReplace(s):wordList["阿里巴巴","苹果","亚马逊","京东","字节","脸书"]for word …

月圆之夜梦儿时 贡秋竹唱响游子心声

自今年年初贡秋竹的首支单曲《逐梦》发布以来,其人气和传唱度便一直屡创新高,口碑上佳表现良好,网友们纷纷隔空喊话贡秋竹再发新作。时至今日,久经打磨的贡秋竹全新力作《低头思故乡》在千呼万唤中终于震撼首发! 贡秋竹…

500以内开放式耳机哪款好?五款高性价比开放式耳机推荐

现在很多人会利用休闲时间进行锻炼,增强体质,在锻炼之前很多人会先入手一些运动设备,像慢跑鞋,还有臂环,运动手表等~当然运动耳机肯定也不能少,边运动边听音乐真的是一大享受!但是哪种耳机比较适…