Keras人工智能神经网络 Regressor 回归 神经网络搭建

news2024/11/24 2:57:16

前期分享了使用tensorflow来进行神经网络的回归,tensorflow构建神经网络

本期我们来使用Keras来搭建一个简单的神经网络

Keras神经网络可以用来模拟回归问题 (regression),例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值,也就是线性回归问题(Y=w*X+b)

1创建随机数据

import numpy as np
np.random.seed(1337) # 随机数
from keras.models import Sequential # models.Sequential,用来一层一层的建立神经层
from keras.layers import Dense # layers.Dense 个神经层是全连接层
import matplotlib.pyplot as plt # 可视化模块

# 创建数据

X = np.linspace(-1, 1, 300)
np.random.shuffle(X) # 数据随机化
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (300, ))
# 显示创建的数据
plt.scatter(X, Y)
plt.show()
X_train, Y_train = X[:260], Y[:260] # train 前 260 data points
X_test, Y_test = X[260:], Y[260:] # test 后 40 data points

数据创建完后,我们查看一下生成的随机数据

建立神经网络模型

Sequential 建立 model

model.add 添加神经层,添加的是 Dense 全连接神经层。

参数有两个,一个是输入数据和输出数据的维度,本代码的例子中 x 和 y 是一维的。

如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入(这一点大大简化了神经网络的搭建过程)

model = Sequential()
model.add(Dense(output_dim=1, input_dim=1))
model.compile(loss='mse', optimizer='sgd') 
# 参数中,误差函数用的是 mse 均方误差;优化器用的是 sgd 随机梯度下降法

以上3行代码便是keras神经网络的搭建过程(比tensorflow减少了很多),构建完成神经网络后,开始训练

keras神经网络训练

for step in range(401):
cost = model.train_on_batch(X_train, Y_train)
if step % 100 == 0:
print('train cost: ', cost)
'''
train cost: 4.0291815
train cost: 0.076484405
train cost: 0.004810586
train cost: 0.0029513359
train cost: 0.002760151
'''

训练400步,每100步打印一下训练的结果,使用model.train_on_batch 一批一批的训练 X_train, Y_train

keras 神经网络的验证

使用 model.evaluate,输入测试集的x和y, 输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数

从训练的结果看出, weights 比较接近0.5,bias 接近 2,符合我们输入的模型

cost = model.evaluate(X_test, Y_test, batch_size=40)
print('test cost:', cost)
W, b = model.layers[0].get_weights()
print('Weights=', W, '\nbiases=', b)
40/40 [==============================] - 0s 450us/step
test cost: 0.003141355235129595
Weights= [[0.51579475]]
biases= [1.9971616]

可视化模型

验证完成后,我们可以可视化模型,看看神经网络预测的数据与实际数据的差异

Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

以上便是keras构建回归神经网络的步骤,我们下期分享一下如何使用keras构建分类模型的神经网络

 动画详解transformer  

更多transformer,VIT,swin tranformer
参考头条号:人工智能研究所
v号:启示AI科技
微信中复制如下链接,打开,免费体验chatgpt
 
https://wx2.expostar.cn/qz/pages/manor/index?id=1137&share_from_id=79482&sid=24

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1171228.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GPT学习笔记

百度的文心一言 阿里的通义千问 通过GPT能力,提升用户体验和产品力 GPT的出现是AI的iPhone时刻。2007年1月9日,第一代iPhone发布,开启移动互联网时代。新一轮的产业革命。 GPT模型发展时间线: Copilot - 副驾驶 应用&#xf…

Angular-07:组件生命周期

三个阶段: ① 挂载阶段1.1 constructor1.2 ngOnInit ② 更新阶段2.1 ngOnChanges2.2 ngAfterViewInit2.3 ngAfterContentInit2.4 ngDoCheck ③ 卸载阶段3.1 onOnDestroy ④ 在组件中添加所有方法并打印 该表按照执行顺序编写 编号函数名实现名说明1constructorcons…

基于单片机的智能感应监控系统的设计

收藏和点赞,您的关注是我创作的动力 文章目录 概要 一、系统分析2.1 整个控制系统的设计要求2.2 总体设计方案 二、系统硬件电路设计3.1 硬件电路介绍3.2 控制电路分析3.2.1 复位电路 三 软件设计原理图 四、 结论五、 文章目录 概要 因为人们在生活中对安全防范的…

【JAVA学习笔记】61 - 线程入门、常用方法、同步机制,以及本章作业(难点)

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter17/src/com/yinhai 线程 一、线程相关概念 1.程序 是为完成特定任务、用某种语言编写的一组指令的集合。简单的说:就是我们写的代码 2.进程 1)进程是指运行中的程序&#x…

劳易测扫码条码分段读取实现方法

