机器学习与深度学习——利用随机梯度下降算法SGD对波士顿房价数据进行线性回归

news2024/11/17 20:36:14

机器学习与深度学习——利用随机梯度下降算法SGD对波士顿房价数据进行线性回归

我们这次使用随机梯度下降(SGD)算法对波士顿房价数据进行线性回归的训练,给出每次迭代的权重、损失和梯度,并且绘制损失loss随着epoch变化的曲线图。

步骤

1、导入必要的库和模块:numpy,pandas,matplotlib,load_boston和StandardScaler。其中,load_boston用于加载波士顿房价数据集,StandardScaler用于对数据进行标准化处理。
2、加载数据集并对数据进行标准化处理。同时,为数据添加一列1作为截距项,并将y转换为列向量。
3、定义SGD函数来进行训练。在每个epoch中,我们会随机地从样本中抽取一个batch的数据来计算梯度和损失,并更新权重。
4、使用SGD训练模型,并输出每次迭代的结果:权重w,损失loss和梯度grad。同时,将每个epoch的平均损失存储到列表losses中,并返回最终的权重和损失。
5、绘制随机梯度下降训练过程中,损失函数值随着epoch变化的曲线图

程序代码

1.使用随机梯度下降(SGD)算法对波士顿房价数据进行线性回归的训练,给出每次迭代的权重、损失和梯度,并且绘制了损失loss随着epoch变化的曲线图。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
#加载数据集并对数据进行标准化处理。同时,为数据添加一列1作为截距项,并将y转换为列向量。
# 加载数据
boston = load_boston()
X, y = boston.data, boston.target

# 标准化数据
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 在数据中添加一列1
X = np.hstack((np.ones((X.shape[0], 1)), X))

