机器学习基础--基于线性回归房价预测

news2024/11/25 8:29:02
        经典的线性回归模型主要用来预测一些存在着线性关系的数据集。回归模型可以理解为:存在一个点集,用一条曲线去拟合它分布的过程。如果拟合曲线是一条直线,则称为线性回归。如果是一条二次曲线,则被称为二次回归。线性回归是回归模型中最简单的一种。
在线性回归中:
(1)假设函数是指,用数学的方法描述自变量和因变量之间的关系,它们之间可以是一个线性函数或非线性函数。 在本次线性回顾模型中,我们的假设函数为 Y’= wX+b ,其中,Y’表示模型的预测结果(预测房价),用来和真实的Y区分。模型要学习的参数即:w,b。
(2)损失函数是指,用数学的方法衡量假设函数预测结果与真实值之间的误差。这个差距越小预测越准确,而算法的任务就是使这个差距越来越小。 建立模型后,我们需要给模型一个优化目标,使得学到的参数能够让预测值Y’尽可能地接近真实值Y。这个实值通常用来反映模型误差的大小。不同问题场景下采用不同的损失函数。 对于线性模型来讲,最常用的损失函数就是均方误差(Mean Squared Error, MSE)。
(3)优化算法:神经网络的训练就是调整权重(参数)使得损失函数值尽可能得小,在训练过程中,将损失函数值逐渐收敛,得到一组使得神经网络拟合真实模型的权重(参数)。所以,优化算法的最终目标是找到损失函数的最小值。而这个寻找过程就是不断地微调变量w和b的值,一步一步地试出这个最小值。 常见的优化算法有随机梯度下降法(SGD)、Adam算法等等
import paddle
import numpy as np
import os
import matplotlib.pyplot as plt

(1)uci-housing数据集介绍

数据集共506行,每行14列。前13列用来描述房屋的各种信息,最后一列为该类房屋价格中位数。

PaddlePaddle提供了读取uci_housing训练集和测试集的接口,分别为paddle.dataset.uci_housing.train()和paddle.dataset.uci_housing.test()。

#设置默认的全局dtype为float64
paddle.set_default_dtype("float64")
#下载数据
print('下载并加载训练数据')
train_dataset = paddle.text.datasets.UCIHousing(mode='train')
eval_dataset = paddle.text.datasets.UCIHousing(mode='test')
train_loader = paddle.io.DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = paddle.io.DataLoader(eval_dataset, batch_size = 8, shuffle=False)

print('\n加载完成')

dataloder中的参数设置,首先选择要加载的数据集,batch_size表示每次加载数据的多少,shuffle表示加载时是否打乱顺序

对于线性回归来讲,它就是一个从输入到输出的简单的全连接层。

# 定义全连接网络
class Regressor(paddle.nn.Layer):
    def __init__(self):
        super(Regressor, self).__init__()
        # 定义一层全连接层,输出维度是1,激活函数为None,即不使用激活函数
        self.linear = paddle.nn.Linear(13, 1, None)
    
    # 网络的前向计算函数
    def forward(self, inputs):
        x = self.linear(inputs)
        return x

使用前13个来预测最后一位,故连接层输入为13,输出为1且不使用激活函数

model=Regressor() # 模型实例化
model.train() # 训练模式
mse_loss = paddle.nn.MSELoss()
opt=paddle.optimizer.SGD(learning_rate=0.0005, parameters=model.parameters())

模型实例化,使用SGD作为优化器,rate表示学习率,后面的传入的是模型参数


epochs_num=200 #迭代次数
for pass_num in range(epochs_num):
    for batch_id,data in enumerate(train_loader()):
        image = data[0]
        label = data[1]
        predict=model(image) #数据传入model
        loss=mse_loss(predict,label)
        
        if batch_id!=0 and batch_id%10==0:
            Batch = Batch+10
            Batchs.append(Batch)
            all_train_loss.append(loss.numpy()[0]) 
            print("epoch:{},step:{},train_loss:{}".format(pass_num,batch_id,loss.numpy()[0])  )      
        loss.backward()  #反向传播计算模型梯度     
        opt.step()    #更新模型参数
        opt.clear_grad()   #opt.clear_grad()来重置梯度
paddle.save(model.state_dict(),'Regressor')#保存模型

设置迭代轮次,将数据传入模型进行预测,然后计算损失值(mse代表使用均方误差来计算损失值),训练模型并保存