添加如下3个功能块:M10,M13和M27 设置BCL参数:Code type 1 为Code128 参数:Mode为Range 参数:Number Of digits 1 为条码最小长度 Number Of digits 2 为条码最大长度。 设置M10:Mode(With …

嵌入式系统的元素

注意:关于嵌入式系统的元素这一块儿内容,定义太多了。例如:吉姆莱丁 著,陈会翔 译,由清华大学出版社出版的《构建高性能嵌入式系统》中提到:嵌入式系统通常由电源、时基、数字处理、内存、软件和固件、专用…

JavaScript执行上下文和调用栈

上节课我们已经说过了,JavaScript的代码执行是发生在js引擎中的调用堆栈的,但是具体是如何运行的,我们来详细剖析一下 如何执行上下文 执行上下文: 执行上下文是指在JavaScript中代码被执行时所创建的环境。它包含了变量、函数、…

京东大数据平台-第三方京东平台数据查询分析软件系统

对于电商商家来说,做好电商数据分析是电商运营中的重要一环,且能为电商商家带来诸多好处,例如: 1、提高销售额:通过数据分析可以更好地把握消费者的购买行为,从而更好地推出营销活动,提高销售额…

7.SpringBoot集成Mybats-plus且安装代码生成插件

一、背景 项目需要集成Mybatis-plus用作服务的ORM。 二、实现 2.1 pom.xml引入 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>${mybatis-plus.version}</version>&l…

QT 实现解密m3u8文件

文章目录 概要如何解密M3U8文件呢实现思路和代码序列图网络请求解密 结论 概要 视频文件很多已M3U8文件格式来提供&#xff0c;先复习下什么是M3U8文件&#xff01;用QT的 mutimedia框架来播放视频时&#xff0c;有的视频加载慢&#xff0c;有的视频加载快&#xff0c;为啥&am…

python 深度学习 解决遇到的报错问题9

本篇继python 深度学习 解决遇到的报错问题8-CSDN博客 目录 一、can only concatenate str (not "int") to str 二、cant convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, in…

Python基础入门例程32-NP32 牛牛的加减器(运算符)

最近的博文&#xff1a; Python基础入门例程31-NP31 团队分组&#xff08;列表&#xff09;-CSDN博客 Python基础入门例程30-NP30 用列表实现队列&#xff08;列表&#xff09;-CSDN博客 Python基础入门例程29-NP29 用列表实现栈&#xff08;列表&#xff09;-CSDN博客 目录…

牛客网刷题-(11)

&#x1f308;个人主页: Aileen_0v0&#x1f525;系列专栏:PYTHON学习系列专栏&#x1f4ab;"没有罗马,那就自己创造罗马~" 目录 (1)输出1-100的所有奇数 (2)计算输入6个数字中正数的个数 (3)递增序列 (4)PUM (1)输出1-100的所有奇数 #输出1-100的所有奇数 x…

使用趋动云部署ChatGLM3-6B模型

使用趋动云部署ChatGLM3-6B模型 1 创建项目2 配置环境 修改代码3 运行代码 1 创建项目 创建项目 进入项目 -> 运行代码 -> 选择资源&#xff08;B1.large&#xff09; 2 配置环境 修改代码 等待开发者工具加载完成 -> 点击 JupyterLab 进入开发环境 打开 termin…

2023.11.4 Idea 配置国内 Maven 源

目录 配置国内 Maven 源 重新下载 jar 包 配置国内 Maven 源 <mirror><id>alimaven</id><name>aliyun maven</name><url>http://maven.aliyun.com/nexus/content/groups/public/</url><mirrorOf>central</mirrorOf> …

velero 集群备份实战

文章目录 velero 集群备份实战velero 架构velero 安装备份mysql集群备份命令查看备份列表 如何恢复&#xff1f;如何卸载&#xff1f;报错处理 velero 集群备份实战 velero 架构 vmware 的产品。velero 是一个CS架构&#xff0c;服务端是一堆CRD, 监听客户端发来的请求。 优…

【多线程】龟兔赛跑

package org.example;public class Race implements Runnable {//胜利者private static String winner;Overridepublic void run() {for(int i0;i<100;i){boolean flag gameOver(i);//如果flag>100,结束比赛if(flag){break;}System.out.println(Thread.currentThread().g…

批量剪辑:高效处理视频文件的图文解析,AI智剪方法

随着视频文件的数量和种类不断增加&#xff0c;传统的视频剪辑方法往往效率低下且费时费力。为了解决这个问题&#xff0c;批量剪辑和AI智剪技术应运而生。在剪辑过程中&#xff0c;AI智剪可自动调整画面质量、音效、色彩等参数&#xff0c;以保证视频质量。它们可以帮助我们高…

uniapp原生插件之安卓文件操作原生插件

插件介绍 安卓文件操作原生插件&#xff0c;读写文件&#xff0c;文件下载等&#xff0c;支持读取移动设备路径等外部存储设备路径&#xff0c;如U盘路径 插件地址 安卓文件操作原生插件 - DCloud 插件市场 超级福利 uniapp 插件购买超级福利 详细使用文档 uniapp 安卓文…

WPF中依赖属性及附加属性的概念及用法

完全来源于十月的寒流&#xff0c;感谢大佬讲解 依赖属性 由依赖属性提供的属性功能 与字段支持的属性不同&#xff0c;依赖属性扩展了属性的功能。 通常&#xff0c;添加的功能表示或支持以下功能之一&#xff1a; 资源数据绑定样式动画元数据重写属性值继承WPF 设计器集成 …