【深度学习】实验09 使用Keras完成线性回归

news2024/12/24 9:32:39

文章目录

  • 使用Keras完成线性回归
    • 1. 导入Keras库
    • 2. 创建数据集
    • 3. 划分数据集
    • 4. 构造神经网络模型
    • 5. 训练模型
    • 6. 测试模型
    • 7. 分析模型
  • 附:系列文章

使用Keras完成线性回归

Keras是一款基于Python的深度学习框架,以Tensorflow、Theano和CNTK作为后端,由François Chollet开发和维护,其目标是使深度学习模型的实现变得快速、简单。它的设计理念是用户友好、可扩展、易于调试和实验。

Keras提供了一系列高级API和便捷的工具,使得用户可以快速构建和训练深度学习模型,而不必关注底层的细节。Keras支持各种类型的网络结构,包括卷积神经网络、循环神经网络、自编码器等,并且可以轻松地在不同的数据集上进行训练和测试。

Keras的主要特点有:

  1. 简单易用,快速上手:Keras提供了简单易用的API,用户只需几行代码就能实现复杂的深度学习模型。

  2. 支持多种后端:Keras可以用Tensorflow、Theano和CNTK作为后端,用户可以根据自己的需要选择合适的后端。

  3. 高度可扩展:Keras提供了模块化的API,用户可以根据需要添加自定义层和函数,以及修改现有的代码。

  4. 方便的调试和实验:Keras提供了实时可视化的工具,方便用户查看模型的训练情况和测试结果,并且支持各种回调函数,例如早期停止、学习率调整等。

  5. 支持GPU加速:Keras可以利用GPU进行计算,加速深度学习模型的训练和推断过程。

总之,Keras是一款优秀的深度学习框架,它使得深度学习模型的构建和训练变得更加简单和快速,可以帮助用户更加专注于模型的设计和应用。

1. 导入Keras库

import warnings
warnings.filterwarnings("ignore")

import numpy as np
np.random.seed(1337)

from keras.models import Sequential
from keras.layers import Dense
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
Using TensorFlow backend.

2. 创建数据集

# 创建数据集
# 在[-1,1]的区间内等间隔创建200个样本数
X = np.linspace(-1, 1, 200)
X
   array([-1.        , -0.98994975, -0.9798995 , -0.96984925, -0.95979899,
          -0.94974874, -0.93969849, -0.92964824, -0.91959799, -0.90954774,
          -0.89949749, -0.88944724, -0.87939698, -0.86934673, -0.85929648,
          -0.84924623, -0.83919598, -0.82914573, -0.81909548, -0.80904523,
          -0.79899497, -0.78894472, -0.77889447, -0.76884422, -0.75879397,
          -0.74874372, -0.73869347, -0.72864322, -0.71859296, -0.70854271,
          -0.69849246, -0.68844221, -0.67839196, -0.66834171, -0.65829146,
          -0.64824121, -0.63819095, -0.6281407 , -0.61809045, -0.6080402 ,
          -0.59798995, -0.5879397 , -0.57788945, -0.5678392 , -0.55778894,
          -0.54773869, -0.53768844, -0.52763819, -0.51758794, -0.50753769,
          -0.49748744, -0.48743719, -0.47738693, -0.46733668, -0.45728643,
          -0.44723618, -0.43718593, -0.42713568, -0.41708543, -0.40703518,
          -0.39698492, -0.38693467, -0.37688442, -0.36683417, -0.35678392,
          -0.34673367, -0.33668342, -0.32663317, -0.31658291, -0.30653266,
          -0.29648241, -0.28643216, -0.27638191, -0.26633166, -0.25628141,
          -0.24623116, -0.2361809 , -0.22613065, -0.2160804 , -0.20603015,
          -0.1959799 , -0.18592965, -0.1758794 , -0.16582915, -0.15577889,
          -0.14572864, -0.13567839, -0.12562814, -0.11557789, -0.10552764,
          -0.09547739, -0.08542714, -0.07537688, -0.06532663, -0.05527638,
          -0.04522613, -0.03517588, -0.02512563, -0.01507538, -0.00502513,
           0.00502513,  0.01507538,  0.02512563,  0.03517588,  0.04522613,
           0.05527638,  0.06532663,  0.07537688,  0.08542714,  0.09547739,
           0.10552764,  0.11557789,  0.12562814,  0.13567839,  0.14572864,
           0.15577889,  0.16582915,  0.1758794 ,  0.18592965,  0.1959799 ,
           0.20603015,  0.2160804 ,  0.22613065,  0.2361809 ,  0.24623116,
           0.25628141,  0.26633166,  0.27638191,  0.28643216,  0.29648241,
           0.30653266,  0.31658291,  0.32663317,  0.33668342,  0.34673367,
           0.35678392,  0.36683417,  0.37688442,  0.38693467,  0.39698492,
           0.40703518,  0.41708543,  0.42713568,  0.43718593,  0.44723618,
           0.45728643,  0.46733668,  0.47738693,  0.48743719,  0.49748744,
           0.50753769,  0.51758794,  0.52763819,  0.53768844,  0.54773869,
           0.55778894,  0.5678392 ,  0.57788945,  0.5879397 ,  0.59798995,
           0.6080402 ,  0.61809045,  0.6281407 ,  0.63819095,  0.64824121,
           0.65829146,  0.66834171,  0.67839196,  0.68844221,  0.69849246,
           0.70854271,  0.71859296,  0.72864322,  0.73869347,  0.74874372,
           0.75879397,  0.76884422,  0.77889447,  0.78894472,  0.79899497,
           0.80904523,  0.81909548,  0.82914573,  0.83919598,  0.84924623,
           0.85929648,  0.86934673,  0.87939698,  0.88944724,  0.89949749,
           0.90954774,  0.91959799,  0.92964824,  0.93969849,  0.94974874,
           0.95979899,  0.96984925,  0.9798995 ,  0.98994975,  1.        ])
