【深度学习实战】利用Linear Regression预测房价

news2024/11/27 5:39:34

本文参考了李沐老师的b站深度学习课程 课程链接,使用了线性回归模型,特别适合深度学习初学者。通过阅读本文,你将学会如何用PyTorch训练模型,并掌握一些实用的训练技巧。希望这些内容能对你的深度学习学习有所帮助。

安装pytorch

在命令行输入下面这段指令

pip install pytorch torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple

导入数据集

数据集来自kaggle一个比赛(比赛链接),可以通过链接自己下载数据集:

训练数据集 http://d2l-data.s3-accelerate.amazonaws.com/kaggle_house_pred_train.csv

测试数据集 http://d2l-data.s3-accelerate.amazonaws.com/kaggle_house_pred_test.csv

也可以通过一下这段代码下载

                                                                                    
import hashlib                                                                      
import os                                                                           
import tarfile                                                                      
import zipfile                                                                      
import requests                                                                     
                                                                                    
#@save                                                                              
DATA_HUB = dict()                                                                   
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'                           
def download(name, cache_dir=os.path.join('..', 'data')):  #@save                   
    """下载一个DATA_HUB中的文件,返回本地文件名"""                                                  
    assert name in DATA_HUB, f"{name} 不存在于 {DATA_HUB}"                              
    url, sha1_hash = DATA_HUB[name]                                                 
    os.makedirs(cache_dir, exist_ok=True)                                           
    fname = os.path.join(cache_dir, url.split('/')[-1])                             
    if os.path.exists(fname):                                                       
        sha1 = hashlib.sha1()                                                       
        with open(fname, 'rb') as f:                                                
            while True:                                                             
                data = f.read(1048576)                                              
                if not data:                                                        
                    break                                                           
                sha1.update(data)                                                   
        if sha1.hexdigest() == sha1_hash:                                           
            return fname  # 命中缓存                                                    
    print(f'正在从{url}下载{fname}...')                                                  
    r = requests.get(url, stream=True, verify=True)                                 
    with open(fname, 'wb') as f:                                                    
        f.write(r.content)                                                          
    return fname                                                                    
def download_extract(name, folder=None):  #@save                                    
    """下载并解压zip/tar文件"""                                                            
    fname = download(name)                                                          
    base_dir = os.path.dirname(fname)                                               
    data_dir, ext = os.path.splitext(fname)                                         
    if ext == '.zip':                                                               
        fp = zipfile.ZipFile(fname, 'r')                                            
    elif ext in ('.tar', '.gz'):                                                    
        fp = tarfile.open(fname, 'r')                                               
    else:                                                                           
        assert False, '只有zip/tar文件可以被解压缩'                                           
    fp.extractall(base_dir)                                                         
    return os.path.join(base_dir, folder) if folder else data_dir                   
                                                                                    
def download_all():  #@save                                                         
    """下载DATA_HUB中的所有文件"""                                                          
    for name in DATA_HUB:                                                           
        download(name)                                                              

通过pandas处理数据

import numpy as np # linear algebra                                             
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)          

train_data = pd.read_csv('/kaggle/input/kaggle_house_pred_train.csv')     
test_data = pd.read_csv('/kaggle/input/kaggle_house_pred_test.csv')       

train_data是训练集包含47439条数据,40个feature,还有一个是房价(label) 测试集只有也是47439条,40个feature。 

 在开始建模之前,我们需要对数据进行预处理。 首先,我们将所有缺失的值替换为相应特征的平均值。然后,为了将所有特征放在一个共同的尺度上, 我们通过将特征重新缩放到零均值和单位方差来标准化数据:

中μ和σ分别表示均值和标准差。 现在,这些特征具有零均值和单位方差,即  

直观地说,我们标准化数据有两个原因: 首先,它方便优化。 其次,因为我们不知道哪些特征是相关的, 所以我们不想让惩罚分配给一个特征的系数比分配给其他任何特征的系数更大。 

