深度学习 精选笔记(10)简单案例:房价预测

news2024/11/17 12:42:11

学习参考:

  • 动手学深度学习2.0
  • Deep-Learning-with-TensorFlow-book
  • pytorchlightning

①如有冒犯、请联系侵删。
②已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。
③非常推荐上面(学习参考)的前两个教程,在网上是开源免费的,写的很棒,不管是开始学还是复习巩固都很不错的。

深度学习回顾,专栏内容来源多个书籍笔记、在线笔记、以及自己的感想、想法,佛系更新。争取内容全面而不失重点。完结时间到了也会一直更新下去,已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。所有文章涉及的教程都会写在开头、一起学习一起进步。

在这里案例里面,测试集主要作用仅仅为评估模型泛化效果。并不具有“预测未来售价”的实际作用。因为测试集的其它特征都已经是假设已知的,而预测未来的时候这些都是相当于没发生的未知列,所以该案例中作用是相当于只是评估模型的作用。

一、加载数据集

将下载不同的数据集,并训练和测试模型。

download函数用来下载数据集, 将数据集缓存在本地目录(默认情况下为…/data)中, 并返回下载文件的名称。 如果缓存目录中已经存在此数据集文件,并且其sha-1与存储在DATA_HUB中的相匹配, 将使用缓存的文件,以避免重复的下载。

并实现两个实用函数: 一个将下载并解压缩一个zip或tar文件, 另一个是将本书中使用的所有数据集从DATA_HUB下载到缓存目录中。

import hashlib
import os
import tarfile
import zipfile
import requests
# 如果没有安装pandas,请取消下一行的注释
# !pip install pandas

%matplotlib inline
import numpy as np
import pandas as pd
import tensorflow as tf
from d2l import tensorflow as d2l

#@save
DATA_HUB = dict()
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'

def download_extract(name, folder=None):  #@save
    """下载并解压zip/tar文件"""
    fname = download(name)
    base_dir = os.path.dirname(fname)
    data_dir, ext = os.path.splitext(fname)
    if ext == '.zip':
        fp = zipfile.ZipFile(fname, 'r')
    elif ext in ('.tar', '.gz'):
        fp = tarfile.open(fname, 'r')
    else:
        assert False, '只有zip/tar文件可以被解压缩'
    fp.extractall(base_dir)
    return os.path.join(base_dir, folder) if folder else data_dir

def download_all():  #@save
    """下载DATA_HUB中的所有文件"""
    for name in DATA_HUB:
        download(name)

1.下载数据

数据分为训练集和测试集。 每条记录都包括房屋的属性值和属性,如街道类型、施工年份、屋顶类型、地下室状况等。 这些特征由各种数据类型组成。 例如,建筑年份由整数表示,屋顶类型由离散类别表示,其他特征由浮点数表示。 这就是现实让事情变得复杂的地方:例如,一些数据完全丢失了,缺失值被简单地标记为“NA”。

使用上面定义的脚本下载并缓存Kaggle房屋数据集。

DATA_HUB['kaggle_house_train'] = (  #@save
    DATA_URL + 'kaggle_house_pred_train.csv',
    '585e9cc93e70b39160e7921475f9bcd7d31219ce')

DATA_HUB['kaggle_house_test'] = (  #@save
    DATA_URL + 'kaggle_house_pred_test.csv',
    'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')

使用pandas分别加载包含训练数据和测试数据的两个CSV文件。

train_data

可以注意到,训练数据是有标签的(即价格)。

	Id	MSSubClass	MSZoning	LotFrontage	LotArea	Street	Alley	LotShape	LandContour	Utilities	...	PoolArea	PoolQC	Fence	MiscFeature	MiscVal	MoSold	YrSold	SaleType	SaleCondition	SalePrice
