【吴恩达机器学习-week2】可选实验:使用 Scikit-Learn 进行线性回归

news2024/11/20 7:04:20

支持我的工作 🎉

📃亲爱的朋友们,感谢你们一直以来对我的关注和支持!
💪🏻 为了提供更优质的内容和更有趣的创作,我付出了大量的时间和精力。如果你觉得我的内容对你有帮助或带来了欢乐,欢迎你通过打赏支持我的工作!

🫰🏻你的一份打赏不仅是对我工作的认可,更是对我持续创作的巨大动力。无论金额多少,每一份支持都让我倍感鼓舞和感激。

📝有关此篇文章的更多详情请见:2022吴恩达机器学习Deeplearning.ai课程作业,给我一杯咖啡的支持吧!☕️

🔥再次感谢你们的支持和陪伴!
在这里插入图片描述

可选实验:使用 Scikit-Learn 进行线性回归

目标

在本实验中,您将:

  • 利用 scikit-learn 通过梯度下降实现线性回归
import numpy as np
np.set_printoptions(precision=2) # 使其输出的浮点数精度为小数点后两位
# 这两个类分别用于实现**普通线性回归**和**使用随机梯度下降法的线性回归**
from sklearn.linear_model import LinearRegression, SGDRegressor
# 这个类用于**标准化数据**,使其均值为0,标准差为1。
from sklearn.preprocessing import StandardScaler
from lab_utils_multi import  load_house_data
import matplotlib.pyplot as plt
dlblue = '#0096ff'; dlorange = '#FF9300'; dldarkred='#C00000'; dlmagenta='#FF40FF'; dlpurple='#7030A0'; 
plt.style.use('./deeplearning.mplstyle')

梯度下降

Scikit-learn 有一个梯度下降回归模型 sklearn.linear_model.SGDRegressor。与您之前的梯度下降实现类似,该模型在标准化输入下表现最佳。sklearn.preprocessing.StandardScaler 将执行 z-score 标准化,如之前的实验中所示。在这里,它被称为“标准分数”。


加载数据集

X_train, y_train = load_house_data()
X_features = ['size(sqft)','bedrooms','floors','age']

标准化/归一化训练数据

scaler = StandardScaler()
X_norm = scaler.fit_transform(X_train)
print(f"Peak to Peak range by column in Raw        X:{np.ptp(X_train,axis=0)}")   
print(f"Peak to Peak range by column in Normalized X:{np.ptp(X_norm,axis=0)}")

## print
Peak to Peak range by column in Raw        X:[2.41e+03 4.00e+00 1.00e+00 9.50e+01]
Peak to Peak range by column in Normalized X:[5.85 6.14 2.06 3.69]
  • scaler = StandardScaler()
    • 创建一个 StandardScaler 实例
    • StandardScalerscikit-learn 提供的一个类,用于数据标准化。标准化的目的是使数据具有零均值和单位方差
  • X_norm = scaler.fit_transform(X_train)
    • 使用 StandardScaler 实例对 X_train 数据进行标准化
    • fit_transform 方法首先计算数据的均值和标准差然后对数据进行标准化。
    • 返回的 X_norm 是标准化后的数据。
  • print(f"Peak to Peak range by column in Raw X:{np.ptp(X_train,axis=0)}")
    • 计算并打印原始数据 X_train 每一列的峰值范围(最大值减去最小值)。
    • np.ptp 函数用于计算沿指定轴的峰值范围,这里使用 axis=0 表示按列计算。
  • print(f"Peak to Peak range by column in Normalized X:{np.ptp(X_norm,axis=0)}")
    • 计算并打印标准化后数据 X_norm 每一列的峰值范围
    • 标准化后的数据通常会有较小且相似的范围,因为它们被缩放到相同的尺度。

创建并拟合回归模型

## 创建 SGDRegressor 实例
sgdr = SGDRegressor(max_iter=1000)

## 训练模型
sgdr.fit(X_norm, y_train)

print(sgdr). # 打印 SGDRegressor 模型的**概述信息**,显示**模型的主要参数和设置**

