21- 神经网络模型_超参数搜索 (TensorFlow系列) (深度学习)

news2025/1/17 4:57:37

知识要点

  • fetch_california_housing:加利福尼亚的房价数据,总计20640个样本,每个样本8个属性表示,以及房价作为target

  • 超参数搜索的方式: 网格搜索, 随机搜索, 遗传算法搜索, 启发式搜索

  • 函数式添加神经网络:

    • model.add(keras.layers.Dense(layer_size, activation = 'relu'))

    • model.compile(loss = 'mse', optimizer = optimizer)    # optimizer = keras.optimizers.SGD (learning_rate)

    • sklearn_model = KerasRegressor(build_fn = build_model)

from tensorflow.keras.wrappers.scikit_learn import KerasRegressor   # 回归神经网络
# 搜索最佳学习率
def build_model(hidden_layers = 1, layer_size = 30, learning_rate = 3e-3):
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(layer_size, activation = 'relu', input_shape = x_train.shape[1:]))
    for _ in range(hidden_layers - 1):
        model.add(keras.layers.Dense(layer_size, activation = 'relu'))
    model.add(keras.layers.Dense(1))
    optimizer = keras.optimizers.SGD(learning_rate)
    model.compile(loss = 'mse', optimizer = optimizer)
    # model.summary()
    return model
sklearn_model = KerasRegressor(build_fn = build_model)
  • callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]  # 回调函数设置

  • gv = GridSearchCV(sklearn_model, param_grid = params, n_jobs = 1, cv= 5,verbose = 1)  # 找最佳参数

  • gv.fit(x_train_scaled, y_train)


1 导包

from tensorflow import keras
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
cpu=tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpu)
print(tf.config.list_logical_devices())

2 导入数据

from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing

housing = fetch_california_housing()
x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,
                                                            housing.target,
                                                            random_state= 7)
x_train, x_valid, y_train, y_valid = train_test_split(x_train_all, y_train_all,
                                                      random_state = 11)

3 标准化处理数据

from sklearn.preprocessing import StandardScaler, MinMaxScaler

scaler =StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)

4 函数式定义模型

from tensorflow.keras.wrappers.scikit_learn import KerasRegressor   # 回归神经网络
# 搜索最佳学习率
def build_model(hidden_layers = 1, layer_size = 30, learning_rate = 3e-3):
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(layer_size, activation = 'relu', input_shape = x_train.shape[1:]))
    for _ in range(hidden_layers - 1):
        model.add(keras.layers.Dense(layer_size, activation = 'relu'))
    model.add(keras.layers.Dense(1))
    optimizer = keras.optimizers.SGD(learning_rate)
    model.compile(loss = 'mse', optimizer = optimizer)
    # model.summary()
    return model
sklearn_model = KerasRegressor(build_fn = build_model)

 

5 模型训练

callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
history = sklearn_model.fit(x_train_scaled, y_train, epochs = 10,
                            validation_data = (x_valid_scaled, y_valid), callbacks = callbacks)

 6 超参数搜索

超参数搜索的方式:

  • 网格搜索

    • 定义n维方格

    • 每个方格对应一组超参数

    • 一组一组参数尝试

  • 随机搜索

  • 遗传算法搜索

    • 对自然界的模拟

    • A: 初始化候选参数集合 --> 训练---> 得到模型指标作为生存概率

    • B: 选择 --> 交叉--> 变异 --> 产生下一代集合

    • C: 重新到A, 循环.

  • 启发式搜索

    • 研究热点-- AutoML的一部分

    • 使用循环神经网络来生成参数

    • 使用强化学习来进行反馈, 使用模型来训练生成参数.

# 使用sklearn 的网格搜索, 或者随机搜索
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

params = {
    'learning_rate' : [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2],
    'hidden_layers': [2, 3, 4, 5], 
    'layer_size': [20, 60, 100]}

gv = GridSearchCV(sklearn_model, param_grid = params, n_jobs = 1, cv= 5,verbose = 1)
gv.fit(x_train_scaled, y_train)
  • 输出最佳参数
# 最佳得分
print(gv.best_score_)    # -0.47164334654808043
# 最佳参数
print(gv.best_params_)  # {'hidden_layers': 5,'layer_size': 100,'learning_rate':0.01}
# 最佳模型
print(gv.estimator)
'''<keras.wrappers.scikit_learn.KerasRegressor object at 0x0000025F5BB12220>'''
gv.score

7 最佳参数建模

model = keras.models.Sequential()
model.add(keras.layers.Dense(100, activation = 'relu', input_shape = x_train.shape[1:]))
for _ in range(4):
    model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(1))
optimizer = keras.optimizers.SGD(0.01)
model.compile(loss = 'mse', optimizer = optimizer)
model.summary()

callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]
history = model.fit(x_train_scaled, y_train, epochs = 10,
                    validation_data = (x_valid_scaled, y_valid), callbacks = callbacks)

 8 手动实现超参数搜索

  • 根据参数进行多次模型的训练, 然后记录 loss
# 搜索最佳学习率
learning_rates = [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2]
histories = []
for lr in learning_rates:
    model = keras.models.Sequential([
        keras.layers.Dense(30, activation = 'relu', input_shape = x_train.shape[1:]),
        keras.layers.Dense(1)
    ])
    optimizer = keras.optimizers.SGD(lr)
    model.compile(loss = 'mse', optimizer = optimizer, metrics = ['mse'])
    
    callbacks = [keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-2)]
    history = model.fit(x_train_scaled, y_train, 
                        validation_data = (x_valid_scaled, y_valid), 
                        epochs = 100, 
                        callbacks = callbacks)
    histories.append(history)

 

# 画图
import pandas as pd
def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize = (8, 5))
    plt.grid(True)
    plt.gca().set_ylim(0, 1)
    plt.show()

for lr, history in zip(learning_rates, histories): 
    print(lr)
    plot_learning_curves(history)   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/378776.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python可视化界面编程入门

Python可视化界面编程入门具体实现代码如所示&#xff1a; &#xff08;1&#xff09;普通可视化界面编程代码入门&#xff1a; import sys from PyQt5.QtWidgets import QWidget,QApplication #导入两个类来进行程序界面编程if __name__"__main__":#创建一个Appl…

探索ChatGPT背后的网络基础设施

ChatGPT是OpenAI公司开发的一款聊天机器人应用&#xff0c;自2022年11月推出以来以迅雷不及掩耳盗铃之势火爆全球。ChatGPT不仅可以模仿人类对话&#xff0c;还可以创建音乐、电视剧、童话故事和学生论文&#xff0c;甚至是编写和调试计算机程序。 截至2023年1月&#xff0c;C…

如何打造自己的小程序生态?

2021 年全网小程序数量就已超 700 万&#xff0c;从微信开始&#xff0c;到其他各大平台&#xff0c;如抖音、支付宝&#xff0c;小程序发展迅猛&#xff0c;2023年小程序仍有着巨大的发展潜力。 现在。人们逐渐发现&#xff0c;日常的生活、出行、购物各个方面都越来越离不开…

CAN工具-VSpy(ValueCAN) - Panel面板

在介绍CANoe工具的时候&#xff0c;有介绍过Panel面板的使用&#xff0c;同样&#xff0c;在VSpy软件工具中&#xff0c;也有同类型的工具可供使用 - Graphical Panels&#xff0c;同样也能提供一个控制面板&#xff0c;然后我们通过连接信号实现不同的控件&#xff0c;已达到我…

极验4参数分析

目标链接 aHR0cHM6Ly9ndDQuZ2VldGVzdC5jb20v接口分析 开发者人员工具进行抓包&#xff0c;刷新页面&#xff0c;抓到了一个名为 load?captcha_idxxx 的包&#xff0c;Query String Parameters 包含了一些参数 captcha_id&#xff1a;验证码 id&#xff0c;固定值&#xff0c…

如何使用AzureGraph通过Microsoft Graph收集Azure活动目录信息

关于AzureGraph AzureGraph是一款针对Azure活动目录的信息收集工具&#xff0c;该工具基于Microsoft Graph实现其功能。多亏了Microsoft Graph技术&#xff0c;AzureGraph才能从Azure活动目录获取各种信息&#xff0c;如用户、设备、应用程序、域等。 此应用程序允许我们通过…

一次性搞定 `SHOW SLAVE STATUS` 的解读

一次性搞定 SHOW SLAVE STATUS 的解读 解析日志文件的位置 诚然, GTID(全局事务标识符)已经在 MySQL 5.6中得到支持, 此外,还可以通过 Tungsten replicator 软件来实现(2009年以后一直有谷歌在维护,不是吗?)。 但有一部分人还在使用MySQL 5.5的标准副本方式, 那么这些二进制日…

20道经典自动化测试面试题

概述 觉得自动化测试很难&#xff1f; 是的&#xff0c;它确实不简单。但是学会它&#xff0c;工资高啊&#xff01; 担心面试的时候被问到自动化测试&#xff1f; 嗯&#xff0c;你担心的没错&#xff01;确实会被经常问到&#xff01; 现在应聘软件测试工程师的岗位&…

前端经典react面试题及答案

为什么 React 元素有一个 $$typeof 属性 目的是为了防止 XSS 攻击。因为 Synbol 无法被序列化&#xff0c;所以 React 可以通过有没有 $$typeof 属性来断出当前的 element 对象是从数据库来的还是自己生成的。 如果没有 $$typeof 这个属性&#xff0c;react 会拒绝处理该元素。…