0	1	60	RL	65.0	8450	Pave	NaN	Reg	Lvl	AllPub	...	0	NaN	NaN	NaN	0	2	2008	WD	Normal	208500
1	2	20	RL	80.0	9600	Pave	NaN	Reg	Lvl	AllPub	...	0	NaN	NaN	NaN	0	5	2007	WD	Normal	181500
2	3	60	RL	68.0	11250	Pave	NaN	IR1	Lvl	AllPub	...	0	NaN	NaN	NaN	0	9	2008	WD	Normal	223500
3	4	70	RL	60.0	9550	Pave	NaN	IR1	Lvl	AllPub	...	0	NaN	NaN	NaN	0	2	2006	WD	Abnorml	140000
4	5	60	RL	84.0	14260	Pave	NaN	IR1	Lvl	AllPub	...	0	NaN	NaN	NaN	0	12	2008	WD	Normal	250000
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
1455	1456	60	RL	62.0	7917	Pave	NaN	Reg	Lvl	AllPub	...	0	NaN	NaN	NaN	0	8	2007	WD	Normal	175000
1456	1457	20	RL	85.0	13175	Pave	NaN	Reg	Lvl	AllPub	...	0	NaN	MnPrv	NaN	0	2	2010	WD	Normal	210000
1457	1458	70	RL	66.0	9042	Pave	NaN	Reg	Lvl	AllPub	...	0	NaN	GdPrv	Shed	2500	5	2010	WD	Normal	266500
1458	1459	20	RL	68.0	9717	Pave	NaN	Reg	Lvl	AllPub	...	0	NaN	NaN	NaN	0	4	2010	WD	Normal	142125
1459	1460	20	RL	75.0	9937	Pave	NaN	Reg	Lvl	AllPub	...	0	NaN	NaN	NaN	0	6	2008	WD	Normal	147500
1460 rows × 81 columns

在这里插入图片描述

test_data

可以发现测试数据是没有标签的。

Id	MSSubClass	MSZoning	LotFrontage	LotArea	Street	Alley	LotShape	LandContour	Utilities	...	ScreenPorch	PoolArea	PoolQC	Fence	MiscFeature	MiscVal	MoSold	YrSold	SaleType	SaleCondition
0	1461	20	RH	80.0	11622	Pave	NaN	Reg	Lvl	AllPub	...	120	0	NaN	MnPrv	NaN	0	6	2010	WD	Normal
1	1462	20	RL	81.0	14267	Pave	NaN	IR1	Lvl	AllPub	...	0	0	NaN	NaN	Gar2	12500	6	2010	WD	Normal
2	1463	60	RL	74.0	13830	Pave	NaN	IR1	Lvl	AllPub	...	0	0	NaN	MnPrv	NaN	0	3	2010	WD	Normal
3	1464	60	RL	78.0	9978	Pave	NaN	IR1	Lvl	AllPub	...	0	0	NaN	NaN	NaN	0	6	2010	WD	Normal
4	1465	120	RL	43.0	5005	Pave	NaN	IR1	HLS	AllPub	...	144	0	NaN	NaN	NaN	0	1	2010	WD	Normal
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
1454	2915	160	RM	21.0	1936	Pave	NaN	Reg	Lvl	AllPub	...	0	0	NaN	NaN	NaN	0	6	2006	WD	Normal
1455	2916	160	RM	21.0	1894	Pave	NaN	Reg	Lvl	AllPub	...	0	0	NaN	NaN	NaN	0	4	2006	WD	Abnorml
1456	2917	20	RL	160.0	20000	Pave	NaN	Reg	Lvl	AllPub	...	0	0	NaN	NaN	NaN	0	9	2006	WD	Abnorml
1457	2918	85	RL	62.0	10441	Pave	NaN	Reg	Lvl	AllPub	...	0	0	NaN	MnPrv	Shed	700	7	2006	WD	Normal
1458	2919	60	RL	74.0	9627	Pave	NaN	Reg	Lvl	AllPub	...	0	0	NaN	NaN	NaN	0	11	2006	WD	Normal
1459 rows × 80 columns

2.顺便删除无意义数据列

在每个样本中,第一个特征是ID,)这有助于模型识别每个训练样本。 虽然这很方便,但它不携带任何用于预测的信息。 因此,在将数据提供给模型之前,将其从数据集中删除,并且把训练集测试集拼接到一起、方便处理数据。

all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))
all_features
	MSSubClass	MSZoning	LotFrontage	LotArea	Street	Alley	LotShape	LandContour	Utilities	LotConfig	...	ScreenPorch	PoolArea	PoolQC	Fence	MiscFeature	MiscVal	MoSold	YrSold	SaleType	SaleCondition