## 打印迭代次数和权重更新次数
print(f"number of iterations completed: {sgdr.n_iter_}, number of weight updates: {sgdr.t_}")

## print
SGDRegressor()
number of iterations completed: 117, number of weight updates: 11584.0
  • sgdr = SGDRegressor(max_iter=1000)

    创建一个 SGDRegressor 实例,并将最大迭代次数设置为 1000。SGDRegressorscikit-learn 提供的一种使用随机梯度下降法训练线性模型的回归器

  • sgdr.fit(X_norm, y_train)

    • 使用标准化后的训练数据 X_norm 和目标变量 y_train 来训练 SGDRegressor 模型。
    • fit 方法用于拟合模型。

查看参数

请注意,这些参数与标准化输入数据相关。拟合参数与之前实验中使用该数据找到的参数非常接近。

# 获取 SGDRegressor 模型的截距(偏置项)。intercept_ 属性包含模型的截距
b_norm = sgdr.intercept_ 

# 获取 SGDRegressor 模型的系数(权重)。coef_ 属性包含模型的系数
w_norm = sgdr.coef_
print(f"model parameters:                   w: {w_norm}, b:{b_norm}")
print(f"model parameters from previous lab: w: [110.56 -21.27 -32.71 -37.97], b: 363.16")

## print
model parameters:                   w: [110.08 -21.05 -32.46 -38.04], b:[363.15]
model parameters from previous lab: w: [110.56 -21.27 -32.71 -37.97], b: 363.16

进行预测

预测训练数据的目标值。使用 predict 例程,并使用 w w w b b b 进行计算。

# 使用 sgdr.predict() 进行预测
y_pred_sgd = sgdr.predict(X_norm)
# 使用权重和截距进行预测 
y_pred = np.dot(X_norm, w_norm) + b_norm  
# 检查所有预测值是否都相同,如果相同则返回 True,否则返回 False。
print(f"prediction using np.dot() and sgdr.predict match: {(y_pred == y_pred_sgd).all()}")

# 打印前四个样本的预测值和目标值进行对比
print(f"Prediction on training set:\n{y_pred[:4]}" )
print(f"Target values \n{y_train[:4]}")

## print
prediction using np.dot() and sgdr.predict match: True
Prediction on training set:[295.2  485.82 389.56 491.98]
Target values [300.  509.8 394.  540. ]
  • y_pred_sgd = sgdr.predict(X_norm)
    • 使用训练好的 SGDRegressor 模型对标准化后的数据 X_norm 进行预测。
    • predict 方法会根据模型的系数和截距计算预测值。
  • y_pred = np.dot(X_norm, w_norm) + b_norm
    • 直接使用之前获取的权重 w_norm 和截距 b_norm 对标- 准化后的数据 X_norm 进行预测。
    • np.dot(X_norm, w_norm) 计算每个样本的线性组合,加上截距 b_norm 得到预测值。

绘制结果

让我们绘制预测值与目标值的对比图。

# plot predictions and targets vs original features    
fig,ax=plt.subplots(1,4,figsize=(12,3),sharey=True)
for i in range(len(ax)):
    ax[i].scatter(X_train[:,i],y_train, label = 'target') # 绘制实际值的散点图
    ax[i].set_xlabel(X_features[i])
    ax[i].scatter(X_train[:,i],y_pred,color=dlorange, label = 'predict') # 绘制预测值的散点图
ax[0].set_ylabel("Price"); 
ax[0].legend();
fig.suptitle("target versus prediction using z-score normalized model")
plt.show()

在这里插入图片描述

小结

  1. 使用开源机器学习工具包 scikit-learn
    • scikit-learn 是一个使用的机器学习库,提供了各种算法和工具,用于数据预处理、模型训练和评估。
  2. 实现了线性回归模型
    • 通过使用梯度下降算法(SGDRegressor),我们训练了一个线性回归模型来预测房价。
    • 我们还使用了**标准化技术(StandardScaler)**对特征数据进行了归一化处理,从而加快了模型的收敛速度并提高了模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1887454.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

