机器学习实战-系列教程3:手撕线性回归2之单特征线性回归(项目实战、原理解读、源码解读)

news2025/1/10 2:25:14

🌈🌈🌈机器学习 实战系列 总目录

本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

手撕线性回归1之线性回归类的实现
手撕线性回归2之单特征线性回归
手撕线性回归3之多特征线性回归
手撕线性回归4之非线性回归# 5、数据预处理

5.1 数据读入

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from linear_regression import LinearRegression

首先是导包numpy、pandas、matplotlib素质三连,从文件中linear_regression导入类

data = pd.read_csv('../data/world-happiness-report-2017.csv')
train_data = data.sample(frac = 0.8)
test_data = data.drop(train_data.index)
  1. 读csv文件
  2. 按80比例分配训练数据
  3. 按20比例分配训练数据
input_param_name = 'Economy..GDP.per.Capita.'
output_param_name = 'Happiness.Score'
x_train = train_data[[input_param_name]].values
y_train = train_data[[output_param_name]].values
x_test = test_data[input_param_name].values
y_test = test_data[output_param_name].values
  1. 原始数据中取出一列"索引"作为输入
  2. 原始数据中取出一列"索引"作为标签
  3. 按照索引取出训练集数据
  4. 按照索引取出训练集标签
  5. 按照索引取出测试集数据
  6. 按照索引取出测试集标签
plt.scatter(x_train,y_train,label='Train data')
plt.scatter(x_test,y_test,label='test data')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()
  1. 训练数据散点图
  2. 测试数据散点图

打印结果:
在这里插入图片描述

5.2 训练

num_iterations = 500
learning_rate = 0.01
  1. 迭代次数,即整个数据集训练次数
  2. 学习率
linear_regression = LinearRegression(x_train,y_train)
(theta,cost_history) = linear_regression.train(learning_rate,num_iterations)
print ('开始时的损失:',cost_history[0])
print ('训练后的损失:',cost_history[-1])
  1. 数据传入类中,实例化类得到linear_regression 对象
  2. linear_regression 对象调用train方法,得到参数和损失
  3. 打印开始损失
  4. 打印结束损失

打印结果:

开始时的损失: 14.633306098916812
训练后的损失: 0.2275173194286417

plt.plot(range(num_iterations),cost_history)
plt.xlabel('Iter')
plt.ylabel('cost')
plt.title('GD')
plt.show()

打印结果:
在这里插入图片描述

5.3 测试

predictions_num = 100
x_predictions = np.linspace(x_train.min(),x_train.max(),predictions_num).reshape(predictions_num,1)
y_predictions = linear_regression.predict(x_predictions)
  1. 选用100个数据
  2. x_predictions:
    1. x_train.min(),x_train.max(),前面的训练数据中的最小值和最大值
    2. np.linspace(x_train.min(),x_train.max(),predictions_num),最小值和最大值为范围均匀分成100个数据
    3. 维度调整为(100,1)
  3. 使用定义的线性回归将x_predictions预测成y_predictions
plt.scatter(x_train,y_train,label='Train data')
plt.scatter(x_test,y_test,label='test data')
plt.plot(x_predictions,y_predictions,'r',label = 'Prediction')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()
  1. 训练数据和训练标签的散点图
  2. 测试数据和测试标签的散点图
  3. x_predictions和y_predictions 对应的一条直线
  4. 画图

打印结果:
在这里插入图片描述

6、数据预处理

机器学习开发流程中一定有一个数据预处理的重要流程,在很多实际的任务中,数据预处理甚至比网络设计更复杂更重要。

6.1 归一化函数

这部分函数主要为了将原始数据放入到一个合适的范围内,一般是[0,1]的范围或者[-1,1]的范围,人能识别数据,计算机只识别数字,机器学习只能认识特征

def normalize(features):
    features_normalized = np.copy(features).astype(float)
    features_mean = np.mean(features, 0)
    features_deviation = np.std(features, 0)
    if features.shape[0] > 1:
        features_normalized -= features_mean
    features_deviation[features_deviation == 0] = 1
    features_normalized /= features_deviation
    return features_normalized, features_mean, features_deviation
  1. 深度复制传进来的原始数据features,转换为float格式
  2. 返回原始数据的均值
  3. 返回原始数据的标准差
  4. 判断features是否只有一个数字
  5. 原始数据减去均值
  6. 判断标准差是否为0,如果为0 则改为1(防止分母出现为0的情况)
  7. 原始数据减去均值的结果再除以标准差
  8. 返回处理结果、均值、标准差

6.2 数据预处理函数

在此次的数据预处理中只用到了归一化操作