0	60	RL	65.0	8450	Pave	NaN	Reg	Lvl	AllPub	Inside	...	0	0	NaN	NaN	NaN	0	2	2008	WD	Normal
1	20	RL	80.0	9600	Pave	NaN	Reg	Lvl	AllPub	FR2	...	0	0	NaN	NaN	NaN	0	5	2007	WD	Normal
2	60	RL	68.0	11250	Pave	NaN	IR1	Lvl	AllPub	Inside	...	0	0	NaN	NaN	NaN	0	9	2008	WD	Normal
3	70	RL	60.0	9550	Pave	NaN	IR1	Lvl	AllPub	Corner	...	0	0	NaN	NaN	NaN	0	2	2006	WD	Abnorml
4	60	RL	84.0	14260	Pave	NaN	IR1	Lvl	AllPub	FR2	...	0	0	NaN	NaN	NaN	0	12	2008	WD	Normal
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
1454	160	RM	21.0	1936	Pave	NaN	Reg	Lvl	AllPub	Inside	...	0	0	NaN	NaN	NaN	0	6	2006	WD	Normal
1455	160	RM	21.0	1894	Pave	NaN	Reg	Lvl	AllPub	Inside	...	0	0	NaN	NaN	NaN	0	4	2006	WD	Abnorml
1456	20	RL	160.0	20000	Pave	NaN	Reg	Lvl	AllPub	Inside	...	0	0	NaN	NaN	NaN	0	9	2006	WD	Abnorml
1457	85	RL	62.0	10441	Pave	NaN	Reg	Lvl	AllPub	Inside	...	0	0	NaN	MnPrv	Shed	700	7	2006	WD	Normal
1458	60	RL	74.0	9627	Pave	NaN	Reg	Lvl	AllPub	Inside	...	0	0	NaN	NaN	NaN	0	11	2006	WD	Normal
2919 rows × 79 columns

二、数据预处理

有各种各样的数据类型。 在开始建模之前,需要对数据进行预处理。

1.特征缩放(Z-score数据标准化)

首先,将所有缺失的值替换为相应特征的平均值。然后,为了将所有特征放在一个共同的尺度上, 通过将特征重新缩放到零均值和单位方差来标准化数据,其中 𝜇 和 𝜎分别表示均值和标准差。
在这里插入图片描述
标准化数据有两个原因: 首先,它方便优化。 其次,因为不知道哪些特征是相关的, 所以不想让惩罚分配给一个特征的系数比分配给其他任何特征的系数更大。

# 若无法获得测试数据,则可根据训练数据计算均值和标准差
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index
all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std()))

常用的预处理方法:将实值数据重新缩放为零均值和单位方法;用前后一段时间的均值替换缺失值。

2.处理缺失值

# 在标准化数据之后,所有均值消失,因此我们可以将缺失值设置为0
all_features[numeric_features] = all_features[numeric_features].fillna(0)

3.处理离散值

处理离散值。这包括诸如“MSZoning”之类的特征。 用独热编码替换它们, 方法与前面将多类别标签转换为向量的方式相同。

创建两个新的指示器特征“MSZoning_RL”和“MSZoning_RM”,其值为0或1。 根据独热编码,如果“MSZoning”的原始值为“RL”, 则:“MSZoning_RL”为1,“MSZoning_RM”为0。 pandas软件包会自动实现这一点。

# “Dummy_na=True”将“na”(缺失值)视为有效的特征值,并为其创建指示符特征
all_features = pd.get_dummies(all_features, dummy_na=True)
all_features.shape
all_features
	MSSubClass	LotFrontage	LotArea	OverallQual	OverallCond	YearBuilt	YearRemodAdd	MasVnrArea	BsmtFinSF1	BsmtFinSF2	...	SaleType_Oth	SaleType_WD	SaleType_nan	SaleCondition_Abnorml	SaleCondition_AdjLand	SaleCondition_Alloca	SaleCondition_Family	SaleCondition_Normal	SaleCondition_Partial	SaleCondition_nan