# 将数据集随机化
np.random.shuffle(X)
X
   array([-0.70854271,  0.1758794 , -0.30653266,  0.74874372, -0.02512563,
           0.33668342, -0.85929648,  0.01507538, -0.13567839,  0.72864322,
           0.24623116, -0.74874372, -0.78894472,  0.50753769,  0.03517588,
           0.35678392, -0.55778894,  0.2361809 , -0.25628141, -0.44723618,
           0.2160804 , -0.43718593, -0.64824121,  0.69849246, -0.03517588,
          -0.45728643,  0.86934673,  0.73869347,  0.53768844, -0.67839196,
          -0.75879397,  0.55778894,  0.28643216, -0.05527638, -0.86934673,
           0.1959799 , -0.57788945, -0.9798995 , -0.6080402 , -0.63819095,
           0.84924623,  0.41708543,  0.13567839,  0.79899497, -0.47738693,
           0.46733668,  0.59798995, -0.80904523, -0.98994975, -0.36683417,
          -0.5678392 , -0.00502513, -0.53768844, -0.37688442, -0.65829146,
          -0.1959799 ,  0.06532663,  0.44723618, -0.01507538, -0.6281407 ,
           0.02512563, -0.71859296, -0.14572864, -0.46733668,  0.07537688,
           0.85929648,  0.76884422,  0.40703518, -0.68844221,  0.68844221,
          -0.29648241,  0.66834171, -0.95979899, -0.33668342,  0.26633166,
          -0.82914573,  1.        , -0.5879397 , -0.69849246, -0.20603015,
           0.63819095, -0.88944724, -0.40703518, -0.32663317,  0.15577889,
          -0.41708543,  0.10552764,  0.20603015, -0.04522613,  0.00502513,
          -0.31658291,  0.43718593,  0.42713568,  0.45728643, -0.59798995,
          -0.66834171,  0.83919598,  0.75879397, -0.24623116,  0.71859296,
          -0.92964824,  0.39698492,  0.61809045, -0.84924623, -0.87939698,
          -0.96984925,  0.87939698,  0.6281407 ,  0.25628141,  0.27638191,
           0.12562814,  0.09547739, -0.89949749,  0.80904523, -0.16582915,
          -0.12562814,  0.30653266,  0.49748744,  0.5879397 , -0.51758794,
          -0.10552764,  0.54773869, -0.94974874,  0.92964824,  0.16582915,
          -0.83919598, -0.35678392, -0.48743719,  0.08542714, -0.61809045,
           0.18592965,  0.57788945,  0.65829146,  0.38693467,  0.91959799,
          -0.26633166, -0.50753769, -1.        , -0.54773869,  0.6080402 ,
          -0.49748744, -0.22613065,  0.9798995 ,  0.98994975,  0.5678392 ,
           0.32663317,  0.64824121, -0.52763819,  0.36683417,  0.81909548,
          -0.11557789,  0.31658291, -0.2160804 ,  0.95979899,  0.77889447,
          -0.73869347, -0.81909548, -0.79899497,  0.78894472,  0.88944724,
          -0.2361809 ,  0.37688442,  0.70854271,  0.22613065, -0.28643216,
          -0.38693467,  0.90954774, -0.91959799,  0.48743719, -0.42713568,
          -0.08542714,  0.11557789, -0.18592965,  0.47738693, -0.39698492,
          -0.34673367,  0.04522613,  0.05527638,  0.93969849, -0.77889447,
          -0.93969849, -0.06532663, -0.72864322,  0.29648241,  0.52763819,
          -0.76884422,  0.94974874,  0.82914573,  0.34673367, -0.90954774,
          -0.27638191, -0.15577889, -0.1758794 ,  0.14572864, -0.09547739,
           0.96984925,  0.67839196, -0.07537688,  0.89949749,  0.51758794])
