吴恩达机器学习C1W2Lab05-使用Scikit-Learn进行线性回归

news2024/9/17 8:56:08

前言

有一个开源的、商业上可用的机器学习工具包,叫做scikit-learn。这个工具包包含了你将在本课程中使用的许多算法的实现。

目标

在本实验中,你将:

  • 利用scikit-learn实现使用梯度下降的线性回归

工具

您将使用scikit-learn中的函数以及matplotlib和NumPy。

import numpy as np
np.set_printoptions(precision=2)
from sklearn.linear_model import LinearRegression, SGDRegressor
from sklearn.preprocessing import StandardScaler
from lab_utils_multi import  load_house_data
import matplotlib.pyplot as plt
dlblue = '#0096ff'; dlorange = '#FF9300'; dldarkred='#C00000'; dlmagenta='#FF40FF'; dlpurple='#7030A0'; 
plt.style.use('./deeplearning.mplstyle')

注意点

可能会出现报错 No module named ‘sklearn’
这是因为当前环境下未安装scikit-learn
在这里插入图片描述【解决办法】:在cmd中输入

pip install scikit-learn

梯度下降

Scikit-learn有一个梯度下降回归模型sklearn.linear_model.SGDRegressor。与之前的梯度下降实现一样,该模型在规范化输入时表现最好。standardscaler将像之前的实验一样执行z-score归一化。这里它被称为“标准分数”。

加载数据集

X_train, y_train = load_house_data()
X_features = ['size(sqft)','bedrooms','floors','age']

缩放/规范化训练数据

scaler = StandardScaler()
X_norm = scaler.fit_transform(X_train)
print(f"Peak to Peak range by column in Raw        X:{np.ptp(X_train,axis=0)}")   
print(f"Peak to Peak range by column in Normalized X:{np.ptp(X_norm,axis=0)}")

创建并拟合回归模型

sgdr = SGDRegressor(max_iter=1000)
sgdr.fit(X_norm, y_train)
print(sgdr)
print(f"number of iterations completed: {sgdr.n_iter_}, number of weight updates: {sgdr.t_}")

视图参数

注意,参数与规范化的输入数据相关联。拟合参数与之前使用该数据的实验室中发现的非常接近。

b_norm = sgdr.intercept_
w_norm = sgdr.coef_
print(f"model parameters:                   w: {w_norm}, b:{b_norm}")
print(f"model parameters from previous lab: w: [110.56 -21.27 -32.71 -37.97], b: 363.16")

做出预测

预测训练数据的目标。使用’ predict '例程并使用 w w w b b b进行计算。

# make a prediction using sgdr.predict()
y_pred_sgd = sgdr.predict(X_norm)
# make a prediction using w,b. 
y_pred = np.dot(X_norm, w_norm) + b_norm  
print(f"prediction using np.dot() and sgdr.predict match: {(y_pred == y_pred_sgd).all()}")

print(f"Prediction on training set:\n{y_pred[:4]}" )
print(f"Target values \n{y_train[:4]}")

绘制结果

让我们绘制预测值与目标值的对比图。

# plot predictions and targets vs original features    
fig,ax=plt.subplots(1,4,figsize=(12,3),sharey=True)
for i in range(len(ax)):
    ax[i].scatter(X_train[:,i],y_train, label = 'target')
    ax[i].set_xlabel(X_features[i])
    ax[i].scatter(X_train[:,i],y_pred,color=dlorange, label = 'predict')
ax[0].set_ylabel("Price"); ax[0].legend();
fig.suptitle("target versus prediction using z-score normalized model")
plt.show()

祝贺

在这个实验中,你:
-使用开源机器学习工具包scikit-learn
-实现线性回归使用梯度下降和特征归一化的工具包

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1962229.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

付费进群系统源码原版最新修复全开源版

付费进群,和平时所见到的别人拉你进群是不一样的,付费进群需要先缴费以后,才会看到群的二维码,扫码进群或者是长按二维码图片识别进群,付费进群这个功能广泛应用于拼多多的砍价群,活动的助力群,…

Chapter 21 深入理解JSON

欢迎大家订阅【Python从入门到精通】专栏,一起探索Python的无限可能! 文章目录 前言一、JSON数据格式1. 什么是JSON?2. JSON数据的格式 二、JSON格式数据转化三、格式化JSON数据的在线工具 前言 在当今数据驱动的世界中,JSON&…

【大模型系列篇】Vanna-ai基于检索增强(RAG)的sql生成框架

简介 Vanna是基于检索增强(RAG)的sql生成框架 Vanna 使用一种称为 LLM(大型语言模型)的生成式人工智能。简而言之,这些模型是在大量数据(包括一堆在线可用的 SQL 查询)上进行训练的,并通过预测响应提示中最…

中间件安全:Nginx 解析漏洞测试.

中间件安全:Nginx 解析漏洞测试. Nginx 是一个高性能的 HTTP和 反向代理服务器,Nginx 解析漏洞是一个由于配置不当导致的安全问题,它不依赖于Nginx或PHP的特定版本,而是由于用户配置错误造成的。这个漏洞允许攻击者上传看似无害的…

通俗易懂,车载显示屏简单介绍!