0	0.067320	-0.184443	-0.217841	0.646073	-0.507197	1.046078	0.896679	0.523038	0.580708	-0.29303	...	0	1	0	0	0	0	0	1	0	0
1	-0.873466	0.458096	-0.072032	-0.063174	2.187904	0.154737	-0.395536	-0.569893	1.177709	-0.29303	...	0	1	0	0	0	0	0	1	0	0
2	0.067320	-0.055935	0.137173	0.646073	-0.507197	0.980053	0.848819	0.333448	0.097840	-0.29303	...	0	1	0	0	0	0	0	1	0	0
3	0.302516	-0.398622	-0.078371	0.646073	-0.507197	-1.859033	-0.682695	-0.569893	-0.494771	-0.29303	...	0	1	0	1	0	0	0	0	0	0
4	0.067320	0.629439	0.518814	1.355319	-0.507197	0.947040	0.753100	1.381770	0.468770	-0.29303	...	0	1	0	0	0	0	0	1	0	0
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
1454	2.419286	-2.069222	-1.043758	-1.481667	1.289537	-0.043338	-0.682695	-0.569893	-0.968860	-0.29303	...	0	1	0	0	0	0	0	1	0	0
1455	2.419286	-2.069222	-1.049083	-1.481667	-0.507197	-0.043338	-0.682695	-0.569893	-0.415757	-0.29303	...	0	1	0	1	0	0	0	0	0	0
1456	-0.873466	3.884968	1.246594	-0.772420	1.289537	-0.373465	0.561660	-0.569893	1.717643	-0.29303	...	0	1	0	1	0	0	0	0	0	0
1457	0.655311	-0.312950	0.034599	-0.772420	-0.507197	0.682939	0.370221	-0.569893	-0.229194	-0.29303	...	0	1	0	0	0	0	0	1	0	0
1458	0.067320	0.201080	-0.068608	0.646073	-0.507197	0.715952	0.465941	-0.045732	0.694840	-0.29303	...	0	1	0	0	0	0	0	1	0	0
2919 rows × 331 columns

转换会将特征的总数量从79个增加到331个。 最后,通过values属性,可以 从pandas格式中提取NumPy格式,并将其转换为张量表示用于训练。

n_train = train_data.shape[0]
train_features = tf.constant(all_features[:n_train].values, dtype=tf.float32)
test_features = tf.constant(all_features[n_train:].values, dtype=tf.float32)
train_labels = tf.constant(train_data.SalePrice.values.reshape(-1, 1), dtype=tf.float32)

训练数据的特征内容如下,不包括标签列:

all_features[:n_train]
	MSSubClass	LotFrontage	LotArea	OverallQual	OverallCond	YearBuilt	YearRemodAdd	MasVnrArea	BsmtFinSF1	BsmtFinSF2	...	SaleType_Oth	SaleType_WD	SaleType_nan	SaleCondition_Abnorml	SaleCondition_AdjLand	SaleCondition_Alloca	SaleCondition_Family	SaleCondition_Normal	SaleCondition_Partial	SaleCondition_nan
0	0.067320	-0.184443	-0.217841	0.646073	-0.507197	1.046078	0.896679	0.523038	0.580708	-0.293030	...	0	1	0	0	0	0	0	1	0	0
1	-0.873466	0.458096	-0.072032	-0.063174	2.187904	0.154737	-0.395536	-0.569893	1.177709	-0.293030	...	0	1	0	0	0	0	0	1	0	0
2	0.067320	-0.055935	0.137173	0.646073	-0.507197	0.980053	0.848819	0.333448	0.097840	-0.293030	...	0	1	0	0	0	0	0	1	0	0
3	0.302516	-0.398622	-0.078371	0.646073	-0.507197	-1.859033	-0.682695	-0.569893	-0.494771	-0.293030	...	0	1	0	1	0	0	0	0	0	0
4	0.067320	0.629439	0.518814	1.355319	-0.507197	0.947040	0.753100	1.381770	0.468770	-0.293030	...	0	1	0	0	0	0	0	1	0	0
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
1455	0.067320	-0.312950	-0.285421	-0.063174	-0.507197	0.914028	0.753100	-0.569893	-0.968860	-0.293030	...	0	1	0	0	0	0	0	1	0	0
1456	-0.873466	0.672275	0.381246	-0.063174	0.391170	0.220763	0.178782	0.093673	0.765076	0.670295	...	0	1	0	0	0	0	0	1	0	0
1457	0.302516	-0.141607	-0.142781	0.646073	3.086271	-1.000704	1.040259	-0.569893	-0.365275	-0.293030	...	0	1	0	0	0	0	0	1	0	0
1458	-0.873466	-0.055935	-0.057197	-0.772420	0.391170	-0.703591	0.561660	-0.569893	-0.861312	5.788329	...	0	1	0	0	0	0	0	1	0	0
1459	-0.873466	0.243916	-0.029303	-0.772420	0.391170	-0.208401	-0.921995	-0.569893	0.852870	1.420862	...	0	1	0	0	0	0	0	1	0	0
1460 rows × 331 columns