#模型评估
para_state_dict = paddle.load("Regressor") 
model = Regressor()
model.set_state_dict(para_state_dict) #加载模型参数
model.eval() #验证模式

losses = []
infer_results=[]
groud_truths=[]
for batch_id,data in enumerate(eval_loader()):#测试集
    image=data[0]
    label=data[1] 
    groud_truths.extend(label.numpy())    
    predict=model(image) 
    infer_results.extend(predict.numpy())      
    loss=mse_loss(predict,label)
    losses.append(loss.numpy()[0])
    avg_loss = np.mean(losses)
print("当前模型在验证集上的损失值为:",avg_loss)

进行模型评估并计算损失值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2247169.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

可视化建模与UML《协作图实验报告》

有些鸟儿毕竟是关不住的。 一、实验目的: 1、熟悉协作图的构件事物。 2、掌握协作图的绘制方法。 二、实验环境: window7 | 10 | 11 EA15 三、实验内容: 下面列出了打印文件时的工作流: 用户通过计算机指定要打印的文件。(2)打…

docker搭建私有的仓库

docker搭建私有仓库 一、为什么要搭建私有的仓库? 因为在国内,访问:https://hub.docker.com/ 会出现无法访问页面。。。。(已经使用了魔法) 当然现在也有一些国内的镜像管理网站,比如网易云镜像服务、Dao…

微信小程序条件渲染与列表渲染的全面教程

微信小程序条件渲染与列表渲染的全面教程 引言 在微信小程序的开发中,条件渲染和列表渲染是构建动态用户界面的重要技术。通过条件渲染,我们可以根据不同的状态展示不同的内容,而列表渲染则使得我们能够高效地展示一组数据。本文将详细讲解这两种渲染方式的用法,结合实例…

订单日记为“惠采科技”提供全方位的进销存管理支持

感谢温州惠采科技有限责任公司选择使用订单日记! 温州惠采科技有限责任公司,成立于2024年,位于浙江省温州市,是一家以从事销售电气辅材为主的企业。 在业务不断壮大的过程中,想使用一种既能提升运营效率又能节省成本…

mysql-分析并解决可重复读隔离级别发生的删除幻读问题

在 MySQL 的 InnoDB 存储引擎中,快照读和当前读的行为会影响事务的一致性。让我们详细分析一下隔离级别味可重复读的情况下如何解决删除带来的幻读。 场景描述 假设有一个表 orders,其中包含以下数据: 事务 A 执行快照读 START TRANSACTION…

使用itextpdf进行pdf模版填充中文文本时部分字不显示问题

在网上找了很多种办法 都解决不了; 最后发现是文本域字体设置出了问题; 在这不展示其他的代码 只展示重要代码; 1 引入扩展包 <dependency><groupId>com.itextpdf</groupId><artifactId>itext-asian</artifactId><version>5.2.0</v…

链表刷题|判断回文结构

题目来自于牛客网&#xff0c;本文章仅记录学习过程的做题理解&#xff0c;便于梳理思路和复习 我做题喜欢先把时间复杂度和空间复杂度放一边&#xff0c;先得有大概的解决方案&#xff0c;最后如果时间或者空间超了再去优化即可。 思路一&#xff1a;要判断是否为回文结构则…

0基础跟德姆(dom)一起学AI NLP自然语言处理01-自然语言处理入门

1 什么是自然语言处理 自然语言处理&#xff08;Natural Language Processing, 简称NLP&#xff09;是计算机科学与语言学中关注于计算机与人类语言间转换的领域. 2 自然语言处理的发展简史 3 自然语言处理的应用场景 语音助手机器翻译搜索引擎智能问答...

Linux系统使用valgrind分析C++程序内存资源使用情况

内存占用是我们开发的时候需要重点关注的一个问题&#xff0c;我们可以人工根据代码推理出一个消耗内存较大的函数&#xff0c;也可以推理出大概会消耗多少内存&#xff0c;但是这种方法不仅麻烦&#xff0c;而且得到的只是推理的数据&#xff0c;而不是实际的数据。 我们可以…

跨平台开发_RTC程序设计:实时音视频权威指南 2

1.2.1 一切皆bit 将8 bit分为一组&#xff0c;我们定义了字节(Byte)。 1956年6月&#xff0c;使用了Byte这个术语&#xff0c;用来表示数字信息的基本单元。 最早的字节并非8 bit。《计算机程序设计的艺术》一书中的MIX机器采用6bit作为1Byte。8 bit的Byte约定&#xff0c;和I…

