【优化算法】使用遗传算法优化MLP神经网络参数(TensorFlow2)

news2024/12/26 23:30:58

文章目录

  • 任务
  • 查看当前的准确率情况
  • 使用遗传算法进行优化
  • 完整代码

任务

使用启发式优化算法遗传算法对多层感知机中中间层神经个数进行优化,以提高模型的准确率。

待优化的模型:
基于TensorFlow2实现的Mnist手写数字识别多层感知机MLP

# MLP手写数字识别模型,待优化的参数为layer1、layer2
model = tf.keras.Sequential([
   tf.keras.layers.Flatten(input_shape=(28, 28, 1)), 
   tf.keras.layers.Dense(layer1, activation='relu'),
   tf.keras.layers.Dense(layer2, activation='relu'),
   tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字
])

查看当前的准确率情况

设置随机树种子,避免相同结构的神经网络其结果不同的影响。

random_seed = 10
np.random.seed(random_seed)
tf.random.set_seed(random_seed)
random.seed(random_seed)
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28, 1)), 
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字
])
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer,loss=loss_func,metrics=['acc'])
history = model.fit(dataset, validation_data=test_dataset, epochs=5, verbose=1)
score = model.evaluate(test_dataset)
print("测试集准确率:",score) # 输出 [损失率,准确率]
Epoch 1/5
235/235 [==============================] - 1s 5ms/step - loss: 0.4663 - acc: 0.8703 - val_loss: 0.2173 - val_acc: 0.9361
Epoch 2/5
235/235 [==============================] - 1s 4ms/step - loss: 0.1882 - acc: 0.9465 - val_loss: 0.1604 - val_acc: 0.9521
Epoch 3/5
235/235 [==============================] - 1s 4ms/step - loss: 0.1362 - acc: 0.9608 - val_loss: 0.1278 - val_acc: 0.9629
Epoch 4/5
235/235 [==============================] - 1s 4ms/step - loss: 0.1084 - acc: 0.9689 - val_loss: 0.1086 - val_acc: 0.9681
Epoch 5/5
235/235 [==============================] - 1s 4ms/step - loss: 0.0878 - acc: 0.9740 - val_loss: 0.1077 - val_acc: 0.9675
40/40 [==============================] - 0s 2ms/step - loss: 0.1077 - acc: 0.9675
测试集准确率: [0.10773944109678268, 0.9674999713897705]

准确率为96.7%

使用遗传算法进行优化

使用scikit-opt提供的遗传算法库进行优化。(pip install scikit-opt
官方文档:https://scikit-opt.github.io/scikit-opt/#/zh/README

# 遗传算法调用
ga = GA(func=loss_func, n_dim=2, size_pop=4, max_iter=3, prob_mut=0.15, lb=[10, 10], ub=[256, 256], precision=1)

优化目标函数loss_func:1 - 模型准确率(求优化目标函数的最小值
待优化的参数数量n_dim:2
种群数量size_pop:4
迭代次数max_iter:3
变异概率prob_mut:0.15
待优化两个参数的取值范围均为10-256
精确度precision:1

# 定义多层感知机模型函数
def mlp_model(layer1, layer2):
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28, 1)), 
        tf.keras.layers.Dense(layer1, activation='relu'),
        tf.keras.layers.Dense(layer2, activation='relu'),
        tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字
    ])
    optimizer = tf.keras.optimizers.Adam()
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
    model.compile(optimizer=optimizer,loss=loss_func,metrics=['acc'])
    return model

# 定义损失函数
def loss_func(params):
    random_seed = 10
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    random.seed(random_seed)
    layer1 = int(params[0])
    layer2 = int(params[1])
    model = mlp_model(layer1, layer2)
    history = model.fit(dataset, validation_data=test_dataset, epochs=5, verbose=0)
    return 1 - history.history['val_acc'][-1]


ga = GA(func=loss_func, n_dim=2, size_pop=4, max_iter=3, prob_mut=0.15, lb=[10, 10], ub=[256, 256], precision=1)
bestx, besty = ga.run()
print('bestx:', bestx, '\n', 'besty:', besty)

结果:

bestx: [165. 155.] 
besty: [0.02310002]

通过迭代,找到layer1、layer2的最好值为165、155,此时准确率为1-0.0231=0.9769。

查看每轮迭代的损失函数值图(1-准确率):

Y_history = pd.DataFrame(ga.all_history_Y)
fig, ax = plt.subplots(2, 1)
ax[0].plot(Y_history.index, Y_history.values, '.', color='red')
Y_history.min(axis=1).cummin().plot(kind='line')
plt.show()

在这里插入图片描述
上图为三次迭代种群中,种群每个个体的损失函数值(每个种群4个个体)。
下图为三次迭代种群中,种群个体中的最佳损失函数值。