测试集数据的特征如下,自然也是没有标签列:

MSSubClass	LotFrontage	LotArea	OverallQual	OverallCond	YearBuilt	YearRemodAdd	MasVnrArea	BsmtFinSF1	BsmtFinSF2	...	SaleType_Oth	SaleType_WD	SaleType_nan	SaleCondition_Abnorml	SaleCondition_AdjLand	SaleCondition_Alloca	SaleCondition_Family	SaleCondition_Normal	SaleCondition_Partial	SaleCondition_nan
0	-0.873466	0.458096	0.184340	-0.772420	0.391170	-0.340452	-1.113434	-0.569893	0.058332	0.558006	...	0	1	0	0	0	0	0	1	0	0
1	-0.873466	0.500932	0.519702	-0.063174	0.391170	-0.439490	-1.257014	0.032335	1.056991	-0.293030	...	0	1	0	0	0	0	0	1	0	0
2	0.067320	0.201080	0.464294	-0.772420	-0.507197	0.848003	0.657380	-0.569893	0.767271	-0.293030	...	0	1	0	0	0	0	0	1	0	0
3	0.067320	0.372424	-0.024105	-0.063174	0.391170	0.881015	0.657380	-0.458369	0.352443	-0.293030	...	0	1	0	0	0	0	0	1	0	0
4	1.478499	-1.126832	-0.654636	1.355319	-0.507197	0.682939	0.370221	-0.569893	-0.391613	-0.293030	...	0	1	0	0	0	0	0	1	0	0
...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...	...
1454	2.419286	-2.069222	-1.043758	-1.481667	1.289537	-0.043338	-0.682695	-0.569893	-0.968860	-0.293030	...	0	1	0	0	0	0	0	1	0	0
1455	2.419286	-2.069222	-1.049083	-1.481667	-0.507197	-0.043338	-0.682695	-0.569893	-0.415757	-0.293030	...	0	1	0	1	0	0	0	0	0	0
1456	-0.873466	3.884968	1.246594	-0.772420	1.289537	-0.373465	0.561660	-0.569893	1.717643	-0.293030	...	0	1	0	1	0	0	0	0	0	0
1457	0.655311	-0.312950	0.034599	-0.772420	-0.507197	0.682939	0.370221	-0.569893	-0.229194	-0.293030	...	0	1	0	0	0	0	0	1	0	0
1458	0.067320	0.201080	-0.068608	0.646073	-0.507197	0.715952	0.465941	-0.045732	0.694840	-0.293030	...	0	1	0	0	0	0	0	1	0	0
1459 rows × 331 columns

训练数据的标签:

train_data.SalePrice
0       208500
1       181500
2       223500
3       140000
4       250000
         ...  
1455    175000
1456    210000
1457    266500
1458    142125
1459    147500
Name: SalePrice, Length: 1460, dtype: int64

三、训练模型

训练一个带有损失平方的线性模型。

loss = tf.keras.losses.MeanSquaredError()

def get_net():
    net = tf.keras.models.Sequential()
    net.add(tf.keras.layers.Dense(
        1, kernel_regularizer=tf.keras.regularizers.l2(weight_decay)))
    return net

房价就像股票价格一样,关心的是相对数量,而不是绝对数量。 因此,更关心相对误差(𝑦−𝑦̂)/ 𝑦,而不是绝对误差𝑦−𝑦̂ 。
解决这个问题的一种方法是用价格预测的对数来衡量差异, 事实上,这也是比赛中官方用来评价提交质量的误差指标。 即将 𝛿 for |log𝑦−log𝑦̂ |≤𝛿 转换为 𝑒(−𝛿)≤𝑦̂/𝑦≤𝑒(𝛿) 。 这使得预测价格的对数与真实标签价格的对数之间出现以下均方根误差:
在这里插入图片描述