# 假设真实模型为:Y=0.5X+2
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200,))
Y
   array([1.66851812, 2.12220988, 1.91611873, 2.38979647, 1.96473269,
          2.11662688, 1.58217043, 2.05326658, 1.95885373, 2.4277956 ,
          2.13544689, 1.68732448, 1.66384243, 2.2702853 , 2.03148986,
          2.14968674, 1.76442495, 2.10802586, 1.93269542, 1.81936289,
          2.15190248, 1.83941395, 1.71399197, 2.21820555, 1.97918099,
          1.79781646, 2.43645587, 2.31211201, 2.21764353, 1.71912829,
          1.64285239, 2.2663785 , 2.11081029, 2.09338152, 1.5614153 ,
          2.19655545, 1.72824772, 1.56444412, 1.72673075, 1.67311017,
          2.39817488, 2.12624087, 2.07791136, 2.40515644, 1.80701389,
          2.16050089, 2.30373845, 1.57656517, 1.52482139, 1.7639545 ,
          1.76787463, 2.01204511, 1.74877623, 1.86751173, 1.67509082,
          1.95941218, 2.0126989 , 2.31574759, 2.04672223, 1.73762178,
          1.97249596, 1.65257838, 1.98435822, 1.74193776, 2.05272917,
          2.41693508, 2.37609913, 2.24686996, 1.61790402, 2.37607665,
          1.82677368, 2.29512653, 1.52756173, 1.79404414, 2.08314   ,
          1.5209276 , 2.48034115, 1.7821867 , 1.60377021, 1.82345627,
          2.23840132, 1.50174227, 1.85127905, 1.92372432, 1.95433662,
          1.8146093 , 1.96513404, 2.0227501 , 1.97564664, 2.09893966,
          1.95392005, 2.2089975 , 2.26074219, 2.24742979, 1.75936195,
          1.69145596, 2.46801952, 2.40938521, 1.98369075, 2.37509171,
          1.53026033, 2.24305926, 2.33309562, 1.49913881, 1.48743005,
          1.54075518, 2.33130062, 2.37463005, 2.19387461, 2.20970603,
          2.04719149, 2.04105128, 1.48410805, 2.34714158, 1.95061571,
          1.89473245, 2.26596278, 2.22430597, 2.29984983, 1.7894671 ,
          1.85995514, 2.31688729, 1.53417344, 2.39777465, 2.12853793,
          1.47736812, 1.90180229, 1.73086567, 2.03772387, 1.67243511,
          2.10115733, 2.26944612, 2.37404859, 2.22042332, 2.4948031 ,
          1.80153666, 1.72069013, 1.44829544, 1.77678155, 2.24291992,
          1.73557503, 1.79249737, 2.52580388, 2.46810975, 2.34211232,
          2.22144569, 2.31945172, 1.72814133, 2.17318812, 2.43560932,
          1.9662451 , 2.14319385, 1.83150682, 2.48805089, 2.28374904,
          1.63645718, 1.57901687, 1.61041853, 2.40884706, 2.37339631,
          1.90728817, 2.09065413, 2.36836694, 2.05400262, 1.87764304,
          1.83547711, 2.45064964, 1.46324772, 2.2429919 , 1.75954149,
          1.97326923, 2.08379661, 2.04616096, 2.3161197 , 1.81470671,
          1.8188581 , 2.11349671, 2.05477704, 2.39622142, 1.61281075,
          1.56914576, 1.96947616, 1.56645219, 2.08002605, 2.2185357 ,
          1.54079134, 2.42384819, 2.41198434, 2.0570266 , 1.55142224,
          1.83396657, 1.92648666, 1.9143498 , 1.9372014 , 1.92794208,
          2.42698754, 2.29871021, 2.03266023, 2.42413239, 2.28286632])
# 绘制数据集(X, Y)
plt.scatter(X, Y)
plt.show()

1

3. 划分数据集

# 划分训练集和测试集
X_train, Y_train = X[:160], Y[:160]
X_test, Y_test = X[160:], Y[160:]

4. 构造神经网络模型