WIFI:长GI与短GI有什么区别和影响

1、GI的作用 Short GI(Guard Interval)是802.11n针对802.11a/g所做的改进。射频芯片在使用OFDM调制方式发送数据时&#xff0c;整个帧是被划分成不同的数据块进行发送的&#xff0c;为了数据传输的可靠性&#xff0c;数据块之间会有GI&#xff0c;用以保证接收侧能够正确的解析…

ssm实战项目──哈米音乐(二)

目录 1、流派搜索与分页 2、流派的添加 3、流派的修改 4、流派的删除 接上篇&#xff1a;ssm实战项目──哈米音乐&#xff08;一&#xff09;&#xff0c;我们完成了项目的整体搭建&#xff0c;接下来进行后台模块的开发。 首先是流派模块&#xff1a; 在该模块中采用分…

C++使用minio-cpp(minio官方C++ SDK)与minio服务器交互简介

目录 minio简介minio-cpp简介minio-cpp使用 minio简介 minio是一个开源的高性能对象存储解决方案&#xff0c;完全兼容Amazon S3 API&#xff0c;支持分布式存储&#xff0c;适用于大规模数据架构&#xff0c;容易集成&#xff0c;而且可以方便的部署在集群中。 如果你已经部…

细说敏捷:敏捷四会之standup meeting

上一篇文章中&#xff0c;我们讨论了 敏捷四会 中 冲刺计划会 的实施要点&#xff0c;本篇我们继续分享敏捷四会中实施最频繁&#xff0c;团队最容易实施但往往也最容易走形的第二个会议&#xff1a;每日站会 关于每日站会的误区 站会是一个比较有标志性的仪式活动&#xff0…

二分法(折半法)查找【有动图】

二分法&#xff0c;也叫做折半法&#xff0c;就是一种通过有序表的中间元素与目标元素进行对比&#xff0c;根据大小关系排除一半元素&#xff0c;然后继续在剩余的一半中进行查找&#xff0c;重复这个过程直至找到目标值或者确定目标值不存在。 我们从结论往回推&#xff0c;…

FreeRTOS——低功耗管理

目录 一、概念及其应用 1.1应用 1.2STM32电源管理系统 2.3STM32低功耗模式 2.3.1睡眠模式 2.3.2停止模式 2.3.3待机模式 三、Tickless低功耗模式 3.1低功耗模式配置 3.2低功耗模式应用 3.3低功耗电路分析 3.4低功耗处理相关接口 四、实现原理 4.1任务等待删除的检查…

【STM32】MPU6050初始化常用寄存器说明及示例代码

一、MPU6050常用配置寄存器 1、电源管理寄存器1&#xff08; PWR_MGMT_1 &#xff09; 此寄存器允许用户配置电源模式和时钟源。 DEVICE_RESET &#xff1a;用于控制复位的比特位。设置为1时复位 MPU6050&#xff0c;内部寄存器恢复为默认值&#xff0c;复位结束…

2024年亚太地区数学建模大赛A题-复杂场景下水下图像增强技术的研究

复杂场景下水下图像增强技术的研究 对于海洋勘探来说&#xff0c;清晰、高质量的水下图像是深海地形测量和海底资源调查的关键。然而&#xff0c;在复杂的水下环境中&#xff0c;由于光在水中传播过程中的吸收、散射等现象&#xff0c;导致图像质量下降&#xff0c;导致模糊、…

自动驾驶3D目标检测综述(三)

前两篇综述阅读理解放在这啦&#xff0c;有需要自行前往观看&#xff1a; 第一篇&#xff1a;自动驾驶3D目标检测综述&#xff08;一&#xff09;_3d 目标检测-CSDN博客 第二篇&#xff1a;自动驾驶3D目标检测综述&#xff08;二&#xff09;_子流行稀疏卷积 gpu实现-CSDN博客…

【Linux | 计网】TCP协议详解:从定义到连接管理机制

目录 1.TCP协议的定义&#xff1a; 2.TCP 协议段格式 3.TCP两种通信方式 4.确认应答(ACK)机制 解决“后发先至”问题 5.超时重传机制 那么, 超时的时间如何确定? 6.连接管理机制&#xff1a; 6.1.三次握手&#xff1a; 为什么需要3次握手&#xff0c;一次两次不行吗…