Python深度学习-Keras》精华笔记4:解决深度学习回归问题

news2025/1/10 20:30:51

公众号:尤而小屋
作者:Peter
编辑:Peter

持续更新《Python深度学习》一书的精华内容,仅作为学习笔记分享。

本文是第4篇:基于Keras解决深度学习中的回归问题

Keras内置数据集

回归问题中使用的是内置的波士顿房价数据集。在keras中有多个内置的数据集:

  • 波士顿房价数据集
  • CIFAR10数据集(包含10种类别的图片集)
  • CIFAR100数据集(包含100种类别的图片集)
  • MNIST数据集(手写数字图片集)
  • Fashion-MNIST数据集(10种时尚类别的图片集)
  • IMDB电影点评数据集
  • 路透社新闻数据集

其中IMDB数据集在二分类问题中被使用过,路透社新闻数据集在多分类问题中被使用。

In [1]:

import numpy as np
np.random.seed(1234)  # 设置随机种子

import matplotlib.pyplot as plt
%matplotlib inline

import warnings 
warnings.filterwarnings("ignore")  # 忽略notebook中的警告

导入数据

In [2]:

from  keras.datasets import boston_housing

(train_data,train_targets), (test_data, test_targets) = boston_housing.load_data()  

查看数据的基本信息:

In [3]:

train_data[:3]

Out[3]:

array([[1.23247e+00, 0.00000e+00, 8.14000e+00, 0.00000e+00, 5.38000e-01,
        6.14200e+00, 9.17000e+01, 3.97690e+00, 4.00000e+00, 3.07000e+02,
        2.10000e+01, 3.96900e+02, 1.87200e+01],
       [2.17700e-02, 8.25000e+01, 2.03000e+00, 0.00000e+00, 4.15000e-01,
        7.61000e+00, 1.57000e+01, 6.27000e+00, 2.00000e+00, 3.48000e+02,
        1.47000e+01, 3.95380e+02, 3.11000e+00],
       [4.89822e+00, 0.00000e+00, 1.81000e+01, 0.00000e+00, 6.31000e-01,
        4.97000e+00, 1.00000e+02, 1.33250e+00, 2.40000e+01, 6.66000e+02,
        2.02000e+01, 3.75520e+02, 3.26000e+00]])

In [4]:

train_data.shape, test_data.shape

Out[4]:

((404, 13), (102, 13))

数据标准化

在机器学习中,对数据进行标准化是非常重要的,主要有以下原因:

  • 消除量纲影响:不同特征的数值大小可能相差很大,例如重量和价格。如果这些特征之间具有很大的差异,那么某些特征的值可能会主导模型的训练过程,从而削弱其他特征的重要性。通过标准化,可以将所有特征的值缩放到相同的尺度上,消除量纲的影响。
  • 加速收敛:在机器学习算法中,梯度下降是一种常用的优化算法。当数据存在较大的尺度差异时,梯度更新可能会变得非常慢,导致算法收敛速度变慢。通过标准化,可以减少尺度差异,从而加速梯度下降算法的收敛速度。
  • 提高模型性能:标准化可以使数据分布更加均匀,避免出现极端值或离群点。这有助于提高模型的泛化能力和性能。 因此,对数据进行标准化是机器学习中一个重要的预处理步骤,可以提升模型的训练效果和预测性能。

首先求出训练集的均值和标准差,进行标准化;再使用训练集的均值和标准差对测试集进行标准化。

In [5]:

# 手动标准化

mean = train_data.mean(axis=0) # 均值
train_data -= mean

std = train_data.std(axis=0)
train_data /= std 

对测试集的标准化:

In [6]:

# 仍然使用训练集的mean std

test_data -= mean
test_data /= std

搭建网络

In [7]:

from keras import models, layers

In [8]:

def build_model():
    model = models.Sequential()
    # 输入层
    model.add(layers.Dense(64,activation="relu",input_shape=(train_data.shape[1],)))
    # 隐藏层
    model.add(layers.Dense(64,activation="relu"))
    # 输出层:回归问题预测一个值,最终只有一个单元
    model.add(layers.Dense(1))
    model.compile(optimizer="rmsprop",  # 优化器
                  loss="mse",   # 损失
                  metrics=["mae"])  # 评价指标
    return model