将缺失值替换为0,因为已经在上一步中进行了标准化,0在这个上下文中相当于均值


# train_data.loc[:, train_data.columns != 'Sold Price'] # 这行代码用于提取除'Sold Price'外的其他列
# 合并训练数据和测试数据,排除“Sold Price”列,因为这是我们预测的目标变量
all_features = pd.concat((train_data.loc[:, train_data.columns != 'Sold Price'], test_data.iloc[:, 1:]))

# 查看合并后的数据信息,以了解数据的整体情况
all_features.info()

# 将所有缺失的值替换为相应特征的平均值。通过将特征重新缩放到零均值和单位方差来标准化数据
# 首先,确定哪些特征是数值型的,因为我们将对这些特征进行标准化
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index

# 对数值型特征进行标准化处理:减去均值并除以标准差
all_features[numeric_features] = all_features[numeric_features].apply(
    lambda x: (x - x.mean()) / (x.std()))

# 将缺失值替换为0,因为已经在上一步中进行了标准化,0在这个上下文中相当于均值
all_features[numeric_features] = all_features[numeric_features].fillna(0)

all_features = all_features[numeric_features[1:]] # 原本第一列是Id,去掉
all_features.info()
<class 'pandas.core.frame.DataFrame'>
Index: 79065 entries, 0 to 31625
Data columns (total 18 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   Year built                   79065 non-null  float64
 1   Lot                          79065 non-null  float64
 2   Bathrooms                    79065 non-null  float64
 3   Full bathrooms               79065 non-null  float64
 4   Total interior livable area  79065 non-null  float64
 5   Total spaces                 79065 non-null  float64
 6   Garage spaces                79065 non-null  float64
 7   Elementary School Score      79065 non-null  float64
 8   Elementary School Distance   79065 non-null  float64
 9   Middle School Score          79065 non-null  float64
 10  Middle School Distance       79065 non-null  float64
 11  High School Score            79065 non-null  float64
 12  High School Distance         79065 non-null  float64
 13  Tax assessed value           79065 non-null  float64
 14  Annual tax amount            79065 non-null  float64
 15  Listed Price                 79065 non-null  float64
 16  Last Sold Price              79065 non-null  float64
 17  Zip                          79065 non-null  float64
dtypes: float64(18)
memory usage: 11.5 MB

将numpy 转换成tensor


# 从pandas格式中提取NumPy格式,并将其转换为张量表示
n_train = train_data.shape[0]#shape获取行、列数,只取行数————获取训练集行数
train_features = torch.tensor(all_features[:n_train].values,
                              dtype=torch.float32)
test_features = torch.tensor(all_features[n_train:].values,
                             dtype=torch.float32)
train_labels = torch.tensor(train_data['Sold Price'].values.reshape(-1, 1),
                            dtype=torch.float32)

 设计模型


from torch import nn
from torch.utils.data import dataset, DataLoader, TensorDataset
from torch import optim
# 定义一个继承自nn.Module的模型类
class model(nn.Module):
    def __init__(self, in_features):
        super(model, self).__init__()
        ### 定义模型 [b,40] ==> [b,1]
        self.net = nn.Sequential(nn.Linear(in_features, 1))

    def forward(self, x):
        ##只有一层所以不需要激活函数
        return self.net(x)

获取数据集,实例化模型


# 将训练数据和标签封装为数据集
train_datasets = TensorDataset(train_features, train_labels)
# 创建数据加载器,用于迭代加载数据集中的数据
train_data = DataLoader(train_datasets, batch_size=64, shuffle=True)
# 同样的操作应用于测试数据
test_datasets = TensorDataset(test_features)
test_data = DataLoader(test_datasets, batch_size=64, shuffle=True)

# 获取输入特征的数量
in_features = train_features.shape[1]
# 实例化模型
model = model(in_features)
# 定义均方误差损失函数
loss = nn.MSELoss()
# 定义优化器,使用Adam算法
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
# 再次获取输入特征的数量,用于下面的函数
in_features = train_features.shape[1]

 定义相关函数

1. 定义获取K折交叉验证数据的函数:它有助于模型选择和超参数调整。 我们首先需要定义一个函数,在K折交叉验证过程中返回第i折的数据(将一个数据集分成k折,k-1折作为训练级 ,1折作为验证集)。 具体地说,它选择第i个切片作为验证数据,其余部分作为训练数据。 注意,这并不是处理数据的最有效方法,如果我们的数据集大得多,会有其他解决办法。

# 定义获取K折交叉验证数据的函数
def get_k_fold_data(k, i, X, y):
    assert k > 1
    fold_size = X.shape[0] // k
    X_train, y_train = None, None
    for j in range(k):
        idx = slice(j * fold_size, (j + 1) * fold_size)
        X_part, y_part = X[idx, :], y[idx]
        if j == i:
            X_valid, y_valid = X_part, y_part
        elif X_train is None:
            X_train, y_train = X_part, y_part
        else:
            X_train = torch.cat([X_train, X_part], 0)
            y_train = torch.cat([y_train, y_part], 0)
    return X_train, y_train, X_valid, y_valid

2.房价就像股票价格一样,我们关心的是相对数量,而不是绝对数量。 因此,我们更关心相对误差y - \hat{y} / y\hat{}, 而不是绝对误差|y - y\hat{}|。 例如,如果我们在估计一栋房子的价格时, 假设我们的预测偏差了10万美元, 然而那里一栋典型的房子的价值是12.5万美元, 那么模型可能做得很糟糕。 另一方面,如果我们在加州豪宅区的预测出现同样的10万美元的偏差, (在那里,房价中位数超过400万美元) 这可能是一个不错的预测。

解决这个问题的一种方法是用价格预测的对数来衡量差异。 事实上,这也是比赛中官方用来评价提交质量的误差指标。 即将δ for |log⁡y−log⁡y^|≤δ 转换为e−δ≤y^y≤eδ。 这使得预测价格的对数与真实标签价格的对数之间出现以下均方根误差:

# 定义对数均方根误差函数
def log_rmse(preds, labels):
    clipped_preds = torch.clamp(preds, 1, float('inf'))
    rmse = torch.sqrt(loss(torch.log(clipped_preds), torch.log(labels)))
    return rmse.item()

训练模型


# 训练模型
for epochs in range(10):
    for batch_id, (x, y) in enumerate(train_data):
        # 获取K折交叉验证数据
        x_train, y_train, x_test, y_test = get_k_fold_data(5, 0, x, y)
        # 前向传播得到预测值
        pred = model(x_train)
        # 对预测值进行裁剪,确保其值在1到正无穷之间
        clipped_preds = torch.clamp(pred, 1, float('inf'))
        # 计算损失
        l = loss(torch.log(clipped_preds), torch.log(y_train))
        # 清零梯度
        optimizer.zero_grad()
        # 反向传播
        l.backward()
        # 更新参数
        optimizer.step()
        # 在测试数据上进行预测
        pred = model(x_test)
        # 打印当前批次的训练情况
        print(f'epoch {epochs + 1}, batch {batch_id}, valid log rmse {log_rmse(pred,y_test):f}')

# 在测试集上进行最终预测
pred = model(test_features)

结语

非常感谢您的阅读!我衷心希望这篇关于使用PyTorch进行线性回归模型训练的博客文章能够对您有所帮助。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2045012.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【书生大模型实战营(暑假场)】基础任务四 XTuner微调个人小助手认知

基础任务四 XTuner微调个人小助手认知 任务文档视频XTuner微调前置基础 文章目录 基础任务四 XTuner微调个人小助手认知0 认识微调0.1 Fine-tune 的两种范式0.2 常见微调技术 1 微调工具 XTuner1.1 认识高效微调框架 XTuner1.2 XTuner 具有出色的优化效果1.3 XTuner 零显存浪费…

【已成功EI检索】第五届新材料与清洁能源国际学术会议(ICAMCE 2024)

重要信息 会议官网&#xff1a;2024.icceam.com 接受/拒稿通知&#xff1a;投稿后1周内 收录检索&#xff1a;EI, Scopus 会议召开视频 见刊封面 EI检索页面 Scopus 检索页面 相关会议 第六届新材料与清洁能源国际学术会议&#xff08;ICAMCE 2025&#xff09; 大会官网&…

【Android】不同系统版本获取设备MAC地址

【Android】不同系统版本获取设备MAC地址 尝试实现 尝试 在开发过程中&#xff0c;想要获取MAC地址&#xff0c;最开始想到的就是WifiManager&#xff0c;但结果始终返回02:00:00:00:00:00&#xff0c;由于用得是wifi &#xff0c;考虑是不是因为用得网线的原因&#xff0c;但…

【海思SS626 | VB】关于 视频缓存池 的理解

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

【C#】explicit、implicit与operator

字面解释 explicit&#xff1a;清楚明白的;易于理解的;(说话)清晰的&#xff0c;明确的;直言的;坦率的;直截了当的;不隐晦的;不含糊的。 implicit&#xff1a;含蓄的;不直接言明的;成为一部分的;内含的;完全的;无疑问的。 operator&#xff1a;操作人员;技工;电话员;接线员;…

OSL 冠名赞助Web3峰会 “FORESIGHT2024”圆满收官

OSL 望为香港数字资产市场发展建设添砖加瓦 &#xff08;香港&#xff0c;2024 年 8 月 13 日&#xff09;- 8 月 11 日至 12 日&#xff0c; 由 香港唯一专注数字资产的上市公司 OSL 集团&#xff08;863.HK&#xff09;冠名赞助&#xff0c;Foresight News、 Foresight Ventu…

C++ 11相关新特性(lambda表达式与function包装器)

目录 lambda表达式 引入 lambda表达式介绍 lambda表达式捕捉列表的传递形式 lambda表达式的原理 包装器 包装器的基本使用 包装器与重载函数 包装器的使用 绑定 C 11 新特性 lambda表达式 引入 在C 98中&#xff0c;对于sort函数来说&#xff0c;如果需要根据不同的比较方式实现…

Springboot日志监听功能

目录 1. 概述1.1. 需求1.2. 思路 2. 功能实现2.1 依赖选取2.2 编写logBack.xml2.3 日志拦截2.4 封装请求为HttpServletRequestWrapper2.5 AOP2.6 日志监听 3. 后记 1. 概述 1.1. 需求 背景&#xff1a;拆分支付系统的日志&#xff0c;把每笔单子的日志单独拎出来存库。每笔单…

如何将高清图片修复?3个方法一键还原图片

如何将高清图片修复&#xff1f;高清图片修复是一个涉及图像处理技术的复杂过程&#xff0c;是对图片进行简单的调整或优化。这个过程旨在最大程度地恢复和提升图片的清晰度、细节和整体视觉效果&#xff0c;使其更加逼真、生动。通过高清图片的修复&#xff0c;我们可以让老旧…

稀疏注意力:时间序列预测的局部性和Transformer的存储瓶颈

时间序列预测是许多领域的重要问题&#xff0c;包括对太阳能发电厂发电量、电力消耗和交通拥堵情况的预测。在本文中&#xff0c;提出用Transformer来解决这类预测问题。虽然在我们的初步研究中对其性能印象深刻&#xff0c;但发现了它的两个主要缺点:(1)位置不可知性:规范Tran…

C++_2_ inline内联函数 宏函数(2/3)

C推出了inline关键字&#xff0c;其目的是为了替代C语言中的宏函数。 我们先来回顾宏函数&#xff1a; 宏函数 现有个需求&#xff1a;要求你写一个Add(x,y)的宏函数。 正确的写法有一种&#xff0c;错误的写法倒是五花八门&#xff0c;我们先来“见不贤而自省也。” // …

windows下部署redis3.2

一、下载redis3.2的包 6.2.6的包也有&#xff0c;但无法安装为Windows服务&#xff0c;暂时舍弃。 直接运行&#xff1a; redis-server redis.windows.conf 修改密码, 对应 redis.windows.conf 中的 requirepass 节点&#xff0c;注意去掉前面的# 修改端口&#xff0c;对应…

缺陷检测AI 重要参数解释

一、参数介绍 基本参数 True Positives (TP) True Positives (TP) 是一个用于评估模型性能的术语。它指的是模型正确预测为正例&#xff08;Positive&#xff09;的样本数量&#xff0c;即实际为正例且被正确分类为正例的样本数量。 False Positives (FP) FP (False Posit…

Python 文件目录操作,以及json.dump() 和 json.load()

import os 是用来引入 Python 标准库中的 os 模块的&#xff0c;这个模块提供了与操作系统交互的功能。这个模块常用于文件和目录操作&#xff0c;比如获取文件的目录路径、创建目录等。 如果你在代码中需要使用与操作系统相关的功能&#xff08;例如获取目录名、检查文件是否…

qt-11基本对话框(消息框)

基本对话框--消息框 msgboxdlg.hmsgboxdlg.cppmain.cpp运行图QustionMsgInFormationMsgWarningMsgCriticalMsgAboutMsgAboutAtMsg自定义 msgboxdlg.h #ifndef MSGBOXDLG_H #define MSGBOXDLG_H#include <QDialog> #include <QLabel> #include <QPushButton>…

Cesium模型制作,解决Cesium加载glb/GLTF显示太黑不在中心等问题

Cesium模型制作&#xff0c;解决Cesium加载glb/GLTF显示太黑不在中心等问题 QQ可以联系这里&#xff0c;谢谢

电商搜索新纪元:大模型引领购物体验革新

随着电商行业的蓬勃发展&#xff0c;搜索技术作为连接用户与商品的桥梁&#xff0c;其重要性日益凸显。在技术不断革新的今天&#xff0c;电商搜索技术经历了哪些阶段&#xff1f;面对大模型的飞速发展&#xff0c;企业又将如何把握趋势&#xff0c;应对挑战&#xff1f;为了深…

openinstall支持抖音游戏小手柄监测,助力游戏联运生态高效增长

近来&#xff0c;抖音“小手柄”功能风靡游戏广告生态&#xff0c;通过新颖的联运形式成功将游戏广告触达到抖音整个流量池&#xff0c;由于受众较广&#xff0c;小手柄也是目前直播场数、点赞数最高的形式。 为了帮助广告主快速捕捉流量红利&#xff0c;打通抖音小手柄的数据…

【选型指南】大流量停车场和高端停车场如何选择停车场管理系统?

在当今快节奏的城市生活中&#xff0c;大型停车场和高端车场的运营者面临着一系列挑战&#xff0c;尤其是在车辆流量大和客户期望高的情况下。选择一个合适的停车场管理系统&#xff0c;不仅关系到日常运营的效率&#xff0c;更关系到客户的满意度和车场的整体形象。 捷顺科技认…

螺纹钢生产线中测径仪对基圆和负公差的测量和影响

螺纹钢生产线中测径仪的作用 在螺纹钢生产线中&#xff0c;测径仪是一种关键的检测设备&#xff0c;它负责对螺纹钢的基圆直径、横肋和纵肋等尺寸进行实时测量。测径仪的数据对于监控和控制螺纹钢的生产质量至关重要&#xff0c;尤其是在进行负公差轧制时&#xff0c;它能够提供…