机器学习第一篇 线性回归

news2025/4/24 5:57:50

数据集:公开的World Happiness Report | Kaggle中的happiness dataset2017.

目标:基于GDP值预测幸福指数。(单特征预测)

代码:

文件一:prepare_for_traning.py

"""用于科学计算的一个库,提供了多维数组对象以及操作函数"""
from utils.features import prepare_for_training
"""数据预处理的一个私库"""

class LinearRegression:
    def __init__(self,data,labels,polynomial_degree = 0,sinusoid_degree = 0,normalize_data = True):
        """
        进行预处理操作
        :param data:
        :param labels:
        :param polynomial_degree:
        :param sinusoid_degree:
        :param normalize_data:
        """
        (data_processed,
            features_mean,
            features_deviation) = prepare_for_training(data,polynomial_degree = 0,sinusoid_degree = 0,normalize_data = True)
        self.data = data_processed
        self.labels = labels
        self.features_mean = features_mean
        self.features_deviation = features_deviation
        self.polynomial_degree = polynomial_degree
        self.sinusoid_degree = sinusoid_degree
        self.normalize_data = normalize_data
        num_features = self.data.shape[1]
        self.theta = np.zeros((num_features,1))
    """"数据,学习率,训练次数"""
    def train(self,alpha,num_iterations = 500):
        """训练模块:梯度下降"""
        cost_history = self.gradient_descent(alpha,num_iterations)
        return self.theta,cost_history

    def gradient_descent(self,alpha,num_iterations):
        """迭代模块"""
        cost_history = []
        for _ in range(num_iterations):
            self.gradient_step(alpha)
            cost_history.append(self.cost_function(self.data,self.labels))
        return cost_history

    def gradient_step(self,alpha):
        """
        梯度下降参数更新算法,矩阵计算,使用小批量梯度下降算法
        :param self:
        :param alpha:
        :return:
        """
        num_examples = self.data.shape[0]
        prediction = LinearRegression.hypothesis(self.data,self.theta)
        delta = prediction - self.labels
        theta = self.theta
        theta = theta - alpha*(1/num_examples)*(np.dot(delta.T,self.data)).T
        self.theta = theta

    def cost_function(self,data,labels):
        """
        损失计算模块
        :param self:
        :param data:
        :param labels:
        :return:
        """
        num_examples = data.shape[0]
        delta = LinearRegression.hypothesis(self.data,self.theta) - labels
        cost = (1/2)*np.dot(delta.T,delta)/num_examples
        """print(cost.shape)"""
        return cost[0][0]

    """装饰器"""
    @staticmethod
    def hypothesis(data,theta):
        prediction = np.dot(data,theta)
        return prediction

    def get_cost(self,data,labels):
        data_processed = prepare_for_training(data,
         self.polynomial_degree,
         self.sinusoid_degree,
         self.normalize_data
         )[0]

        return self.cost_function(data_processed,labels)

    def predict(self,data):
        data_processed = prepare_for_training(data,
        self.polynomial_degree,
        self.sinusoid_degree,
        self.normalize_data
        )[0]
        predictions = LinearRegression.hypothesis(data_processed,self.theta)

        return predictions

文件2:Linear_regression.py 

import numpy as np
"""用于科学计算的一个库,提供了多维数组对象以及操作函数"""
import pandas as pd
"""一个用于数据导入、导出、清洗和分析的库,本文中导入csv格式数据等等"""
import matplotlib.pyplot as plt
"""pyplot提供了绘图接口"""
import matplotlib
"""一个强大的绘图库"""

# 设置matplotlib正常显示中文和负号
matplotlib.rcParams['font.family'] = 'SimHei'  # 指定默认字体为黑体
matplotlib.rcParams['axes.unicode_minus'] = False  # 正确显示负号


from prepare_for_training import LinearRegression

data = pd.read_csv("D:/machine_learning/archive/2017.csv")
train_data = data.sample(frac = 0.8)
test_data = data.drop(train_data.index)

input_param_name = 'Economy..GDP.per.Capita.'
output_param_name = 'Happiness.Score'

x_train = train_data[[input_param_name]].values
y_train = train_data[[output_param_name]].values

x_test = test_data[[input_param_name]].values
y_test = test_data[[output_param_name]].values

