深度学习-房价预测案例

news2024/11/20 13:24:18

1. 实现几个函数方便下载数据

import hashlib
import os
import tarfile
import zipfile
import requests
 
#@save
DATA_HUB = dict()
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'
 
def download(name, cache_dir=os.path.join('..', 'data')):  #@save
    """下载一个DATA_HUB中的文件,返回本地文件名"""
    assert name in DATA_HUB, f"{name} 不存在于 {DATA_HUB}"# 判断变量name是否存在于DATA_HUB中,不在则抛出异常
    url, sha1_hash = DATA_HUB[name]
    os.makedirs(cache_dir, exist_ok=True)# cache_dir目录不存在,则创建该目录,如果目录已经存在,则什么都不做
    fname = os.path.join(cache_dir, url.split('/')[-1])# 拼接成一个完整的路径
    if os.path.exists(fname): # 路径存在
        sha1 = hashlib.sha1() # 创建了一个哈希对象
        with open(fname, 'rb') as f:
            while True:
                data = f.read(1048576)
                if not data:
                    break
                sha1.update(data)
        if sha1.hexdigest() == sha1_hash:
            return fname  # 命中缓存
    print(f'正在从{url}下载{fname}...')
    r = requests.get(url, stream=True, verify=True)
    with open(fname, 'wb') as f:
        f.write(r.content)
    return fname
 
 
def download_extract(name, folder=None):  #@save
    """下载并解压zip/tar文件"""
    fname = download(name)
    base_dir = os.path.dirname(fname)
    data_dir, ext = os.path.splitext(fname)
    if ext == '.zip':
        fp = zipfile.ZipFile(fname, 'r')
    elif ext in ('.tar', '.gz'):
        fp = tarfile.open(fname, 'r')
    else:
        assert False, '只有zip/tar文件可以被解压缩'
    fp.extractall(base_dir)
    return os.path.join(base_dir, folder) if folder else data_dir
 
def download_all():  #@save
    """下载DATA_HUB中的所有文件"""
    for name in DATA_HUB:
        download(name)

2. 使用pandas读入并处理数据

%matplotlib inline
import numpy as np
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l

DATA_HUB['kaggle_house_train'] = (  # 将数据集的名称kaggle_house_train作为字典DATA_HUB的键
    DATA_URL + 'kaggle_house_pred_train.csv', # 数据集的下载链接
    '585e9cc93e70b39160e7921475f9bcd7d31219ce') # 哈希值用于验证数据的完整性

DATA_HUB['kaggle_house_test'] = (  
    DATA_URL + 'kaggle_house_pred_test.csv',
    'fa19780a7b011d9b009e8bff8e99922a8ee2eb90')

# 从指定的数据源下载名为'kaggle_house_train'的CSV文件,
# 并使用pd.read_csv()函数将其读取为一个DataFrame对象,并将该对象赋值
train_data = pd.read_csv(download('kaggle_house_train'))
test_data = pd.read_csv(download('kaggle_house_test'))

print(train_data.shape)
print(test_data.shape)

在这里插入图片描述

观察特征

打印出前4行,前4列和最后3列打印出来
【使用iloc属性对train_data这个DataFrame对象进行切片操作,选取了指定行和列的数据子集】

print(train_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]])

在这里插入图片描述

在每个样本中,第一个特征ID不能参与训练,所以要将其删除

saleprice作为标签在训练数据中要进行删除

all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))# 将train_data去除第一列ID和最后一列标签,和去除id的test_data进行合并

将所有缺失的值替换为相应特征的平均值,通过将特征重新缩放到零均值和单位方差来标准化数据

【.fillna(0)对选择的数值型特征进行了填充操作,将缺失值(NaN值)填充为0。fillna()是一个DataFrame对象的方法,用于填充缺失值】

numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index # all_features.dtypes != 'object'-》数值型数据
"""-》将数值特征均值设为0,方差设为1"""
all_features[numeric_features] = all_features[numeric_features].apply(
    lambda x: (x - x.mean()) / (x.std())) # 将(数值特征 - 均值)/方差 

all_features[numeric_features] = all_features[numeric_features].fillna(0) # 对选择的数值型特征进行了填充操作,将缺失值(NaN值)填充为0

处理离散值,用一次独热编码替换它们

all_features = pd.get_dummies(all_features, dummy_na=True)
all_features.shape

在这里插入图片描述

从pandas格式中提取NumPy格式,并将其转化为张量表示

【.values将该列数据转换为一个Numpy数组。

.reshape(-1, 1)改变数组的形状,将其变为一个列向量(具有一列)。】