def prepare_for_training(data, polynomial_degree=0, sinusoid_degree=0, normalize_data=True):
    num_examples = data.shape[0]
    data_processed = np.copy(data)
    features_mean = 0
    features_deviation = 0
    data_normalized = data_processed
    if normalize_data:
        (data_normalized, features_mean, features_deviation ) = normalize(data_processed)
        data_processed = data_normalized
    if sinusoid_degree > 0:
        sinusoids = generate_sinusoids(data_normalized, sinusoid_degree)
        data_processed = np.concatenate((data_processed, sinusoids), axis=1)
    if polynomial_degree > 0:
        polynomials = generate_polynomials(data_normalized, polynomial_degree, normalize_data)
        data_processed = np.concatenate((data_processed, polynomials), axis=1)
    data_processed = np.hstack((np.ones((num_examples, 1)), data_processed))
    return data_processed, features_mean, features_deviation
  1. 计算有多少个数
  2. 深度复制原始数据
  3. 初始均值0(避免提示报错而已)
  4. 初始标准差0(避免提示报错而已)
  5. 定义初始归一化数据(避免提示报错而已)
  6. 将数据传入初始化函数
  7. 特征变换sinusoidal
  8. 特征变换polynomial
  9. 原始数据拼接了一列1
  10. 返回数据

7、整体流程解读

单特征线性回归整体流程,从Non-linearRegression.py文件的这行代码开始:
data = pd.read_csv(‘…/data/non-linear-regression-x-y.csv’)

  1. 读数据
  2. 选择特征
  3. 画一下原始数据的散点图(训练数据、测试数据)
  4. 进入线性回归类
  5. 在线性回归类进入初始化函数
  6. 在初始化函数进入数据预处理函数
  7. 在数据预处理函数中进入归一化操函数后,返回处理结果、均值、标准差,返回初始化函数
  8. 初始化函数系列赋值操作
  9. 退出线性回归类,返回线性回归实例化对象
  10. 线性回归对象调用trian函数
  11. 在trian函数中调用梯度下降函数
  12. 在梯度下降函数中多次调用参数更新函数以及损失计算函数
  13. 线性回归对象的trian函数返回损失,返回最后的参数
  14. 打印损失
  15. 画出损失下降过程
  16. 进行预测

手撕线性回归1之线性回归类的实现
手撕线性回归2之单特征线性回归
手撕线性回归3之多特征线性回归
手撕线性回归4之非线性回归

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/993769.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0-1背包-动态规划

一、01背包 描述:有 N 件物品和一个容量为 V 的背包,每件物品只能使用一次 第 i 件物品的体积是 Ci,价值是 Wi 求解将哪些物品装入背包,能够在不超过背包容量的情况下使总价值最大 求解:动态规划 使用dp[i][j]表示从…

zabbix监控H3C设备

背景 常见的服务和主机已经使用Prometheus进行监控了,但是网络设备还未配置监控。使用基于SNMP对网络设备进行监控。 设备概览 主要类型为H3C的路由器和交换机。H3CS5560交换机 路由器MER5200 er8300 步骤 配置网络设备开启telnet远程; 配置启用sn…

nodejs采集淘宝、天猫网商品详情数据以及解决_m_h5_tk令牌及sign签名验证(2023-09-09)

一、淘宝、天猫sign加密算法 淘宝、天猫对于h5的访问采用了和APP客户端不同的方式,由于在h5的js代码中保存appsercret具有较高的风险,mtop采用了随机分配令牌的方式,为每个访问端分配一个token,保存在用户的cookie中,通…

SAP-MM-销售订单库存转移到普通库存

业务需求: 特殊库存-销售订单库存 有产成品物料1个,现在需要在集团下的两个公司间调拨,需要把特殊库存E调拨到普通库存里,再从H020普通库存调拨到另一个工厂1000. 注意事项:库存地点需要扩充,否则调拨会报…

iOS 17新功能:教你轻松掌握锁定屏幕快捷方式

通过iOS 17,苹果为iPhone用户提供了使用快捷方式锁定手机屏幕的能力。 为什么你需要学习如何使用iOS锁定屏幕快捷方式?按下iPhone上的电源按钮激活这个屏幕肯定是最简单的吗?嗯,这并不总是正确的。如果你在按下物理按钮时遇到困难…

【2023知乎爬虫】批量获取问题的全部回答

一.需求 爬取任意问题下的所有回答,如下图: 1.根据问题,批量获取问题下的所有回答、与对应问题的关系到answer.csv文件; 2.保存当前问题基本信息到quesiton_info.csv文件; 二.展示爬取结果 三.讲解步骤 3.1 新建项…

个人开发者看过来,我搭了一个监控系统免费用

最近在做一个自己的项目,平时就在自己电脑上跑着,有一天回去突然就挂了,查了半天也没搞清楚原因,想看个监控都没有,什么时候挂的,为啥挂了,统统都不知道。平时做公司项目多了,监控用…

C/C++操作加密与不加密的zip文件