plt.scatter(x_train,y_train,label ='Train data')
plt.scatter(x_test,y_test,label ='Test data')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()

"""训练次数,学习率"""
num_iterations = 500
learning_rate = 0.01

linear_regression = LinearRegression(x_train,y_train)
(theta,cost_history) = linear_regression.train(learning_rate,num_iterations)
print('开始时的损失',cost_history[0])
print('训练后的损失',cost_history[-1])

plt.plot(range(num_iterations),cost_history)
plt.xlabel('Iter')
plt.ylabel('cost')
plt.title('损失值')
plt.show()

predictions_num = 100
x_predictions = np.linspace(x_train.min(),x_train.max(),predictions_num).reshape(predictions_num,1)
y_predictions = linear_regression.predict(x_predictions)

plt.scatter(x_train,y_train,label ='Train data')
plt.scatter(x_test,y_test,label ='Test data')
plt.plot(x_predictions,y_predictions,'r',label = 'Prediction')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()

效果图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2341241.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CS144 Lab1实战记录:实现TCP重组器

文章目录 1 实验背景与要求1.1 TCP的数据分片与重组问题1.2 实验具体任务 2 重组器的设计架构2.1 整体架构2.2 数据结构设计 3 重组器处理的关键场景分析3.1 按序到达的子串(直接写入)3.2 乱序到达的子串(需要存储)3.3 与已处理区…

Linux安装mysql_exporter

mysqld_exporter 是一个用于监控 MySQL 数据库的 Prometheus exporter。可以从 MySQL 数据库的 metrics_schema 收集指标,相关指标主要包括: MySQL 服务器指标:例如 uptime、version 等数据库指标:例如 schema_name、table_rows 等表指标:例如 table_name、engine、…

BeautifulSoup 库的使用——python爬虫

文章目录 写在前面python 爬虫BeautifulSoup库是什么BeautifulSoup的安装解析器对比BeautifulSoup的使用BeautifulSoup 库中的4种类获取标签获取指定标签获取标签的的子标签获取标签的的父标签(上行遍历)获取标签的兄弟标签(平行遍历)获取注释根据条件查找标签根据CSS选择器查找…

HTTP的Header