n_train = train_data.shape[0]
all_features = all_features.astype(float) # 进行强制类型转化否则会报错
train_features = torch.tensor(all_features[:n_train].values, # 之前将train_data和 test_data结合,现在进行下标分开
                              dtype=torch.float32)
test_features = torch.tensor(all_features[n_train:].values, 
                             dtype=torch.float32)
train_labels = torch.tensor(train_data.SalePrice.values.reshape(-1, 1),# SalePrice列数据提取出来,并将其转换为一个列向量(具有一列)
                            dtype=torch.float32)

训练

loss = nn.MSELoss()
in_features = train_features.shape[1]

def get_net():
    net = nn.Sequential(nn.Linear(in_features, 1)) # 使用单层线性回归,输入特征数:in_features,输出特征数:1
    return net

为解决误差的影响,可以使用相对误差 (真实房价-预测房价/真实房价),其中一种方法是用价格预测的对数来衡量差异

【torch.clamp()函数会将输出结果中小于下界的值替换为下界,将大于上界的值替换为上界,因此它可以用来对输出结果进行范围限制】

def log_rmse(net, features, labels): # log可以将除法转化为减法
    clipped_preds = torch.clamp(net(features), 1, float('inf'))# 对输出进行截断,将小于1的值设置为1,大于float('inf')的值保持不变
    rmse = torch.sqrt(loss(torch.log(clipped_preds), torch.log(labels))) # 对预测和实际标签进行log,然后传入损失函数后取根号
    return rmse.item()# 返回 张量rmse中的值提取为一个标量

训练函数将借助Adam优化器

def train(net, train_features, train_labels, test_features, test_labels,
          num_epochs, learning_rate, weight_decay, batch_size):
    train_ls, test_ls = [], []
    train_iter = d2l.load_array((train_features, train_labels), batch_size)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate,# 使用Adam【对学习率不太敏感】进行优化
                                 weight_decay=weight_decay)  # 权重衰减(weight decay)参数【lamdb】,用于控制模型参数的正则化
    """训练"""
    for epoch in range(num_epochs):
        for X, y in train_iter:
            optimizer.zero_grad() # 优化器梯度清0
            l = loss(net(X), y) # 计算损失
            l.backward() # 反向传播计算梯度
            optimizer.step() # 更新优化器参数
        train_ls.append(log_rmse(net, train_features, train_labels)) # 更新数据
        if test_labels is not None:
            test_ls.append(log_rmse(net, test_features, test_labels))
    return train_ls, test_ls

K则交叉验证

def get_k_fold_data(k, i, X, y):
    assert k > 1
    fold_size = X.shape[0] // k # 每一折的大小是样本数/k
    X_train, y_train = None, None
    for j in range(k):
        idx = slice(j * fold_size, (j + 1) * fold_size) # 计算每个切片的起始和终止位置,根据切片索引idx取出相应位置上的数。
        X_part, y_part = X[idx, :], y[idx] # 取出相应位置
        if j == i: # 如果此时j==i,当前迭代的fold为验证集,则将切片X_part和y_part赋值给X_valid和y_valid。
            X_valid, y_valid = X_part, y_part
            
        elif X_train is None: # 如果训练集为空,则将切片X_part和y_part赋值给X_train和y_train
            X_train, y_train = X_part, y_part
            
        else: # 否则,将切片X_part和y_part与之前的训练集进行拼接,使用torch.cat()函数进行行拼接,将结果重新赋值给X_train和y_train。
            X_train = torch.cat([X_train, X_part], 0)
            y_train = torch.cat([y_train, y_part], 0)
    # 返回训练集和验证集
    return X_train, y_train, X_valid, y_valid

返回训练和验证误差的平均值

def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay,
           batch_size):
    train_l_sum, valid_l_sum = 0, 0
    for i in range(k): # 做k次交叉验证
        data = get_k_fold_data(k, i, X_train, y_train)
        net = get_net()
        train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,
                                   weight_decay, batch_size)
        train_l_sum += train_ls[-1]
        valid_l_sum += valid_ls[-1]
        if i == 0:
            d2l.plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls],
                     xlabel='epoch', ylabel='rmse', xlim=[1, num_epochs],
                     legend=['train', 'valid'], yscale='log')
        print(f'fold {i + 1}, train log rmse {float(train_ls[-1]):f}, '
              f'valid log rmse {float(valid_ls[-1]):f}')
    return train_l_sum / k, valid_l_sum / k # 返回平均测试集和验证集的损失

模型选择

k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr,
                          weight_decay, batch_size)
print(f'{k}-折验证: 平均训练log rmse: {float(train_l):f}, '
      f'平均验证log rmse: {float(valid_l):f}')

在这里插入图片描述
需要关注valid验证集的损失,需要不断的调整参数实现最小的损失

提交Kaggle预测