k折交叉验证

本小节的部分代码和原文有不同:

In [9]:

k = 5  # k折

num_val_samples = (len(train_data)) // k

num_epochs = 500

all_mae = []
all_val_mae = []
all_loss = []
all_val_loss = []

for i in range(k):
    
    print(f"第 {i+1} fold正在running")
    
    # 准备验证集数据
    valid_data = train_data[i*num_val_samples:(i+1)*num_val_samples]
    valid_targets = train_targets[i*num_val_samples:(i+1)*num_val_samples]
    
    # 准备训练集(除去验证集部分)
    part_train_data = np.concatenate([
        train_data[:i*num_val_samples],  # 索引i之前和i+1之后,两部分合并起来成为训练集
        train_data[(i+1)*num_val_samples:]
    ],axis=0)
    
    part_train_targets = np.concatenate([
        train_targets[:i*num_val_samples],
        train_targets[(i+1)*num_val_samples:]
    ],axis=0) 
    
    
    # 模型训练
    model = build_model()
    history = model.fit(part_train_data,
             part_train_targets,
             epochs=num_epochs,
             batch_size=1,
             verbose=0,   # 静默模式,不显示每个epochs的具体打印信息
             validation_data=[valid_data,valid_targets])
    history_dict = history.history
    
    mae = history_dict["mae"]  # 训练集的mae
    val_mae = history_dict["val_mae"]  # 训练集的mae
    loss = history_dict["loss"]
    val_loss = history_dict["val_loss"]
    
    all_mae.append(mae)
    all_val_mae.append(val_mae)
    all_loss.append(loss)
    all_val_loss.append(val_loss)
第 1 fold正在running
第 2 fold正在running
第 3 fold正在running
第 4 fold正在running
第 5 fold正在running

模型指标可视化

In [10]:

len(all_mae) # 5折

Out[10]:

5

In [11]:

# all_mae[0]  # 第1折的全部信息

第1折有500个元素(epochs=500)

In [12]:

len(all_mae[0])

Out[12]:

500

计算每个指标的平均值:

In [13]:

mae_average = [np.mean([x[i] for x in all_mae]) for i in range(num_epochs)]
mae_val_average = [np.mean([x[i] for x in all_val_mae]) for i in range(num_epochs)]

loss_average = [np.mean([x[i] for x in all_loss]) for i in range(num_epochs)]
loss_val_average = [np.mean([x[i] for x in all_val_loss]) for i in range(num_epochs)]

LOSS

In [14]:

num_epochs

Out[14]:

500

In [15]:

epochs = range(1, num_epochs+ 1)  # 作为横轴

plt.figure(figsize=(12,6))
plt.plot(epochs, loss_val_average, "blue")
plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.legend()
plt.title("Validation Loss")
plt.show()

基于plotly绘制图像:

In [16]:

import plotly_express as px

px.scatter(x=epochs,y=loss_val_average)

In [17]:

epochs = range(1, num_epochs+ 1)  # 作为横轴

plt.figure(figsize=(12,6))