看完这篇文章你就知道什么是未来软件开发的方向了!即生成式AI在软件开发领域的革新=CodeFlying

从最早的UGC(用户生成内容)到PGC(专业生成内容)再到AIGC(人工智能生成内容)体现了web1.0→web2.0→web3.0的发展历程。 毫无疑问UGC已经成为了当前拥有群体数量最大的内容生产方式。 同时随着人工智能技术…

SAP 表字段调整,表维护生成器调整

表维护生成器->已生成的对象->更改->专家模式

Linux下的wifi开发

了解什么是wifi 可参考: 什么是Wi-Fi?Wi-Fi和WLAN的区别是什么? - 华为 (huawei.com) WLAN的基本元素 工作站STA(Station):支持802.11标准的终端设备。例如带无线网卡的电脑、支持WLAN的手机等。 接入点AP&…

OpenSSH RCE (CVE-2024-6387) | 附poc | 小试

Ⅰ 漏洞描述 OpenSSH 远程代码执行漏洞(CVE-2024-6387)&#xff0c;该漏洞是由于OpenSSH服务器 (sshd) 中的信号处理程序竞争问题&#xff0c;未经身份验证的攻击者可以利用此漏洞在Linux系统上以root身份执行任意代码。 Ⅱ 影响范围 8.5p1 < OpenSSH < 9.8p1 但OpenSS…

数学建模--层次分析法~~深入解读

目录 1.基本概念 &#xff08;1&#xff09;研究案例 &#xff08;2&#xff09;模型框架 &#xff08;3&#xff09;阐述说明 &#xff08;4&#xff09;注意事项 2.模型的建立和求解 &#xff08;1&#xff09;数量级的统一 &#xff08;2&#xff09;归一化处理 &am…

用Vue3和Rough.js绘制一个粗糙的3D条形图

本文由ScriptEcho平台提供技术支持 项目地址&#xff1a;传送门 使用 Rough.js 和 D3.js 绘制粗糙手写风格条形图 应用场景 该代码适用于需要在 Web 应用程序中创建具有粗糙手写风格的条形图的情况。它可以用于数据可视化、信息图表或任何需要以独特和有吸引力的方式呈现数…

Java StringBuffer类和StringBuilder类

在使用 StringBuffer 类时&#xff0c;每次都会对 StringBuffer 对象本身进行操作&#xff0c;而不是生成新的对象&#xff0c;所以如果需要对字符串进行修改推荐使用 StringBuffer。 StringBuilder 类在 Java 5 中被提出&#xff0c;它和 StringBuffer 之间的最大不同在于 St…

【PYG】Cora数据集分类任务计算损失,cross_entropy为什么不能直接替换成mse_loss

cross_entropy计算误差方式&#xff0c;输入向量z为[1,2,3]&#xff0c;预测y为[1]&#xff0c;选择数为2&#xff0c;计算出一大坨e的式子为3.405&#xff0c;再用-23.405计算得到1.405MSE计算误差方式&#xff0c;输入z为[1,2,3]&#xff0c;预测向量应该是[1,0,0]&#xff0…

IAR工程目录移动报错(改变文件目录结构)

刚开始用IAR&#xff0c;记录一下。 工作中使用华大单片机&#xff0c;例程的文件目录结构太复杂了想精简一点。 1.如果原本的C文件相对工程文件&#xff08;.eww文件&#xff09;路径变化了&#xff0c;需要先打开工程&#xff0c;再将所有的.c文件右键Add添加进工程&#xf…

【Godot4.2】Godot中的贝塞尔曲线

概述 通过指定平面上的多个点&#xff0c;然后顺次连接&#xff0c;我们可以得到折线段&#xff0c;如果闭合图形&#xff0c;就可以获得多边形。通过向量旋转我们可以获得圆等特殊图形。 但是对于任意曲线&#xff0c;我们无法使用简单的方式来获取其顶点&#xff0c;好在计…

X-ObjectMount: 对象存储访问接入的新选择