# 定义一个model
# Keras有两种类型的模型,序列模型和函数式模型
# 比较常用的是Sequential,它是单输入单输出的
model = Sequential()

# 通过add()方法一层层添加模型
# Dense是全连接层,第一层需要定义输入
model.add(Dense(output_dim=1,input_dim=1))

# 定义完成模型就要训练了,不过训练之前我们需要指定一些训练参数
# 通过compile()方法选择损失函数和优化器
# 这里我们用均方差作为损失函数,随机梯度下降作为优化方法
model.compile(loss='mse', optimizer='sgd')

5. 训练模型

# 开始训练
print('Training ----------')

# Keras有很多开始训练的函数,这里用train_on_batch()
for step in range(301):
    cost = model.train_on_batch(X_train,Y_train)
    if step%100 == 0:
        print('train cost: ', cost)
Training ----------
WARNING:tensorflow:From /home/nlp/anaconda3/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

train cost:  4.0225005
train cost:  0.073238626
train cost:  0.00386274
train cost:  0.002643449

6. 测试模型

# 测试训练好的模型
print('Testing ----------')
cost = model.evaluate(X_test, Y_test, batch_size = 40)
print('test cost: ',cost)
Testing ----------
40/40 [==============================] - 0s 508us/step
test cost:  0.0031367032788693905

7. 分析模型

# 查看训练出的网络参数
# 由于我们网络只有一层,且每次训练的输入只有一个,输出只有一个
# 因此第一层训练出Y=WX+B这个模型,其中W,b为训练出的参数
W, b = model.layers[0].get_weights()
print('Weights = ', W, '\nbiases = ', b)
Weights =  [[0.4922711]] 
biases =  [1.9995022]
# 画出预测图
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

2

#使用r2 score评估准确度
pred_acc = r2_score(Y_test, Y_pred)
print('pred_acc',pred_acc)
pred_acc 0.9591211310535933
#保存模型
model.save('keras_linear.h5')

附:系列文章

序号文章目录直达链接
1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
3特征处理https://want595.blog.csdn.net/article/details/132182165
4交叉验证https://want595.blog.csdn.net/article/details/132182238
5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
18自然语言处理https://want595.blog.csdn.net/article/details/132276591

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1003396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AlwaysOn-关于读写分离的误区(一)

前言 很多人认为AlwaysOn在同步提交模式下数据是实时同步的,也就是说在主副本写入数据后可以在辅助副本立即查询到。因此期望实现一个彻底的读写分离策略,即所有的写语句在主副本上,所有的只读语句分离到辅助副本上。这是一个认知误区&#x…

多要素气象站:推动气象监测进入智能化新时代

一、多要素气象站概述 多要素气象站是一种集成了多种气象监测要素的自动化气象站,可实现对温度、湿度、风速、风向、气压、太阳辐射等多项气象参数的实时监测。相较于传统气象站,多要素气象站体积更小、能耗更低,且具备更高的测量精度和更广…

高精度加法[大整数运算]

这里只有大整数运算,浮点数相对来说很少使用到 前言 如果使用C内置的类型来存储大整数(位数有几万位),是会溢出的,得不到正确的值,即使是long long int类型的范围也只是[-9*10^19,9*10^19](无符号是[0,10^20]),所能存储的最大数字也就20位,所以需要高精度算法,高精度加法具体…

人脸签到系统 pyQT+数据库+深度学习

一、简介 人脸签到系统是一种基于人脸识别技术的自动签到和认证系统。它利用计算机视觉和深度学习算法来检测、识别和验证个体的面部特征,以确定其身份并记录其出现的时间。这个系统通常用于各种场景,包括企业、学校、会议、活动和公共交通等&#xff0c…

海格里斯HEGERLS托盘式四向穿梭车系统核心技术有哪些?

四向穿梭车立体仓库是常见的自动化立体库解决方案,可以应用在不规则、异型、长宽比较大或者少品种大批量、多品种大批量的仓库。其核心设备四向穿梭车克服了环形穿梭车的缺点,具有很高的灵活性和柔性。对于大型的立体库系统,四向穿梭车具有很…

2023最新计算机大数据毕业设计选题推荐100例

文章目录 0 前言1 如何选题1.1 选题技巧:如何避坑(重中之重)1.2 为什么这么说呢?1.3 难度把控1.4 题目名称1.5 最后 2 大数据 - 选题推荐2.1 大数据挖掘类2.2 大数据处理、云计算、区块链 毕设选题2.3 大数据安全类2.4 python大数据 游戏设计、动画设计类…

人工智能基础-趋势-架构