可以看出,通过遗传算法,其准确率有一定的提升。

增加种群数量及迭代次数效果可更加明显。

完整代码

# python3.6.9
import tensorflow as tf # 2.3
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
from sko.GA import GA

# 数据导入,获取训练集和测试集
(train_image, train_labels), (test_image, test_labels) = tf.keras.datasets.mnist.load_data()

# 增加通道维度
train_image = tf.expand_dims(train_image, -1)
test_image = tf.expand_dims(test_image, -1)

# 归一化 类型转换
train_image = tf.cast(train_image/255, tf.float32)
test_image = tf.cast(test_image/255, tf.float32)
train_labels = tf.cast(train_labels, tf.int64)
test_labels = tf.cast(test_labels, tf.int64)

# 创建Dataset
dataset = tf.data.Dataset.from_tensor_slices((train_image, train_labels)).shuffle(60000).batch(256)
test_dataset = tf.data.Dataset.from_tensor_slices((test_image, test_labels)).batch(256)

# 定义多层感知机模型函数
def mlp_model(layer1, layer2):
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28, 1)), 
        tf.keras.layers.Dense(layer1, activation='relu'),
        tf.keras.layers.Dense(layer2, activation='relu'),
        tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字
    ])
    optimizer = tf.keras.optimizers.Adam()
    loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
    model.compile(optimizer=optimizer,loss=loss_func,metrics=['acc'])
    return model

# 定义损失函数
def loss_func(params):
    random_seed = 10
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    random.seed(random_seed)
    layer1 = int(params[0])
    layer2 = int(params[1])
    model = mlp_model(layer1, layer2)
    history = model.fit(dataset, validation_data=test_dataset, epochs=5, verbose=0)
    return 1 - history.history['val_acc'][-1]


ga = GA(func=loss_func, n_dim=2, size_pop=4, max_iter=3, prob_mut=0.15, lb=[10, 10], ub=[256, 256], precision=1)
bestx, besty = ga.run()
print('bestx:', bestx, '\n', 'besty:', besty)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/421874.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java支付SDK接口远程调试 - 支付宝沙箱环境【公网地址调试】

文章目录1.测试环境2.本地配置3. 内网穿透3.1 下载安装cpolar内网穿透3.2 创建隧道4. 测试公网访问5. 配置固定二级子域名5.1 保留一个二级子域名5.2 配置二级子域名6. 使用固定二级子域名进行访问转发自CSDN远程穿透的文章:Java支付宝沙箱环境支付,SDK接…

Linux命令·traceroute

通过traceroute我们可以知道信息从你的计算机到互联网另一端的主机是走的什么路径。当然每次数据包由某一同样的出发点(source)到达某一同样的目的地(destination)走的路径可能会不一样,但基本上来说大部分时候所走的路由是相同的。linux系统…

移动端项目开发总结(一)

移动端项目开发总结(一) 前阵子做租赁项目,风风火火的上线,趁现在还没忘,把用到的东西整理以下,算是对于这个项目的回顾吧。 特效一 : 移动端适配 需求 移动端适配,采用rem单位。…

深入理解Java虚拟机——Java内存区域

1.前言 Java内存区域也叫运行时数据区域,要记得把Java内存模型(JMM区分开来)。 根据线程是否共享可以把运行时数据区如上图所分。 线程共享 堆内存方法区 线程私有 栈内存 本地方法栈虚拟机栈 程序计数器 接下来,将逐个介绍…

什么是文件传输协议,文件传输协议又是怎么工作的

文件传输协议FTP是一种仍在使用的协议,在上载和下载文件时仍然比较流行,通常是那些太大的文件,需要花费很长时间才能通过常规电子邮件程序作为附件下载进行传输。 从技术上讲,它是“文件传输实用程序”,是许多TCP / I…

腾讯云4核8G12M轻量服务器配置性能评测

腾讯云轻量4核8G12M服务器,之前是4核8G10M配置,现在公网带宽和月流量包整体升级,12M公网带宽下载速度可达1536KB/秒,系统盘为180GB SSD盘,每月2000GB免费流量,腾讯云百科来详细说下4核8G12M轻量应用服务器配…

碳化硅材料在功率半导体中的优劣

开关电源工作频率的提高受到开关损耗的制约 开关电源的工作频率是指开关变换器操作的频率。在开关电源中,一个开关变换器被用来将直流(DC)能源转换为可用于电子设备的交流(AC)能源。开关变换器的基本原理是通过对开关…

3.4 函数的单调性和曲线的凹凸性

学习目标: 如果我要学习函数的单调性和曲线的凹凸性,我会采取以下几个步骤: 理解概念和定义:首先,我会学习单调性和凹凸性的定义和概念。单调性是指函数的增减性质,可以分为单调递增和单调递减&#xff1b…