def log_rmse(y_true, y_pred):
    # 为了在取对数时进一步稳定该值,将小于1的值设置为1
    clipped_preds = tf.clip_by_value(y_pred, 1, float('inf'))
    return tf.sqrt(tf.reduce_mean(loss(
        tf.math.log(y_true), tf.math.log(clipped_preds))))

训练函数将借助Adam优化器,Adam优化器的主要吸引力在于它对初始学习率不那么敏感。

def train(net, train_features, train_labels, test_features, test_labels,
          num_epochs, learning_rate, weight_decay, batch_size):
    train_ls, test_ls = [], []
    train_iter = d2l.load_array((train_features, train_labels), batch_size)
    # 这里使用的是Adam优化算法
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    net.compile(loss=loss, optimizer=optimizer)
    for epoch in range(num_epochs):
        for X, y in train_iter:
            with tf.GradientTape() as tape:
                y_hat = net(X)
                l = loss(y, y_hat)
            params = net.trainable_variables
            grads = tape.gradient(l, params)
            optimizer.apply_gradients(zip(grads, params))
        train_ls.append(log_rmse(train_labels, net(train_features)))
        if test_labels is not None:
            test_ls.append(log_rmse(test_labels, net(test_features)))
    return train_ls, test_ls

四、K折交叉验证

K折交叉验证, 它有助于模型选择和超参数调整。首先需要定义一个函数,在 𝐾折交叉验证过程中返回第 𝑖 折的数据。具体地说,它选择第 𝑖 个切片作为验证数据,其余部分作为训练数据。注意,这并不是处理数据的最有效方法,如果数据集大得多,会有其他解决办法。

def get_k_fold_data(k, i, X, y):
    assert k > 1
    fold_size = X.shape[0] // k
    X_train, y_train = None, None
    for j in range(k):
        idx = slice(j * fold_size, (j + 1) * fold_size)
        X_part, y_part = X[idx, :], y[idx]
        if j == i:
            X_valid, y_valid = X_part, y_part
        elif X_train is None:
            X_train, y_train = X_part, y_part
        else:
            X_train = tf.concat([X_train, X_part], 0)
            y_train = tf.concat([y_train, y_part], 0)
    return X_train, y_train, X_valid, y_valid

在 𝐾 折交叉验证中训练 𝐾次后,返回训练和验证误差的平均值。

def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay,
           batch_size):
    train_l_sum, valid_l_sum = 0, 0
    for i in range(k):
        data = get_k_fold_data(k, i, X_train, y_train)
        net = get_net()
        train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
                                   weight_decay, batch_size)
        train_l_sum += train_ls[-1]
        valid_l_sum += valid_ls[-1]
        if i == 0:
            d2l.plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls],
                     xlabel='epoch', ylabel='rmse', xlim=[1, num_epochs],
                     legend=['train', 'valid'], yscale='log')
        print(f'折{i + 1},训练log rmse{float(train_ls[-1]):f}, '
              f'验证log rmse{float(valid_ls[-1]):f}')
    return train_l_sum / k, valid_l_sum / k

五、模型选择

选择了一组未调优的超参数,并将其留给读者来改进模型。 找到一组调优的超参数可能需要时间,这取决于一个人优化了多少变量。

有了足够大的数据集和合理设置的超参数, 𝐾 折交叉验证往往对多次测试具有相当的稳定性。 然而,如果尝试了不合理的超参数,可能会发现验证效果不再代表真正的误差。

k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr,
                          weight_decay, batch_size)
print(f'{k}-折验证: 平均训练log rmse: {float(train_l):f}, '
      f'平均验证log rmse: {float(valid_l):f}')
1,训练log rmse0.170193, 验证log rmse0.1573992,训练log rmse0.162299, 验证log rmse0.1895973,训练log rmse0.164040, 验证log rmse0.1678444,训练log rmse0.168239, 验证log rmse0.1548915,训练log rmse0.163944, 验证log rmse0.182911
5-折验证: 平均训练log rmse: 0.165743, 平均验证log rmse: 0.170528