plt.plot(epochs, loss_average, "blue", label="Training Loss")
plt.plot(epochs, loss_val_average, "red", label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.legend()
plt.title("Training and Validation Loss")
plt.show()

MAE

In [18]:

epochs = range(1, num_epochs+1)  # 作为横轴

plt.plot(epochs, mae_average, "blue", label="Training MAE")
plt.plot(epochs, mae_val_average, "red", label="Validation MAE")
plt.xlabel("Epochs")
plt.ylabel("MAE")

plt.legend()
plt.title("Training and Validation MAE")
plt.show()

模型优化

可以看到,在训练集或者验证集上,不管是损失Loss还是误差MAE,前面的10个数据点和其他点差异很大,属于异常值,考虑直接删除。

数据的平滑处理:将每个数据点替换为前面数据点的平均值,得到较为光滑的曲线。

In [19]:

def smooth(points, factor=0.9):
    smooth_points = []

    for point in points:
        if smooth_points: 
            previous = smooth_points[-1]
            smooth_points.append(previous * factor + point * (1 - factor)) # 前一个点 * 0.9 + 当前点 * 0.1
        else:  # 不存在元素则添加point点
            smooth_points.append(point)
                
    return smooth_points

排除前10个点后进行平滑处理:

In [20]:

smooth_mae = smooth(mae_average[10:]) 
smooth_mae_val = smooth(mae_val_average[10:]) 
smooth_loss = smooth(loss_average[10:]) 
smooth_loss_val = smooth(loss_val_average[10:]) 

新LOSS

In [21]:

epochs = range(1, len(smooth_mae)+ 1)  # 作为横轴

plt.figure(figsize=(12,6))

plt.plot(epochs, smooth_loss, "blue", label="Training Loss")
plt.plot(epochs, smooth_loss_val, "red", label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")

plt.legend()
plt.title("Smoothed Training and Validation Loss")
plt.show()

# fig = px.scatter(x=epochs, y=smooth_mae)
fig = px.scatter(x=epochs, y=smooth_loss_val)

fig.show()

新MAE

In [23]:

epochs = range(1, len(smooth_mae)+ 1)  # 作为横轴

plt.figure(figsize=(12,6))

plt.plot(epochs, smooth_mae, "blue", label="Training MAE")
plt.plot(epochs, smooth_mae_val, "red", label="Validation MAE")
plt.xlabel("Epochs")
plt.ylabel("MAE")

plt.legend()
plt.title("Smoothed Training and Validation MAE")
plt.show()

# fig = px.scatter(x=epochs, y=smooth_mae)
fig = px.scatter(x=epochs, y=smooth_mae_val) 

fig.show()

可以看到:在训练集上loss和mae随着轮次的进行,都在逐渐变小;但是在验证集上,并非如此,在50轮左右降到最低;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1001640.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JavaScript中的事件捕获(event capturing)和事件冒泡(event bubbling)

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 事件捕获和事件冒泡⭐ 事件捕获(Event Capturing)示例: ⭐ 事件冒泡(Event Bubbling)示例: ⭐ 应用场景⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开…

苹果电脑版虚拟机推荐 VMware Fusion Pro for mac(vm虚拟机)

VMware Fusion Pro是一款功能强大的虚拟化软件,专为Mac用户设计。它允许用户在Mac上创建、运行和管理虚拟机,以便同时运行多个操作系统和应用程序。 以下是VMware Fusion Pro的一些主要特点和功能: 1. 多操作系统支持:VMware Fu…

CocosCreator3.6.2图片导入到工程,没办法拖动到场景中

解决方案:将资源的属性类型由texture调整为sprite-frame

字节、华为、美团软件测试面试真题(超详细~)

前言 最近已经算是秋招了,所以最近博主会努力给大家搜集整理一些各大公司测试岗测开岗的面经,希望能帮助到大家更好的入职想去的公司哦,关注我,一个每日分享软件测试知识的日更博主。 同时,我也准备了一份软件测试面…

Git 概述命令、idea中的使用

目录 Git概述 Git代码托管服务 Git常用命令 Git 全局设置 获取 Git 仓库 ​编辑Git 工作区中文件的状态 本地仓库操作 远程仓库操作 ​编辑分支操作 标签操作 在IDEA中使用Git 1.获取Git仓库 .gitignore 表示忽略 2.本地仓库操作 3.远程仓库操作 4.分支操作 Git是…

C++项目实战——基于多设计模式下的同步异步日志系统-⑤-实用工具类设计与实现

文章目录 专栏导读获取系统时间time介绍 getTime函数设计判断文件是否存在stat介绍exists函数设计 获取文件所在路径find_last_of介绍path函数设计 创建文件所在目录mkdir介绍find_first_of介绍函数createDirectory设计 实用工具类整理 专栏导读 🌸作者简介&#xf…

智慧公厕助力数字强市建设,打造善感知新型信息化公共厕所

随着城市建设的不断发展,智慧公厕作为一个重要的基础设施,正逐渐受到人们的重视。智慧公厕不仅为人们提供舒适的使用环境,更是通过数字化技术的应用,为城市发展注入新的动力。本文将以智慧公厕源头厂家广州中期科技有限公司&#…

我是如何用 redis 分布式锁来解决线上历史业务问题的

近期发现,开发功能的时候发现了一个 mq 消费顺序错乱(历史遗留问题),导致业务异常的问题,看看我是如何解决的 问题抛出 首先,简单介绍一下情况: 线上 k8s 有多个 pod 会去消费 mq 中的消息&a…

编写更嵌入式软件代码的10个技巧

代码维护是应用程序开发的重要方面,而为了缩短上市时间,通常会忽略代码维护。对于某些应用程序,这可能不会造成重大问题,因为这些应用程序的寿命很短,或者已部署该应用程序,并且再也不会碰它。 但是&#x…

UIScrollView setContentOffset: animated:

项目中遇到感觉一切都设置对了,但是看到的效果和预想的不一样。 后来查询了一番,才知道问题所在,现在记录一下,担心过后又忘了。 最初的问题是这样的,这个热度只有在评论里有,点击赞的时候,热度…

视频号的视频怎么下载,有什么下载工具推荐

视频下载助手去水印小工具是一款方便实用的工具,可以帮助用户在下载视频的时候可以一键去除视频中的水印。 该工具支持多种视频平台的去水印功能,如抖音、快手、小红书、视频号、公众号文字视频、西瓜视频、哔哩哔哩、微博视频、多多视频等。 经过亲自测…

为什么女程序员那么稀缺?女程序员吃不吃香?

程序员脱单一直是个难题,这里的一个客观原因就是程序员群体的男女比例严重失衡(比如我司达到了2:8),身边的工作环境缺少异性,大老爷们天天混在一起,脱单自然也就更加困难了。 女程序员那么稀缺&#xff0c…

《Python深度学习-Keras》精华笔记3:解决深度学习多分类问题

公众号:机器学习杂货店作者:Peter编辑:Peter 持续更新《Python深度学习》一书的精华内容,仅作为学习笔记分享。 本文是第三篇:介绍如何使用Keras解决Python深度学习中的多分类问题。 多分类问题和二分类问题的区别注意…

180页的Python完全版电子书

大家好,我是涛哥。 Python学习有很多方式,可以从基础一步步看语法, 可以从案例一步步学习,本篇内容就是通过案例进行讲解,方便大家一步一步进行学习实战。 整个内容经过几个月总结《Python之路2.0.pdf》&#xff0c…

基于 Python 的音乐流派分类

音乐就像一面镜子,它可以告诉人们很多关于你是谁,你关心什么,不管你喜欢与否。我们喜欢说“you are what you stream” - Spotify Spotify 拥有 260 亿美元的净资产,是如今很受欢迎的音乐流媒体平台。它目前在其数据库中拥有数百…

Java拓展--空间复杂度和时间复杂度

空间复杂度和时间复杂度 文章目录 空间复杂度和时间复杂度空间复杂度时间复杂度**评价排序算法****时间频度****什么是时间频度****忽略常数项****忽略低次项****忽略系数** **时间复杂度****什么是时间复杂度****计算时间复杂度的方法****常见的时间复杂度** **常见的时间复杂…

正中优配:证券融资融券是什么意思?

证券融资融券(简称“融资融券”)是一种股票出资办法,是指出资者经过融券生意和融资生意来进行股票出资。它在出资商场上具有重要的作用,因为经过这种办法,出资者能够使用假贷资金进行股票生意,能够进步出资…

腾讯云4核8G服务器选CVM还是轻量比较好?价格对比

腾讯云4核8G云服务器可以选择轻量应用服务器或CVM云服务器标准型S5实例,轻量4核8G12M服务器446元一年,CVM S5云服务器935元一年,相对于云服务器CVM,轻量应用服务器性价比更高,轻量服务器CPU和CVM有区别吗?性…

23062C++QTday4

仿照string类&#xff0c;完成myString 类 代码&#xff1a; #include <iostream> #include <cstring> using namespace std; class myString {private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度public://无参构造my…

华为云云耀云服务器L实例评测 | 由于自己原因导致MySQL数据库被攻击 【更新中。。。】

目录 引出起因&#xff08;si因&#xff09;解决报错诶嘿&#xff0c;连上了 不出意外&#xff0c;就出意外了打开数据库what&#xff1f;&#xff1f;&#xff1f; 找华为云求助教训&#xff1a;备份教训&#xff1a;密码 解决1.改密码2.新建一个MySQL&#xff0c;密码设置复杂…