def train_and_pred(train_features, test_feature, train_labels, test_data,
                   num_epochs, lr, weight_decay, batch_size):
    net = get_net()
    train_ls, _ = train(net, train_features, train_labels, None, None,
                        num_epochs, lr, weight_decay, batch_size) # 返回训练过程中的训练误差列表train_ls和验证误差列表valid_ls,但在这个函数调用中用下划线 _ 代替了后者
    # 绘制并显示训练误差的变化情况
    d2l.plot(np.arange(1, num_epochs + 1), [train_ls], xlabel='epoch',
             ylabel='log rmse', xlim=[1, num_epochs], yscale='log')
    print(f'train log rmse {float(train_ls[-1]):f}')
    # 使用训练好的模型net对测试特征进行预测,得到预测结果preds
    preds = net(test_features).detach().numpy()
    # 预测结果转换为Numpy数组,并将其赋值给测试数据集test_data的'SalePrice'列。
    test_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])
    # 将预测结果和对应的'Id'列组合成一个DataFrame submission
    submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)
    # 将submission保存为CSV文件submission.csv
    submission.to_csv('submission.csv', index=False)
# 调用了train_and_pred()函数,传入相应的参数,执行整个训练和预测的过程
train_and_pred(train_features, test_features, train_labels, test_data,
               num_epochs, lr, weight_decay, batch_size)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1088536.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

独立站活动怎么复盘,做独立站需要掌握哪些?-站斧浏览器

独立站的活动形式多种多样,可以通过推出抽奖活动、举办线下活动或者利用社交媒体平台来增加用户互动和参与度。但是要做好一个独立站,除了活动形式,还需要掌握设计能力、编程技术、SEO知识和内容创作能力。 独立站活动怎么复盘? …

学习编程-先改变心态

编程失败的天才 林一和我很久以前就认识了——我从五年级就认识他了。他是班上最聪明的孩子。如果每个人在家庭作业或考试准备方面需要帮助,他们都会去那里。 有趣的是,林一不是那种连续学习几个小时的孩子。 他的聪明才智似乎与生俱来,几乎毫…

Selenium八大定位策略实战,你会了么?

Selenium是一款非常强大的自动化测试工具,支持多种编程语言,如Java、Python等。在使用Selenium进行自动化测试时,定位元素是非常重要的一步,只有正确定位到元素才能进行后续的操作,如输入数据、点击按钮等。在Selenium…

HarmonyOS/OpenHarmony原生应用-ArkTS万能卡片组件Radio

单选框,提供相应的用户交互选择项。该组件从API Version 8开始支持。无子组件。 一、接口 Radio(options: {value: string, group: string}) 从API version 9开始,该接口支持在ArkTS卡片中使用。 参数: 二、属性 除支持通用属性外,还支持以…

springBoot组件注册

springBoot组件注册 前言1、创建组件文件2、写属性3、生成get和set方法4、以前注册的方法5、现在注册的方法6、在启动文件查看7、多实例Scope("prototype")8、注册第三方包导入对应的场景启动器注册组件查看是否存在也可以通过Import(FastsqlException.class)导入但是…

文字雨特效

效果展示 CSS 知识点 简易实现云朵技巧text-shadow 属性的灵活运用filter 属性实现元素自动变色 实现页面布局 <div class"container"><div class"cloud"><h2>Data Clouds Rain</h2></div> </div>实现云朵 实现云…

什么是API网关?——驱动数字化转型的“隐形冠军”

什么是API网关 API网关&#xff08;API Gateway&#xff09;是一个服务器&#xff0c;位于应用程序和后端服务之间&#xff0c;提供了一种集中式的方式来管理API的访问。它是系统的入口点&#xff0c;负责接收并处理来自客户端的请求&#xff0c;然后将请求路由到相应的后端服…

代码随想录第14天 | ● 300.最长递增子序列 ● 674. 最长连续递增序列 ● 718. 最长重复子数组