在这里插入图片描述

六、保存预测数据

既然知道应该选择什么样的超参数, 不妨使用所有数据对其进行训练 (而不是仅使用交叉验证中使用的 1−1/𝐾的数据)。 然后,通过这种方式获得的模型可以应用于测试集。 将预测保存在CSV文件中。

def train_and_pred(train_features, test_features, train_labels, test_data,
                   num_epochs, lr, weight_decay, batch_size):
    net = get_net()
    #训练网络
    train_ls, _ = train(net, train_features, train_labels, None, None,
                        num_epochs, lr, weight_decay, batch_size)
    # 可视化损失变化
    d2l.plot(np.arange(1, num_epochs + 1), [train_ls], xlabel='epoch',
             ylabel='log rmse', xlim=[1, num_epochs], yscale='log')
    print(f'训练log rmse:{float(train_ls[-1]):f}')
    
    # 将网络应用于测试集。
    preds = net(test_features).numpy()
    # 将其重新格式化以导出到Kaggle
    test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])
    submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)
    submission.to_csv('submission.csv', index=False)
train_and_pred(train_features, test_features, train_labels, test_data,
               num_epochs, lr, weight_decay, batch_size)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1482254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Maven【4】(继承)(命令行操作)

文章目录 【1】基础概念【2】继承的作用【3】创建父工程和子工程【4】在父工程中统一管理依赖 【1】基础概念 说到继承,我们很容易想到Java中的继承,有子类和父类,子类继承父类,那么我们maven中的继承是什么呢? Maven…

二分查找常用解题模板(带一道leetcode题目)

1.为了较为清晰的写出各种情况,接下来的代码中不会出现else,而是将每一个else if均给写出来!!! 2.为了防止每次的mid溢出,我们均写为mid left (right - left) 基本的二分查找模板(寻找一个数) 基本问题描述&#xff…

计算机网络物理层知识点总结

本篇博客是基于谢希仁编写的《计算机网络》和王道考研视频总结出来的知识点,本篇总结的主要知识点是第二章的物理层。上一章的传送门:计算机网络体系结构-CSDN博客 通信基础 物理层概念 物理层解决如何在连接各种计算机的传输媒体上传输数据比特流&am…

服务器上部署WEb服务方法

部署Web服务在服务器上是一个比较复杂的过程。这不仅仅涉及到配置环境、选择软件和设置端口,更有众多其它因素需要考虑。以下是在服务器上部署WEb服务的步骤: 1. 选择服务器:根据项目规模和预期访问量,选择合适的服务器类型和配置…

MySQL:函数

提醒: 设定下面的语句是在数据库名为 db_book里执行的。 创建user_info表 注意:pwd为密码字段,这里使用了VARCHAR(128)类型,为了后面方便对比,开发项目里一般使用char(32),SQL语句里只有MD5加密函数 USE db…

iOS卡顿原因与优化

iOS卡顿原因与优化 1. 卡顿简介 卡顿: 指用户在使用过程中出现了一段时间的阻塞,使得用户在这一段时间内无法进行操作,屏幕上的内容也没有任何的变化。 卡顿作为App的重要性能指标,不仅影响着用户体验,更关系到用户留…

XUbuntu22.04之解决:仓库xxx没有数字签名问题(二百一十七)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

buuctf_misc_荷兰宽带数据泄露+被偷走的文件

荷兰宽带数据泄露 题目: 没啥,工具给大家放这了,这个(相对来说)比较安全 https://routerpassview.en.lo4d.com/windows 打开后,.bin文件直接托进去 只是我想不到的是,flag这算是username&…

H.266参考软件VTM各版本的性能差异

VTM(VVC Test Model),是H.266视频编码标准的参考软件,即是VVC spec.的一种参考实现,代码里包括了H.266的软件编码器和软件解码器实现,代码地址如下: https://vcgit.hhi.fraunhofer.de/jvet/VVCS…

基于单片机的节能窗控制系统设计

摘 要:本文以单片机为基础,对节能窗控制系统进行了科学设计,在满足日常生活需求的同时更好地实现节能减排目标。此设计中的节能窗控制系统,实际操作要灵活,具备可靠且稳定的性能,同时具备节能功效。 关键词:单片机;节能窗控制系统;系统设计 在节能窗等概念推广的背景…