docker搭建redis集群模式

目录docker 安装redis1.创建redis.conf开启redis验证(开启密码)允许redis外地连接后台启动开启redis持久化2.启动redis容器3.进入容器redis集群3主3从1.新建6个redis容器2.构建主从关系3.查询集群信息4.主从扩容5.主从缩容docker 安装redis 1.创建redis.conf 开启redis验证(开…

第四阶段-12关于Spring Security框架,RBAC,密码加密原则

关于csmall-passport项目 此项目主要用于实现“管理员”账号的后台管理功能&#xff0c;主要实现&#xff1a; 管理员登录添加管理员删除管理员显示管理员列表启用 / 禁用管理员 关于RBAC RBAC&#xff1a;Role-Based Access Control&#xff0c;基于角色的访问控制 在涉及…

Feign Ribbon Hystrix 三者关系

在微服务架构的应用中&#xff0c; Feign、Hystrix&#xff0c;Ribbon 三者都是必不可少的&#xff0c;可以说已经成为铁三角。 Feign 介绍 Feign 是一款Java语言编写的 HttpClient 绑定器&#xff0c;在 Spring Cloud 微服务中用于实现微服务之间的声明式调用。Feign 可以定…

IIC子系统

文章目录引言一、I2C 总线驱动框架二、I2C驱动框图(重点)三、I2C 子系统软件框架3.1 I2C子系统的4个关键结构体3.2 I2C总线与平台总线的结合3.3 在设备树信息添加i2c从设备3.4 新增加i2c从设备四、i2c driver驱动的编写4.1 陀螺仪和加速度工作原理4.2 mpu6050的寄存器信息和设置…

Synchronized的锁升级过程

Synchronized的锁升级过程 synchronized锁升级过程&#xff1a;在synchronized中引入了偏向锁、轻量级锁、重量级锁之后&#xff0c;当前具体使用的是synchronzed中的那种类型锁&#xff0c;是根据线程竞争激烈程度来决定的。 偏向锁&#xff1a;在锁对象的对象头中记录一下当…

中间件之Kafka实用篇

目录标题一、一些定义&#xff08;一&#xff09;设计kafka的初衷&#xff08;二&#xff09;消息的持久化&#xff08;三&#xff09;sendfile 技术&#xff08;零拷贝&#xff09;二、获取kafka三、卡夫卡客户端工具四、kafka核心API&#xff08;功能&#xff09;五、spring …

阶段十:总结专题(第三章:虚拟机篇)

阶段十&#xff1a;总结专题&#xff08;第三章&#xff1a;虚拟机篇&#xff09;Day-第三章&#xff1a;虚拟机篇1. JVM 内存结构2. JVM 内存参数3. JVM 垃圾回收4. 内存溢出5. 类加载6. 四种引用7. finalizeDay-第三章&#xff1a;虚拟机篇 1. JVM 内存结构 要求 掌握 JVM…

Spring Cloud Alibaba全家桶(三)——微服务负载均衡器Ribbon与LoadBalancer

前言 本文为 微服务负载均衡器Ribbon与LoadBalancer 相关知识&#xff0c;下边将对什么是Ribbon&#xff08;包括&#xff1a;客户端的负载均衡、服务端的负载均衡、常见负载均衡算法&#xff09;&#xff0c;Nacos使用Ribbon&#xff0c;Ribbon内核原理&#xff08;包括&#…

Qt::QOpenGLWidget 渲染天空壳

在qt窗口中嵌入opengl渲染天空壳和各种立方体一 学前知识天空壳的渲染学前小知识1 立方体贴图 天空壳的渲染就是利用立方体贴图来实现渲染流程2 基础光照 光照模型3 opengl帧缓冲 如何自定义帧缓冲实现后期特效4 glsl常见的shader内置函数 glsl编程常用的内置函数二 shader代码…

部署运行ai智障写作记录【ChatRWKV】

文章目录前言一、环境安装1.python环境&#xff1a;Python 3.10。2.安装一些 pip 库numpy 、tokenizers 、prompt_toolkit3.安装pytorch 1.13.1CUDA 11.7二、运行记录1、下载代码2、下载训练参数3、编辑代码运行总结前言 看到知乎一篇教程&#xff0c; 大佬自己弄得ai小说续写…

AI环境搭建步骤(Windows环境)

1. 安装好Anaconda3版本(1) 安装链接&#xff1a;https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?CM&OD本文使用Anaconda3下载链接&#xff1a;Anaconda5(2) 注意安装anaconda时一定要把环境变量加入windows环境中。要没有勾选&#xff0c;安装完后还有手动加入…