一、HTTP Header 是什么? HTTP Header 是 HTTP 协议中的头部信息部分,位于请求或响应的起始行之后,用来在客户端(浏览器等)与服务器之间传递元信息(meta-data)(简单理解为传递信息的…

linux虚拟机网络问题处理

yum install -y yum-utils \ > device-mapper-persistent-data \ > lvm2 --skip-broken 已加载插件:fastestmirror, langpacks Loading mirror speeds from cached hostfile Could not retrieve mirrorlist http://mirrorlist.centos.org/?release7&arch…

AI-Sphere-Butler之如何使用Llama factory LoRA微调Qwen2-1.5B/3B专属管家大模型

环境: AI-Sphere-Butler WSL2 英伟达4070ti 12G Win10 Ubuntu22.04 Qwen2.-1.5B/3B Llama factory llama.cpp 问题描述: AI-Sphere-Butler之如何使用Llama factory LoRA微调Qwen2-1.5B/3B管家大模型 解决方案: 一、准备数据集我这…

协同推荐算法实现的智能商品推荐系统 - [基于springboot +vue]

🛍️ 智能商品推荐系统 - 基于springboot vue 🚀 项目亮点 欢迎来到未来的购物体验!我们的智能商品推荐系统就像您的私人购物顾问,它能读懂您的心思,了解您的喜好,为您精心挑选最适合的商品。想象一下&am…

Jenkins的地位和作用

所处位置 Jenkins 是一款开源的自动化服务器,广泛应用于软件开发和测试流程中,主要用于实现持续集成(CI)和持续部署(CD)。它在开发和测试中的位置和作用可以从以下几个方面来理解: 1. 在开发和测…

【集合】底层原理实现及各集合之间的区别

文章目录 集合2.1 介绍一下集合2.2 集合遍历的方法2.3 线程安全的集合2.4 数组和集合的区别2.5 ArrayList和LinkedList的区别2.6 ArrayList底层原理2.7 LinkedList底层原理2.8 CopyOnWriteArrayList底层原理2.9 HashSet底层原理2.10 HashMap底层原理2.11 HashTable底层原理2.12…

srp batch

参考网址: Unity MaterialPropertyBlock 正确用法(解决无法合批等问题)_unity_define_instanced_prop的变量无法srp合批-CSDN博客 URP | 基础CG和HLSL区别 - 哔哩哔哩 (bilibili.com) 【直播回放】Unity 批处理/GPU Instancing/SRP Batche…

【Linux运维涉及的基础命令与排查方法大全】

文章目录 前言1、计算机网络常用端口2、Kali Linux中常用的命令3、Kali Linux工具的介绍4、Ubuntu没有网络连接解决方法5、获取路由6、数据库端口 前言 以下介绍计算机常见的端口已经对应的网络协议,Linux中常用命令,以及平时运维中使用的排查网络故障的…

Webview+Python:用HTML打造跨平台桌面应用的创新方案

目录 一、技术原理与优势分析 1.1 架构原理 1.2 核心优势 二、开发环境搭建 2.1 安装依赖 2.2 验证安装 三、核心功能开发 3.1 基础窗口管理 3.2 HTML↔Python通信 JavaScript调用Python Python调用JavaScript 四、高级功能实现 4.1 系统级集成 4.2 多窗口管理 五…

克服储能领域的数据处理瓶颈及AI拓展

对于储能研究人员来说,日常工作中经常围绕着一项核心但有时令人沮丧的任务:处理实验数据。从电池循环仪的嗡嗡声到包含电压和电流读数的大量电子表格,研究人员的大量时间都花在了提取有意义的见解上。长期以来,该领域一直受到对专…

包含物体obj与相机camera的 代数几何代码解释

反余弦函数的值域在 [0, pi] 斜体样式 cam_pose self._cameras[hand_realsense].camera.get_model_matrix() # cam2world# 物体到相机的向量 obj_tcp_vec cam_pose[:3, 3] - self.obj_pose.p dist np.linalg.norm(obj_tcp_vec) # 物体位姿的旋转矩阵 obj_rot_mat self.ob…

mybatis实现增删改查1

文章目录 19.MyBatis查询单行数据MapperScan 结果映射配置核心文件Results自定义映射到实体的关系 多行数据查询-完整过程插入数据配置mybatis 控制台日志 更新数据删除数据小结通过id复用结果映射模板xml处理结果映射 19.MyBatis 数据库访问 MyBatis,MyBatis-Plus…

Git,本地上传项目到github

一、Git的安装和下载 https://git-scm.com/ 进入官网,选择合适的版本下载 二、Github仓库创建 点击右上角New新建一个即可 三、本地项目上传 1、进入 要上传的项目目录,右键,选择Git Bash Here,进入终端Git 2、初始化临时仓库…

基于flask+vue框架的灯饰安装维修系统u49cf(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。

系统程序文件列表 项目功能:用户,工单人员,服务项目,订单记录,服务记录,评价记录 开题报告内容 基于 FlaskVue 框架的灯饰安装维修系统开题报告 一、选题背景与意义 (一)选题背景 随着城市化进程的加速与居民生活品质的显著提升&#xf…

【算法】BFS-解决FloodFill问题

目录 FloodFill问题 图像渲染 岛屿数量 岛屿的最大面积 被围绕的区域 FloodFill问题 FloodFill就是洪水灌溉的意思,假设有下面的一块田地,负数代表是凹地,正数代表是凸地,数字的大小表示凹或者凸的程度。现在下一场大雨&…

GIS开发笔记(10)基于osgearth实现二三维地图的一键指北功能

一、实现效果 二、实现原理 获取视图及地图操作器,通过地图操作器来重新设置视点,以俯仰角 (0.0)和偏航角 (-90.0)来设置。 osgEarth::Util::Viewpoint(…) 这里创建了一个新的 Viewpoint 对象,表示一个特定的视角。构造函数的参数是: 第一个参数:是视角名称。 后面的 6 个…

window上 elasticsearch v9.0 与 jmeter5.6.3版本 冲突,造成es 启动失败

[2025-04-22T11:00:22,508][ERROR][o.e.b.Elasticsearch ] [AIRUY] fatal exception while booting Elasticsearchjava.nio.file.NoSuchFileException: D:\Program Files\apache-jmeter-5.6.3\lib\logkit-2.0.jar 解决方案: 降低 es安装版本 ,选择…