前端学习第二天-html提升

达标要求 了解列表的分类 熟练掌握列表的用法 熟练掌握表格的结构构成 合并单元格 表单的组成 熟练掌握表单控件分类的使用 1.列表 1.1 无序列表 <ul>&#xff1a;定义无序列表&#xff0c;并且只能包含<li>子元素。 <li>&#xff1a;定义列表项&a…

hippy 调试demo运行联调-mac环境准备篇

适用对于终端编译环境不熟悉的人看&#xff0c;仅mac端 hippy 调试文档官网地址 前提&#xff1a;请使用node16 联调预览效果图&#xff1a; 编译iOS Demo环境准备 未跑通&#xff0c;待补充 编译Android Demo环境准备 1、正常安装Android Studio 2、下载Android NDK&a…

10-Java装饰器模式 ( Decorator Pattern )

Java装饰器模式 摘要实现范例 装饰器模式&#xff08;Decorator Pattern&#xff09;允许向一个现有的对象添加新的功能&#xff0c;同时又不改变其结构 装饰器模式创建了一个装饰类&#xff0c;用来包装原有的类&#xff0c;并在保持类方法签名完整性的前提下&#xff0c;提供…

JProfiler 14 for Mac 14.0激活版:Java性能分析的终极工具

JProfiler是一款专业的Java应用程序性能分析工具&#xff0c;可帮助开发人员识别和解决Java应用程序中的性能问题。JProfiler支持Java SE、Java EE和Android平台&#xff0c;提供了多种分析选项&#xff0c;包括CPU分析、内存分析和线程分析等。 软件下载&#xff1a;JProfiler…

本地快速部署谷歌开放模型Gemma教程(基于WasmEdge)

本地快速部署谷歌开放模型Gemma教程&#xff08;基于WasmEdge&#xff09; 一、介绍 Gemma二、部署 Gemma2.1 部署工具2.1 部署步骤 三、构建超轻量级 AI 代理四、总结 一、介绍 Gemma Gemma是一系列轻量级、最先进的开放式模型&#xff0c;采用与创建Gemini模型相同的研究和技…

利用IP地址识别风险用户:保护网络安全的重要手段

随着互联网的发展和普及&#xff0c;网络安全问题日益突出&#xff0c;各种网络诈骗、恶意攻击等风险不断涌现&#xff0c;给个人和企业的财产安全和信息安全带来了严重威胁。在这样的背景下&#xff0c;利用IP地址识别风险用户成为了保护网络安全的重要手段之一。IP数据云探讨…

太阳能供电井盖-物联网智能井盖监测系统-旭华智能

在这个日新月异的科技时代&#xff0c;城市的每一个角落都在悄然发生变化。而在这场城市升级的浪潮中&#xff0c;智能井盖以其前瞻性的科技应用和卓越的安全性能&#xff0c;正悄然崭露头角&#xff0c;变身马路上的智能“眼睛”&#xff0c;守护城市安全。 传统的井盖监测系统…

Facebook直播网络需要满足什么条件

Facebook直播已经成为了企业、个人和组织开展在线活动、互动和营销的重要平台之一。然而&#xff0c;要确保Facebook直播的顺利进行和观众体验的良好&#xff0c;需要满足一系列关键条件。本文将探讨Facebook直播网络 需要满足的关键条件。 1、稳定的互联网连接&#xff1a; 稳…

【airtest】自动化入门教程(二)airtest操作

目录 一、touch 二、wait 三、swipe 四、exists 五、text 六、keyevent 七、snapshot 八、sleep 九、断言 9.1 assert_exists 9.2 assert_not_exists 9.3 assert_equal 9.4 assert_not_equal 前言&#xff1a;本文主要针对aritest部分的基础操作,aritest是一个跨平…

加密与安全_探索口令加密算法(PBE)

文章目录 概述疑问PBE 算法 &#xff08; Password Based Encryption&#xff09;CodePOM实现 小结 概述 加密与安全_探索对称加密算法中我们提到AES加密密钥长度是固定的128/192/256位&#xff0c;而不是我们用WinZip/WinRAR那样&#xff0c;随便输入几位都可以。 这是因为对…