# 将y转换为列向量
y = y.reshape(-1, 1)
# 3、定义SGD函数来进行训练。在每个epoch中,我们会随机地从样本中抽取一个batch的数据来计算梯度和损失,并更新权重。
def sgd(X, y, lr=0.01, epochs=100, batch_size=32):
    n_samples, n_features = X.shape
    w = np.zeros((n_features, 1))
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
 
        # 随机排列样本
        permutation = np.random.permutation(n_samples)
        
        for i in range(0, n_samples, batch_size):
            # 获取一个batch的样本
            indices = permutation[i:i+batch_size]
            X_batch = X[indices]
            y_batch = y[indices]
            
            # 计算梯度和损失
            grad = X_batch.T.dot(X_batch.dot(w) - y_batch) / batch_size
            loss = np.mean((X_batch.dot(w) - y_batch) ** 2)
            epoch_loss += loss
            
            # 更新权重
            w -= lr * grad
        
        losses.append(epoch_loss / (n_samples // batch_size))
        
    return w, losses
# 使用SGD训练模型,并输出每次迭代的结果:权重w,损失loss和梯度grad。同时,将每个epoch的平均损失存储到列表losses中,并返回最终的权重和损失。
w, losses = sgd(X, y, lr=0.01, epochs=100, batch_size=32)

def sgd(X, y, lr=0.01, epochs=100, batch_size=32):
    n_samples, n_features = X.shape
    w = np.zeros((n_features, 1))
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        
        # 随机排列样本
        permutation = np.random.permutation(n_samples)
        
        for i in range(0, n_samples, batch_size):
            # 获取一个batch的样本
            indices = permutation[i:i+batch_size]
            X_batch = X[indices]
            y_batch = y[indices]
            
            # 计算梯度和损失
            grad = X_batch.T.dot(X_batch.dot(w) - y_batch) / batch_size
            loss = np.mean((X_batch.dot(w) - y_batch) ** 2)
            epoch_loss += loss
            
            # 更新权重
            w -= lr * grad
            
            # 输出w值和grad值和loss值
            print('w:', w.flatten())
            print('grad:', grad.flatten())
            print('loss:', loss)
        
        losses.append(epoch_loss / (n_samples // batch_size))
        
    return w, losses

w, losses = sgd(X, y, lr=0.01, epochs=100, batch_size=32)



#绘制随机梯度下降训练过程中,损失函数值随着epoch变化的曲线图
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Cost')
plt.show()

效果图

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
随机梯度下降(SGD)是一种简单但非常有效的方法,多用用于支持向量机、逻辑回归(LR)等凸损失函数下的线性分类器的学习。并且SGD已成功应用于文本分类和自然语言处理中经常遇到的大规模和稀疏机器学习问题。SGD既可以用于分类计算,也可以用于回归计算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/734270.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

集群 第四章

目录 1. nginx、lvs、haproxy 的区别 2. 实验 3. ssh 升级 4.总结 1. nginx、lvs、haproxy 的区别 2. 实验 Haproxy 服务器:192.168.83.101 Nginx 服务器1:192.168.83.102 Nginx 服务器2:192.168.83.103 …

Mysql之视图,索引及数据的备份与恢复

目录 一、视图 1.视图是什么 2.视图与数据表的区别 3.视图的优缺点 优点: 缺点: 4.视图的应用场景 5.语法运用 二、索引 1.什么是索引 2.为什么要使用索引 3.使用索引的优缺点 4.何时不使用索引 5.索引何时失效 6.索引分类 三、数据的备份…

头结点到底方便了啥?

头结点到底方便了啥? 链表增加头结点的作用如下: (1)便于首元结点的处理 (2)便于空表和非空表的统一处理 (参考:《数据结构 C语言(第2版)》P31) 其实这两句话很抽象,你说方便就方便,你倒是举个粟子或者画个图什么的啊&…

Linux开发工具之vim工具的使用介绍

目录 前言 1.vim的基本概念 命令模式(Normal mode) 插入模式(Insert mode) 末行模式(last line mode) 2.vim的基本操作 命令模式的命令集 移动光标 ​编辑 删除文字 复制 替换 撤销操作 更改 vim末行模式命令集 简单vim配置 总结 前言 大家好呀,许久…

Java动态规划LeetCode1137. 第 N 个泰波那契数

方法1:通过动态规划解题,这道题也是动态规划的一道很好的入门题,因为比较简单和容易理解。 代码如下: public int tribonacci(int n) {//处理特殊情况if(n0){return 0;}if(n1||n2){return 1;}//定义数组int[]dpnew int[n1];//初…

浏览器通过js打开文件,新建文件,静默实时保存文件

资源&#xff0c;点击下载 在线访问Txt Markdown &#x1f61d;&#x1f61d;&#x1f61d;&#x1f61d;&#x1f61d;&#x1f61d; 新建文件后&#xff0c;可以直接保存文件&#xff0c;不需要再次下载文件&#xff0c;也只有第一次保存时候才会出现确认弹窗 html <!D…

尚硅谷React学习笔记(上)

目录 一、React入门 1.1、React简介 为什么要学&#xff1f; React的特点 1.2、React的基本使用 Hello React案例 创建虚拟DOM的两种方式 虚拟DOM与真实DOM 1.3、React JSX 语法规则 JSX小练习 1.4、模块与组件化的理解 模块 组件 模块化 组件化 二、React面向…

E. Scuza - 二分+前缀和

分析&#xff1a; 暴力会超时&#xff0c;可以用二分&#xff0c;构建两个数组&#xff0c;一个是a[i]&#xff0c;作为前缀和数组&#xff0c;一个是f[i]表示第i个台阶之前的最大高度的台阶&#xff0c;然后每次二分来查找k&#xff0c;因为尽可能地走的多&#xff0c;所以查找…

VTK STL 体积 表面积测量 最短路径 读取中文路径

目录 开发环境&#xff1a; vtkMassProperties 三、中文路径 数据读取 开发环境&#xff1a; 系统&#xff1a;Win10 VTK&#xff1a;8.2.0 Qt&#xff1a;5.12.4 一、结构化对象 体积 面积 vtkMassProperties VTK 计算体积和面积的主要类 vtkMassProperties vtkSm…

C语言进阶之指针的进阶

指针的进阶 1. 字符指针2. 指针数组3. 数组指针3.1 数组指针的定义3.2 &数组名VS数组名3.3 数组指针的使用 4. 数组参数、指针参数4.1 一维数组传参4.2 二维数组传参4.3 一级指针传参4.4 二级指针传参 5. 函数指针6. 函数指针数组7. 指向函数指针数组的指针8. 回调函数9. 指…

【程序员必须掌握的算法】【Matlab】GRNN神经网络遗传算法(GRNN-GA)函数极值寻优——非线性函数求极值

上一篇博客介绍了BP神经网络遗传算法(BP-GA)函数极值寻优——非线性函数求极值&#xff0c;神经网络用的是BP神经网络&#xff0c;本篇博客将BP神经网络替换成GRNN神经网络&#xff0c;希望能帮助大家快速入门GRNN网络。 1.背景条件 要求&#xff1a;对于未知模型&#xff08;…

使用trtexec工具多batch推理tensorrt模型(trt模型)

文章目录 零、pt转onnx模型一、onnx转trt模型二、推理trt模型 零、pt转onnx模型 参考&#xff1a;https://github.com/ultralytics/yolov5 用根目录下的export.py可以转pt为onnx模型&#xff0c;命令如下可以转换成动态batch的onnx模型 python3 export.py --weights./yolov5s…

一款强大易用的截图控件:跨平台,界面简洁,功能丰富,易于集成

当我们在日常工作中沟通交流&#xff0c;或是在开发过程中跟踪反馈问题时&#xff0c;截图无疑是一种最直观有效的方式。然而&#xff0c;传统的截图工具在功能上的局限性&#xff0c;往往无法满足我们日益增长的需求。这时&#xff0c;一款功能强大&#xff0c;易于集成&#…

垃圾收集算法和CMS详解

一、垃圾收集算法 1、分带收集理论 基于新生代和老年代选择不同垃圾回收算法&#xff0c;比如新生代&#xff0c;都是一些暂存对象&#xff0c;而且内存分区域的&#xff0c;可以采用标记复制算法。而老年代只有一块内存区域&#xff0c;使用复制算法比较占用内存空间&#x…

DEVICENET转ETHERCAT网关连接ethercat通讯协议详细解析

你有没有遇到过生产管理系统中&#xff0c;设备之间的通讯问题&#xff1f;两个不同协议的设备进行通讯&#xff0c;是不是很麻烦&#xff1f;今天&#xff0c;我们为大家介绍一款神奇的产品&#xff0c;能够将不同协议的设备进行连接&#xff0c;让现场的数据交换不再困扰&…

MySQL数据库 - 库的操作

目录​​​​​​​ 一、创建数据库 二、创建数据库案例 三、字符集和校验规则 四、校验规则对数据库的影响 五、操纵数据库 1、查看数据库 2、显示创建语句 3、修改数据库 4、删除数据库 六、数据库的备份与恢复 1、数据库的备份 2、数据库的恢复 3、表的备份 4…

【网络系统集成】Pfsense防火墙实验

1.实验名称 Pfsense防火墙实验 2.实验目的 通过动手实践配置pfsense对加深对防火墙的原理与应用的理解。 3.实验内容 (1)安装并完成pfsense防火墙软件的基本配置(WAN, LAN,局域网

刘积仁:东软不太喜欢风口,更看重长期主义

作为数字和软件服务产业一年一度的行业盛宴&#xff0c;2003年&#xff0c;中国国际软件和信息服务交易会&#xff08;简称“软交会”&#xff09;正式诞生。2019年&#xff0c;大会更名为中国国际数字和软件服务交易会&#xff08;简称“数交会”&#xff09;&#xff0c;至今…

【C++修炼之路】string 概述

&#x1f451;作者主页&#xff1a;安 度 因 &#x1f3e0;学习社区&#xff1a;StackFrame &#x1f4d6;专栏链接&#xff1a;C修炼之路 文章目录 一、string 为何使用模板二、string 类认识1、构造/析构/赋值运算符重载2、容量操作3、增删查改4、遍历5、迭代器6、非成员函数…

[NSSRound#13 Basic]flask?jwt?解题思路过程

过程 打开题目链接&#xff0c;是一个登录框&#xff0c;不加验证码&#xff0c;且在注册用户名admin时提示该用户名已被注册&#xff0c;因此爆破也是一种思路。不过根据题目名字中的提示&#xff0c;jwt&#xff0c;且拥有注册入口&#xff0c;注册一个用户先。 注册完用户…