300.最长递增子序列 /*** param {number[]} nums* return {number}*/ var lengthOfLIS function(nums) {let dp Array(nums.length).fill(1);let result 1;for(let i 1; i < nums.length; i) {for(let j 0; j < i; j) {if(nums[i] > nums[j]) {dp[i] Math.max…

最新开源ThinkPHP6框架云梦卡社区系统源码/亲测可用(全新开发)

源码简介&#xff1a; 最新开源ThinkPHP6云梦卡社区系统源码&#xff0c;它是一款基于ThinkPHP 6框架开发的开源社区系统源码。该系统源码具有强大而稳定的后端架构&#xff0c;和简洁易操作的前端界面&#xff0c;能够给人们提供完整的社区功能和更具体的服务。 全新云梦卡社…

fiddler如何抓模拟器中APP的包

第一步&#xff1a;fiddler配置 1、打开fiddler&#xff0c;依次点击工具&#xff08;tools&#xff09;》选项&#xff08;options&#xff09; 2、进入HTTPS选项&#xff0c;先选中DecryptHTTPStraffic&#xff0c;再选中ignore server certificate errors (unsafe) 3、点击…

【战略合作】新的游戏合作伙伴来袭,CARV 助力 Aavegotchi 发展!

想象这样的一个世界&#xff0c;你的游戏成就不仅仅是徽章&#xff0c;而是你链上声誉的一部分&#xff01;我们的最新游戏合作伙伴 CARV 便遵循这样的愿景。CARV 与 Aavegotchi 达成合作&#xff0c;旨在将下一代游戏玩家引入 Web3 世界。 CARV 正在构建一个以游戏为核心的身…

【云计算】相关解决方案介绍

文章目录 1.1 云服务环境 Eucalyptus1.1.1 介绍1.1.2 开源协议及语言1.1.3 官方网站 1.2 开源云计算平台 abiCloud1.2.1 开源协议及语言1.2.2 官方网站 1.3 分布式文件系统 Hadoop1.3.1 开源协议及语言1.3.2 官方网站 1.4 JBoss云计算项目集 StormGrind1.4.1 开源协议及语言1.4…

C# 图像灰化处理方法及速度对比

图像处理过程中&#xff0c;比较常见的灰化处理&#xff0c;将彩色图像处理为黑白图像&#xff0c;以便后续的其他处理工作。 在面对大量的图片或者像素尺寸比较大的图片的时候&#xff0c;处理速度和性能就显得非常重要&#xff0c;下面分别用3种方式来处理图像数据&#xff0…

紫光同创FPGA实现UDP协议栈网络视频传输,基于YT8511和RTL8211,提供4套PDS工程源码和技术支持

目录 1、前言免责声明 2、相关方案推荐我这里已有的以太网方案紫光同创FPGA精简版UDP方案紫光同创FPGA带ping功能UDP方案 3、设计思路框架OV7725摄像头配置及采集OV5640摄像头配置及采集UDP发送控制视频数据组包数据缓冲FIFOUDP协议栈详解RGMII转GMII动态ARPUDP协议IP地址、端口…

gradle版本是7.1.3加载arr包踩坑

第一次尝试&#xff1a; 将arr包放入到libs中&#xff0c; 在build.gradle中添加 implementation(name:**, ext:aar) Make project报错&#xff1a; Could not find :jdsmart-common-b3593f1-1.2.04:. Required by:project :launcherserver Search in build.gradle files根据…

工业网关都是什么?具体怎么应用?

随着工业自动化的不断发展&#xff0c;各种协议和标准在行业中变得越来越重要。其中&#xff0c;工业网关是实现不同设备之间通信和数据传输的关键设备。本文将以HiWoo Box为例&#xff0c;介绍工业网关的概念、应用场景以及具体的应用方式。 一、工业网关的概念 工业网关是一…

【使用教程】在Ubuntu下PMM60系列一体化伺服电机通过SDO跑循环同步位置模式详解

本教程将指导您在Ubuntu操作系统下使用SDO&#xff08;Service Data Object&#xff09;来配置和控制PMM60系列一体化伺服电机以实现循环同步位置模式。我们将介绍必要的步骤和命令&#xff0c;以确保您能够成功地配置和控制PMM系列一体化伺服电机。 01.准备工作 在正式介绍之…

3分钟彻底搞懂Web UI自动化测试之【POM设计模式】

为什么要用POM设计模式 前期&#xff0c;我们学会了使用PythonSelenium编写Web UI自动化测试线性脚本 线性脚本&#xff08;以快递100网站登录举例&#xff09;&#xff1a; import timefrom selenium import webdriver from selenium.webdriver.common.by import Bydriver …

低代码:避免重复造轮子的高效工具

一、前言 在软件开发和其他工程领域&#xff0c;“重复造轮子”被广泛认为是一种低效的做法&#xff0c;因为它浪费了大量的时间和资源去重新创作已经存在的东西&#xff0c;而不是利用现有的技术和经验去解决问题。 因此&#xff0c;为了避免“重复造轮子”&#xff0c;开发人…

数据库安全-RedisHadoopMysql未授权访问RCE

目录 数据库安全-&Redis&Hadoop&Mysql&未授权访问&RCE定义漏洞复现Mysql-CVE-2012-2122 漏洞Hadoop-配置不当未授权三重奏&RCE 漏洞 Redis-未授权访问-Webshell&任务&密匙&RCE 等漏洞定义&#xff1a;漏洞成因漏洞危害漏洞复现Redis-未授权…