XEOS 自 2017 年发布面世以来&#xff0c;历经 7 年的研发迭代&#xff0c;上个月正式发布了 XSKY SDS 6.4 版本&#xff0c;包含了最新的多站点统一命名空间能力&#xff0c;也标志了 XEOS 在对象存储领域的全方面优势和领先市场地位。 在 XSKY 过去对象存储服务历程里&#…

mysql 命令 —— 查看表信息(show table status)

查询表信息&#xff0c;如整个表的数据量大小、表的索引占用空间大小等 1、查询某个库下面的所有表信息&#xff1a; SHOW TABLE STATUS FROM your_database_name;2、查询指定的表信息&#xff1a; SHOW TABLE STATUS LIKE your_table_name;如&#xff1a;Data_length 显示表…

openGauss真的比PostgreSQL差了10年?

前不久写了MogDB针对PostgreSQL的兼容性文章&#xff0c;我在文中提到针对PostgreSQL而言&#xff0c;MogDB兼容性还是不错的&#xff0c;其中也给出了其中一个能源客户之前POC的迁移报告数据。 But很快我发现总有人回留言喷我&#xff0c;而且我发现每次喷的这帮人是根本不看文…

Python基础003

Python流程控制基础 1.条件语句 内置函数input a input("请输入一段内容&#xff1a;") print(a) print(type(a))代码执行的时候遇到input函数&#xff0c;就会等键盘输入结果&#xff0c;已回车为结束标志&#xff0c;也就时说输入回车后代码才会执行 2.顺序执行…

【问题记录】如何在xftp上查看隐藏文件。

显示隐藏的文件夹 用xftp连接到服务器后&#xff0c;发现有些隐藏的文件夹并未显示出来&#xff0c;通过以下配置&#xff0c;即可使隐藏的文件夹给显示出来。 1.点击菜单栏的"小齿轮"按钮&#xff1a; 2.勾选显示隐藏的文件夹&#xff1a; 3.点击确定即可。

古韵流光:探秘五代耀州窑青瓷提梁倒灌壶的奇妙设计

在陕西历史博物馆的静谧展厅中&#xff0c;一件千年前的瓷器静静陈列&#xff0c;它不仅承载着历史的沉淀&#xff0c;更凝聚了古代匠人的非凡智慧。这便是五代时期的耀州窑青瓷提梁倒灌壶&#xff0c;一件巧夺天工的艺术品&#xff0c;其独特的设计至今仍让人叹为观止。 一、倒…

算法mq 交互通用校验模块设计

背景 当前与算法交互均通过rocketMQ异步交互&#xff0c;绝大部分场景一条请求mq消息应对应一条返回mq&#xff0c;但由于各种原因&#xff08;消息积压、程序bug&#xff09;&#xff0c;可能会导致返回mq超时未返回或者消息丢失。工程侧针对一些重要场景 case by case的通过…

【web3】分享一个web入门学习平台-HackQuest

前言 一直想进入web3行业&#xff0c;但是没有什么途径&#xff0c;偶然在电鸭平台看到HackQuest的共学营&#xff0c;发现真的不错&#xff0c;并且还接触到了黑客松这种形式。 链接地址&#xff1a;HackQuest 平台功能 学习路径&#xff1a;平台有完整的学习路径&#xff…

VS2022+Qt+OpenCV Debug模式下,循环中格式转换引起的内存异常问题 debug_heap.cpp

文章目录 前言一、问题二、报错1.提示图片2.提示堆栈3.反汇编位置 三、解决办法总结 前言 最近在使用VS2022&#xff0c;C&#xff0c;OpenCV&#xff0c;Qt开发时&#xff0c;遇到了一个疑难杂症-在循环中执行字符串格式转换会触发内存异常&#xff0c;经过痛苦的排查过程&am…

Ubuntu下反弹shell的思考

目录 Ubuntu的命令执行环境 bash (Bourne Again SHell): sh (Bourne SHell): dash (Debian Almquist SHell): 它们之间的关系&#xff1a; 可能遇到的问题 一、脚本权限问题 二、命令执行环境(shell解释器)问题 如何解决&#xff1f; 1.修改/bin/sh软连接的指向为bas…