Python使用PyQt5实现指定窗口置顶

文章目录前言一、网上找到的代码二、尝试与借鉴后的代码——加入PyQt界面1.引入库2.主代码3.完整主代码4.UI界面代码总结前言 工作中,同事随口提了一句:要是能让WPS窗口置顶就好了,老是将窗口切换来切换去的太麻烦了。 然后,这个…

docker-compose 安装nginx php mysql phpadmin

一 摘要 本文主要介绍基于docker docker-compose 安装 lnmp 三件套,以及用phpmysadmin 验证下部署可正确。 二 环境信息 2.1 操作系统 [root2023001 ~]# cat /etc/centos-release CentOS Linux release 7.9.2009 (Core) [root2023001 ~]#2.2 docker [root20230…

【opencv】图像数字化——认识OpenCV中的Mat类( 7 访问多通道Mat对象中的值)

7 访问多通道Mat对象中的值 7.1使用成员函数at() #include <opencv2/core/core.hpp> #include<iostream> using namespace std; using namespace cv; int main() {Mat mm (Mat_<Vec3f>(2, 2) << Vec3f(1, 11, 21), Vec3f(2, 12, 32), Vec3f(3, …

C++【深入理解多态】

文章目录一、多态概念与实现&#xff08;1&#xff09;多态的概念&#xff08;2&#xff09;怎么构成多态&#xff08;3&#xff09;虚函数重写的2个例外&#xff08;4&#xff09;经典剖析巩固知识点&#xff08;5&#xff09; override 和 final&#xff08;6&#xff09;小总…

YOLO算法改进指南【初阶改进篇】:2.改进DIoU-NMS,SIoU-NMS,EIoU-NMS,CIoU-NMS,GIoU-NMS

非极大值抑制(Non-maximum Suppression (NMS))的作用简单说就是模型检测出了很多框,我应该留哪些。 本篇将演示如何修改:NMS、Merge-NMS、Soft-NMS、CIoU-NMS、DIoU-NMS、GIoU-NMS、EIoU-NMS、SIoU-NMS 1. NMS过程 NMS过程 For a prediction bounding box B, the model c…

基于JDK11从源码角度剖析可重入锁ReentrantLock的获取锁和解锁

ReentrantLock是可重入的独占锁&#xff0c;同时只能有一个线程可以获取该锁&#xff0c;其他获取该锁的线程会被阻塞而被放入该锁的AQS阻塞队列里面。 ReentrantLock是JUC包提供的显式锁的一个基础实现类&#xff0c;实现了Lock接口。我们先来看下ReentrantLock的类图&#x…

SpringBoot WebSocket服务端创建

引入maven <!--websocket--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency>新建WebSocket配置文件 import org.springframework.context.annotatio…

【蓝桥杯嵌入式】第十四届蓝桥杯嵌入式省赛(第一场)客观题及详细题解

题1 解析  编码器&#xff0c;具有编码功能的逻辑电路&#xff0c;能将每一个编码输入信号变换为不同的二进制的代码输出&#xff0c;是一个组合逻辑电路。 答案 ABC 题2 解析   减法计数器的计数值到0时&#xff0c;会产生一个重装载值&#xff0c;此处重载后就会变成111…

改进YOLO系列:CVPR2023最新 PConv |提供 YOLOv5 / YOLOv8 模型 YAML 文件

论文链接:https://arxiv.org/pdf/2303.03667v2.pdf 一、论文介绍 为了设计快速神经网络,许多工作都集中在减少浮点运算(FLOPs)的数量上。然而,作者观察到FLOPs的这种减少不一定会带来延迟的类似程度的减少。这主要源于每秒低浮点运算(FLOPS)效率低下。 为了实现更快的…

buildSrc + gradle插件:多项目共享gradle依赖管理

自定义gradle 插件&#xff0c;配合 buildSrc 形式的组件库版本管理&#xff0c; 用于实现多 project 项目共享一套版本管理信息 前言 随着组件化越来越常见&#xff0c;module数量越来越多&#xff0c;依赖管理的混乱问题大家想必是都遇到过甚至正在经历着。 对于依赖管理的…

iOS - 接入 Live2D

1.安装 Cmake 1.1 从官方下载 https://cmake.org/download/ 下载成功以后,在终端输入 sudo "/Applications/CMake.app/Contents/bin/cmake-gui" --install校验是否成功 cmake --version1.2 从 Homebrew 安装 (这个方法没有成功) brew install cmake如果提示 co…

简单的配置Sawgger+knife4j完成API测试功能

目的&#xff1a;减少postman的使用&#xff0c;以及生成对应的接口文档 1、添加依赖 基于自身spring boot 版本2.7.X 我选择的是&#xff1a; <dependency><groupId>io.springfox</groupId><artifactId>springfox-boot-starter</artifactId>…