2024年后,小汽车产量的年增长率预计将在1%至3%之间 2023年在COVID完全解封后,全球汽车销售数量提升至8千8百万台车。2024预估微幅增加到 9000万辆, 自2024起,年成长率预期将放缓至3%以下。全球汽车主要销售前三大市场 (比较2022年…

为什么阿里开发手册不建议使用Date类?

在日常编码中,基本上99%的项目都会有一个DateUtil工具类,而时间工具类里用的最多的就是java.util.Date。 大家都这么写,这还能有问题?? 当你的“默认常识”出现问题,这个打击,就是毁灭性的。 …

BUG解决(vue3+echart报错):Cannot read properties of undefined (reading ‘type‘)

这是 vue3echart5 遇到的报错:Cannot read properties of undefined (reading ‘type‘) 这个问题需要搞清楚两个关键方法: toRaw: 作用:将一个由reactive生成的响应式对象转为普通对象。 使用场景: 用于读取响应式…

idea2023 总报Low memory

idea2023 总报Low memory 问题背景问题处理 问题背景 在日常开发中,使用idea2023开发工具,开发过程中总会遇到idea提示Low memory的情况,并且每当提示出现的时候,整个idea页面便什么也不能操作了,如何处理这个情况呢&…

AI测试:人工智能模型的核心测试指标,分类判别、目标检测、图像分割、定量计算分别有哪些指标?

在前面的人工智能测试技术系列文章中,我们详细介绍了人工智能测试的技术方法和实践流程。在了解人工智能测试方法后,我们需要进一步学习和研究如何衡量这些方法的有效性,即人工智能模型测试指标的选择。测试指标的选择主要取决于模型的类型和…

借助大语言模型快速升级你的 Java 应用程序

大家都知道我爱小 Q。在我“转码”的征程中,它就像上帝之手,在我本该枯燥漫长的学习进程中拉满快进条。 不仅是我,最近 Amazon Q Developer 还帮助 Amazon 一个由 5 人组成的团队在短短两天内将 1,000 多个生产应用程序从 Java 8 升级到 Jav…

Spring Cloud 组件

1.eureka注册中心原理简述 1.服务注册: Eureka Client 会通过发送rest请求的方式向eureka服务端注册自身元数据:ip地址,端口,运行状况等信息,服务端会把注册信息存储在一个双层map中。 Eureka 的数据存储分了两层:数据存储层和缓存层。 Eureka Client 在拉取服务信息…

【STM32嵌入式系统设计与开发拓展】——13_PWM脉宽

目录 1、什么是PWM?用来做什么的?PWM(Pulse Width Modulation)脉冲宽度调制常见用到 PWM 的情况: 2、什么是输出比较?输出比较模式![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/42434920ca0940b1b1083215…

vue el-input 输入框下拉显示匹配数据

1、效果图: 2、需求&实现: 输入条件 下面匹配查询到的数据有多少个 需要调用后端接口展示,后端查询到之后返回条数 前端展示 3、具体代码实现: html: 图片需要自己根据实际情况增加 // 查询 重置 筛选 本文章…

【git】git常用命令提交规范

Git 是程序员工作中不可或缺的版本控制工具,以下是一些优化后的常用 Git 命令列表,旨在帮助你更高效地使用 Git 进行版本控制。 基础操作 拉取代码 git clone xxx.git创建分支 git branch dev切换分支 git checkout dev # 或者 git switch dev创建并切换…

Python酷库之旅-第三方库Pandas(056)

目录 一、用法精讲 211、pandas.Series.truncate方法 211-1、语法 211-2、参数 211-3、功能 211-4、返回值 211-5、说明 211-6、用法 211-6-1、数据准备 211-6-2、代码示例 211-6-3、结果输出 212、pandas.Series.where方法 212-1、语法 212-2、参数 212-3、功能…

论报文加密加签场景下如何高效的进行渗透测试

前言 最新的测试中,经常遇到HTTP报文加密/加签传输的情况,这导致想要查看和修改明文报文很不方便。 之前应对这种情况我们有几种常见的办法解决,比如使用burpy插件、在Burp上下游使用mitmproxy进行代理等,但这些使用起来不太方便…

LSTM详解总结

LSTM(Long Short-Term Memory)是一种用于处理和预测时间序列数据的递归神经网络(RNN)的改进版本。其设计初衷是为了解决普通RNN在长序列训练中出现的梯度消失和梯度爆炸问题。以下是对LSTM的详细解释,包括原理、公式、…

面向非结构化数据的知迟抽取

文章目录 实体抽取关系抽取事件抽取大量的数据以非结构化数据(即自由文本)的形式存在,如新闻报道、科技文献和政府文件等,面向文本数据的知识抽取一直是广受关注的问题。在前文介绍的知识抽取领域的评测竞赛中,评测数据大多属于非结构化文本数据。本节将对这一类知识抽取技…

Prometheus-部署

Prometheus-部署 Server端安装配置部署Node Exporters监控系统指标监控MySQL数据库监控nginx安装grafana Server端安装配置 1、上传安装包,并解压 cd /opt/ tar xf prometheus-2.30.3.linux-amd64.tar.gz mv prometheus-2.30.3.linux-amd64 /usr/local/prometheus…

【音频识别】十大数据集合集,宝藏合集,不容错过!

本文将为您介绍10个经典、热门的数据集,希望对您在选择适合的数据集时有所帮助。 1 RenderMe-360 发布方: 上海人工智能实验室 发布时间: 2023-05-24 简介: RenFace是一个大规模多视角人脸高清视频数据集,包含多样的…