在过去的几周里,我花了一些时间来了解生成式人工智能基础设施的前景。在这篇文章中,我的目标是清晰概述关键组成部分、新兴趋势,并重点介绍推动创新的早期行业参与者。我将解释基础模型、计算、框架、计算、编排和矢量数据库、微调、标签、合…

HTTPS原理(证书验证+数据传输)

HTTPS协议相关的概念有SSL、非对称加密、CA证书等 为什么用了HTTPS就是安全的?HTTPS底层原理如何实现?用了HTTPS就一定安全吗? HTTPS实现原理 HTTPS在内容传输上的加密使用的是对称加密,证书验证阶段使用非对称加密 中间人攻…

无涯教程-JavaScript - INTRATE函数

描述 INTRATE函数返回完全投资证券的利率。 语法 INTRATE (settlement, maturity, investment, redemption, [basis])争论 Argument描述Required/OptionalSettlement 证券的结算日期。 证券结算日期是指在发行日期之后将证券交易给买方的日期。 RequiredMaturity 证券的到期…

innovus:如何设置clock cell list(icg logic_cells buffer inverter)

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 set_ccopt_property inverter_cells [list CKINV*LVT] set_ccopt_property buffer_cells [list CKBUF*LVT] set_ccopt_property clock_gating_cells [list CLKNQ*LVT set_cc…

数据结构与算法(一)数组的相关概念和底层java实现

一、前言 从今天开始,笔者也开始从0学习数据结构和算法,但是因为这次学习比较捉急,所以记录的内容并不会过于详细,会从基础和底层代码实现以及力扣相关题目去写相关的文章,对于详细的概念并不会过多讲解 二、数组基础…

数字化互联网数字孪生:重塑未来的创新思维

数字孪生是当今数字化互联网时代的一个引人注目的概念,它正在改变着各个行业的方式和节奏。数字孪生不仅是一种技术,更是一种思维方式。 数字化互联网的崛起 数字化互联网正在以前所未有的速度和规模改变着我们的生活和工作方式。从智能手机到物联网设备…

按图搜索淘宝商品(拍立淘)API接口 搜爆款商品 图片搜索功能api 调用示例

接口名称:item_search_img 公共参数 请求地址: 测试item_search_img 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中)[item_s…

品牌为什么要做价格管控

价格管控的目的其实是为了治理低价,低价的存在会使渠道变得不可控,比如经销商低价跟价,消费者因为低价而转投其他品牌,这些无形中都会影响品牌的销量,阻碍品牌发展,所以做价格管控,就是在做好低…

智慧公厕助推城市管理智能化和治理精细化

随着城市化进程的不断加快,城市管理面临着诸多挑战和问题。而智慧公厕作为城市数字化赋能的重要一环,正成为推动城市管理智能化和治理精细化的关键力量。本文将以智慧公厕头部厂家广州中期科技有限公司,所实施大量精品案例项目的实景实图&…

(源码版)2023全国大学生数学建模竞赛E题黄河水沙监测数据分析详解+Python代码源码SARIMA模型

前言 比赛结束了不知道大家情况如何,就我个人而言的话,由于工作任务比较繁重仅完成了对D题和E题的思路解答和建模,还是比较遗憾的。一个人要完成多题的建模和分析确实不是一件容易的事情,当然我向大家做出承诺历年的建模比赛我都…

修改el-card的header的背景颜色

修改el-card的header的背景颜色 1.修改默认样式 好处是当前页面的所有的el-card都会变化 页面卡片&#xff1a; <el-card class"box-card" ><div slot"header" class"clearfix"><span>卡片名称</span><el-button s…

华为数通方向HCIP-DataCom H12-821题库(单选题:341-360)

第341题 在BGP中代表对等体之间已经建立连接的状态是以下哪一种? A、Active B、Connect C、Established D、Open 答案:C 第342题 以下关于路由选择工具的描述,错误的是哪一项? A、访问控制列表用于匹配路由信息或者数据包的地址,过滤不符合条件的路由信息或数据包 …

构造函数注入指定bean名称

配置类 Configuration public class ThreadPoolTaskExecutorConfig {Beanpublic ThreadPoolTaskScheduler syncScheduler() {ThreadPoolTaskScheduler syncScheduler new ThreadPoolTaskScheduler();syncScheduler.setPoolSize(10);syncScheduler.setThreadGroupName("s…

企业网络革命:连接和访问的智慧选项

近年来&#xff0c;企业网络通信需求可谓五花八门&#xff0c;变幻莫测。它不仅为企业的生产、办公、研发、销售提供全面赋能&#xff0c;同时也让企业业务规模变大成为了可能。今天&#xff0c;我们来聊聊广域网中两个不可忽视的概念&#xff1a;连接&#xff08;Connection&a…