为了后续的方便操作zip文件, 将所有的操作封装成了一个动态库了。 /*** \description 从压缩包文件中解压出指定的文件到指定的目录.* \author sunsz* \date 2023/09/09**/ LIBZIP_API int UnpackFile(const char* password, char zipfilename[], char filename_…

rt-thread------任务调度

rt-thread------任务调度 1. 线程初始化 在rt-thread中线程主要包括以下一些内容,线程控制块、线程栈、函数入口。 1.1线程创建函数 RTOS基本都包括两种线程方式:动态创建rt_thread_create()和静态创建rt_thread_init()。 因为有些系统设计时对安全…

硬件学习件Cadence day13 PCB设计中一些设置, 铜皮到钻孔的距离设置, 差分线的设置,板层信息表

1. 设置铺铜中铜皮到钻口,连线的距离。 1. 打开设置界面 2. 设计界面 调整到 铜皮设置界面 2. 高速线的设置 (差分对传输线的设置) 1. 打开设置界面 2. 来到 差分线设置界面 3. 把界面往右看, 设置差分线的之间距离,…

Python之并发编程介绍

一、并发编程介绍 1.1、串行、并行与并发的区别 串行(serial):一个CPU上,按顺序完成多个任务并行(parallelism):指的是任务数小于等于cpu核数,即任务真的是一起执行的并发(concurrency):一个CPU采用时间片管理方式&am…

TrOCR – 基于 Transformer 的 OCR 入门指南

多年来,光学字符识别 (OCR) 出现了多项创新。它对零售、医疗保健、银行和许多其他行业的影响是巨大的。尽管有着悠久的历史和多种最先进的模型,研究人员仍在不断创新。与深度学习的许多其他领域一样,OCR 也看到了变压器神经网络的重要性和影响。如今,我们拥有像TrOCR(Tran…

franka_ros中的一些子包的使用

franka_visualization包 该软件包包含连接到机器人并发布机器人和夹爪关节状态以在 RViz 中进行可视化的发布者。要运行此包启动&#xff1a; roslaunch franka_visualization franka_visualization.launch robot_ip:<fci-ip> \load_gripper:<true|false> 比如&a…

UI自动化测试工具详解

常用工具 1、QTP&#xff1a;商业化的功能测试工具&#xff0c;收费&#xff0c;可用于web自动化测试 2、Robot Framework&#xff1a;基于Python可扩展的关键字驱动的测试自动化框架 3、Selenium &#xff1a;开源的web自动化测试工具&#xff0c;免费&#xff0c;主要用于功…

SpringCloud-微服务CAP原则

接上文 SpringCloud-Config配置中心 到此部分即微服务的入门。 总的来说&#xff0c;数据存放的节点数越多&#xff0c;分区容忍性就越高&#xff0c;但要复制更新的次数就越多&#xff0c;一致性就越难保证。同时为了保证一致性&#xff0c;更新所有节点数据所需要的时间就…

Python教程33:关于在使用zipfile模块,出现中文乱码的解决办法

zipfile是Python标准库中的一个模块&#xff0c;zipfile里有两个class, 分别是ZipFile和ZipInfo&#xff0c;用来创建和读取zip文件&#xff0c;而ZipInfo是存储的zip文件的每个文件的信息的。ZIP文件是一种常见的存档文件格式&#xff0c;它可以将多个文件和目录压缩为一个文件…

帝国cms后台访问链接提示“非法来源”解决方法

提示“非法来源”的原因 帝国CMS更新升级7.2后,新增了后台安全模式,后台推出了金刚模式来验证链接来源。后台所有链接都需要登录后才能访问,直接强制访问后台页面链接都会提示“非法来源”。不是正常登录后台的用户无法直接访问到内容,保证了后台数据安全。 那么我们在日常…

Table of Laplace Transforms

https://www.math.uh.edu/~etgen/LaplaceT.pdf http://web.mit.edu/2.737/www/handouts/LaplaceTransforms.pdf https://www.integral-table.com/downloads/LaplaceTable.pdf https://www.math.purdue.edu/~caiz/MA527-cai/lectures/Table%20of%20Laplace%20Transforms.pdf

阅读源码工具Sourcetrail

收费工具Source Insight、Understand Sourcetrail开源工具 一、下载安装 接下来就是download&#xff0c;在GitHub的release页面选择自己系统对应的发布版本下载安装&#xff1a; 安装好后&#xff0c;运行程序&#xff0c;会出现这样的界面&#xff1a; 二、应用 选择“New…

2023年最佳研发管理平台评选:哪家表现出色?

“研发管理平台哪家好&#xff1f;以下是一些知名的研发管理软件品牌&#xff1a;Zoho Projects、JIRA、Trello、Microsoft Teams、GitLab。’” 企业需要不断创新以保持竞争力。研发是企业创新的核心&#xff0c;而研发管理平台则为企业提供了一个有效的工具来支持和管理其研发…