【大数据】机器学习-----线性模型

news2025/1/15 7:09:23

一、线性模型基本形式

线性模型旨在通过线性组合输入特征来预测输出。其一般形式为:

在这里插入图片描述

其中:

  • x = ( x 1 , x 2 , ⋯   , x d ) \mathbf{x}=(x_1,x_2,\cdots,x_d) x=(x1,x2,,xd) 是输入特征向量,包含 d d d 个特征。
  • w = ( w 1 , w 2 , ⋯   , w d ) \mathbf{w}=(w_1,w_2,\cdots,w_d) w=(w1,w2,,wd) 是权重向量,每个元素 w i w_i wi 表示对应特征的重要性。
  • w 0 = b w_0 = b w0=b 是偏置项,允许模型在没有任何输入特征时也能进行预测。

二、线性回归

线性回归用于预测连续值,其目标是找到最佳的 w \mathbf{w} w b b b 以最小化预测值 y ^ \hat{y} y^ 与真实值 y y y 之间的均方误差(MSE)。给定一组包含 m m m 个样本的数据集 { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯   , ( x m , y m ) } \{(\mathbf{x}_1,y_1),(\mathbf{x}_2,y_2),\cdots,(\mathbf{x}_m,y_m)\} {(x1,y1),(x2,y2),,(xm,ym)},均方误差的计算公式如下:

在这里插入图片描述

通常使用梯度下降法来优化这个目标函数,其更新规则如下:

对于权重 w j w_j wj j = 1 , 2 , ⋯   , d j = 1,2,\cdots,d j=1,2,,d):
在这里插入图片描述

对于偏置项 b b b
在这里插入图片描述

其中 α \alpha α 是学习率,控制每次更新的步长。

三、对数几率回归(逻辑回归)

逻辑回归用于二分类问题,将线性函数的输出通过逻辑函数(sigmoid 函数)转换为概率。逻辑函数定义为:
在这里插入图片描述

其目标是最大化似然函数,等价于最小化对数似然损失函数:

在这里插入图片描述

四、多分类学习

对于多分类问题,常用 softmax 函数将线性函数的结果转化为概率分布。假设类别数为 K K K,对于样本 i i i,首先计算线性函数的输出 z i k = w k T x i + b k z_{ik}=\mathbf{w}_k^T\mathbf{x}_i + b_k zik=wkTxi+bk,然后使用 softmax 函数:
在这里插入图片描述

其交叉熵损失函数为:

在这里插入图片描述

其中 y i k y_{ik} yik 是一个 one-hot 编码向量,如果样本 i i i 属于类别 k k k,则 y i k = 1 y_{ik}=1 yik=1,否则 y i k = 0 y_{ik}=0 yik=0

五、类别不平衡问题

类别不平衡问题发生在不同类别样本数量差异较大时,这可能导致模型偏向于多数类。常见的解决方法包括:

1. 重采样

  • 过采样:复制少数类样本以增加其数量。
  • 欠采样:删除多数类样本以减少其数量。

2. 代价敏感学习

  • 在损失函数中为不同类别赋予不同的权重,使得少数类的错误分类代价更高。

代码示例

线性回归示例

import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt


# 生成线性回归数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 初始化线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)

# 打印模型参数
print(f"Intercept: {model.intercept_}")
print(f"Coefficients: {model.coef_}")

# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

# 可视化结果
plt.scatter(X_test, y_test)
plt.plot(X_test, y_pred, color='red')
plt.xlabel('X')
plt.ylabel('y')
plt.title('Linear Regression')
plt.show()

在这里插入图片描述

逻辑回归示例

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt


# 生成二分类数据
np.random.seed(42)
X = np.random.randn(100, 2)
y = (X[:, 0] + X[:, 1] > 0).astype(int)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 初始化逻辑回归模型
model = LogisticRegression()

# 训练模型
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)

# 计算准确率
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy: {acc}")

# 可视化决策边界
h = 0.02
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Logistic Regression')
plt.show()

在这里插入图片描述

多分类逻辑回归示例

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt


# 生成多分类数据
# 调整 n_clusters_per_class 为 1 或调整 n_classes 为 2 或调整 n_informative 为 3 等
x, y = make_classification(n_samples=200, n_features=2, n_informative=2, n_redundant=0, n_classes=2, random_state=42)


# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2)


# 初始化多分类逻辑回归模型
model = LogisticRegression(multi_class='multinomial', solver='lbfgs')


# 训练模型
model.fit(X_train, y_train)


# 预测
y_pred = model.predict(X_test)


# 计算准确率
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy: {acc}")


# 可视化决策边界
h = 0.02
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.coolwarm, alpha=0.8)
plt.scatter(x[:, 0], x[:, 1], c=y, edgecolors='k')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Multiclass Logistic Regression')
plt.show()

在这里插入图片描述

类别不平衡问题示例(过采样)

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.utils import resample


# 生成类别不平衡数据
X, y = make_classification(n_samples=1000, n_features=2, n_informative=2, n_redundant=0, weights=[0.9, 0.1], random_state=42)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 原始模型
model = LogisticRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print(f"Original Accuracy: {accuracy_score(y_test, y_pred)}")
print(f"Original F1-score: {f1_score(y_test, y_pred)}")

# 过采样少数类
X_minority = X_train[y_train == 1]
y_minority = y_train[y_train == 1]
X_minority_upsampled, y_minority_upsampled = resample(X_minority, y_minority, replace=True, n_samples=X_train[y_train == 0].shape[0], random_state=42)
X_train_upsampled = np.vstack((X_train[y_train == 0], X_minority_upsampled))
y_train_upsampled = np.hstack((y_train[y_train == 0], y_minority_upsampled))

# 过采样后的模型
model_upsampled = LogisticRegression()
model_upsampled.fit(X_train_upsampled, y_train_upsampled)
y_pred_upsampled = model_upsampled.predict(X_test)
print(f"Upsampled Accuracy: {accuracy_score(y_test, y_pred_upsampled)}")
print(f"Upsampled F1-score: {f1_score(y_test, y_pred_upsampled)}")

在这里插入图片描述

代码解释

线性回归代码

  • np.random.rand(100, 1) 生成 100 个样本的特征数据。
  • LinearRegression() 创建线性回归模型。
  • model.fit(X_train, y_train) 训练模型。
  • model.predict(X_test) 进行预测。
  • mean_squared_error(y_test, y_pred) 计算均方误差。

逻辑回归代码

  • np.random.randn(100, 2) 生成二分类数据。
  • LogisticRegression() 创建逻辑回归模型。
  • model.fit(X_train, y_train) 训练模型。
  • accuracy_score(y_test, y_pred) 计算准确率。
  • 使用 meshgridcontourf 绘制决策边界。

多分类逻辑回归代码

  • make_classification 生成多分类数据。
  • LogisticRegression(multi_class='multinomial', solver='lbfgs') 创建多分类逻辑回归模型。
  • model.fit(X_train, y_train) 训练模型。
  • accuracy_score(y_test, y_pred) 计算准确率。

类别不平衡代码

  • make_classification 生成类别不平衡数据,通过 weights 参数控制类别比例。
  • resample 函数用于过采样少数类。
  • 比较原始模型和过采样后模型的准确率和 F1-score。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2276879.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

装备制造行业(复杂机械制造)数字化顶层规划 - 汇报会

行业业务特点: 尊敬的各位管理层: 大家好!今天我将向大家汇报装备制造企业数字化战略的顶层规划设计。在当今数字化浪潮下,装备制造企业面临着转型升级的迫切需求,数字化战略的制定与实施对于提升企业竞争力、实现可持…

深度探索C++20协程机制

#include <iostream> #include <coroutine>class CoroTaskSub { public://编译器在处理协程函数时是通过其返回类型【即协程接口类型】&#xff0c;确定协程的承诺类型和协程句柄类型struct promise_type;using CoroHdl std::coroutine_handle<promise_type>…

linux手动安装mysql5.7

一、下载mysql5.7 1、可以去官方网站下载mysql-5.7.24-linux-glibc2.12-x86_64.tar压缩包&#xff1a; https://downloads.mysql.com/archives/community/ 2、在线下载&#xff0c;使用wget命令&#xff0c;直接从官网下载到linux服务器上 wget https://downloads.mysql.co…

Java Stream实现【Int / Long / Double / Bigdecimal】累计求和

文章目录 背景实现方案案例素材Int类型求和Long 类型求和Double 类型求和BigDecimal 类型求和 背景 在项目开发中经常会使用到数据统计&#xff0c;Java中有求和的方法&#xff0c;可使用Java的Stream工作流实现&#xff0c;记录下来&#xff0c;方便备查。 实现方案 可使用…

OFD文件纯前端查看解决方案

文章目录 ofd.js原有bug修复1、ofd格式文档打开报错2、签章信息不显示 效果展示源码下载 使用前请查看免责声明 ofd.js原有bug修复 1、ofd格式文档打开报错 原因分析&#xff1a; 文档打开时会解析所用到的字体信息&#xff0c;如果字体不在ofd.js预设字体时&#xff0c;会触…

使用 Docker 部署 Java 项目(通俗易懂)

目录 1、下载与配置 Docker 1.1 docker下载&#xff08;这里使用的是Ubuntu&#xff0c;Centos命令可能有不同&#xff09; 1.2 配置 Docker 代理对象 2、打包当前 Java 项目 3、进行编写 DockerFile&#xff0c;并将对应文件传输到 Linux 中 3.1 编写 dockerfile 文件 …

二手车交易系统的设计与实现(代码+数据库+LW)

摘 要 如今社会上各行各业&#xff0c;都喜欢用自己行业的专属软件工作&#xff0c;互联网发展到这个时候&#xff0c;人们已经发现离不开了互联网。新技术的产生&#xff0c;往往能解决一些老技术的弊端问题。因为传统二手车交易信息管理难度大&#xff0c;容错率低&#xf…

抖音ip属地没有手机卡会显示吗

在数字时代&#xff0c;社交媒体平台如抖音已成为人们日常生活的重要组成部分。随着抖音等应用对用户体验和隐私保护的不断优化&#xff0c;IP属地显示功能逐渐走进大众视野。这一功能旨在提高网络环境的透明度&#xff0c;打击虚假信息和恶意行为。然而&#xff0c;对于没有手…

springMVC---resultful风格

目录 一、创建项目 pom.xml 二、配置文件 1.web.xml 2.spring-mvc.xml 三、图解 四、controller 一、创建项目 pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi…

[Git] 深入理解 Git 的客户端与服务器角色

Git 的一个核心设计理念是 分布式&#xff0c;每个 Git 仓库都可以既是 客户端&#xff0c;也可以是 服务器。为了更好地理解这一特性&#xff0c;我们通过一个实际的 GitHub 远程仓库和本地仓库的场景来详细说明 Git 如何在客户端和服务器之间协作&#xff0c;如何独立地进行版…

网络安全-RSA非对称加密算法、数字签名

数字签名非常普遍&#xff1a; 了解数字签名前先了解一下SHA-1摘要&#xff0c;RSA非对称加密算法。然后再了解数字签名。 SHA-1 SHA-1&#xff08;secure hash Algorithm &#xff09;是一种 数据加密算法。该算法的思想是接收一段明文&#xff0c;然后以一种不可逆的方式将…

了解 ASP.NET Core 中的中间件

在 .NET Core 中&#xff0c;中间件&#xff08;Middleware&#xff09; 是处理 HTTP 请求和响应的核心组件。它们被组织成一个请求处理管道&#xff0c;每个中间件都可以在请求到达最终处理程序之前或之后执行操作。中间件可以用于实现各种功能&#xff0c;如身份验证、路由、…

【三维数域】三维数据调度-负载均衡和资源优化

在处理大规模三维数据时&#xff0c;负载均衡和资源优化是确保系统高效运行、提供流畅用户体验的关键。这两者不仅影响到系统的性能和稳定性&#xff0c;还直接决定了用户交互的质量。以下是关于如何在三维数据调度中实现有效的负载均衡和资源优化的详细探讨。 一、负载均衡 负…

AI大模型开发—1、百度的千帆大模型调用(文心一言的底层模型,ENRIE等系列)、API文档目的地

文章目录 前言一、千帆大模型平台简介二、百度平台官网初使用1、平台注册和使用2、应用注册 并 申请密钥3、开启千帆大模型 API调用a、API文档b、 前言 本章旨在为读者奉献一份实用的操作指南&#xff0c;深入探索如何高效利用百度千帆大模型平台的卓越功能。我们将从账号注册…

Java Stream流操作List全攻略:Filter、Sort、GroupBy、Average、Sum实践

在Java 8及更高版本中&#xff0c;Stream API为集合处理带来了革命性的改变。本文将深入解析如何运用Stream对List进行高效的操作&#xff0c;包括筛选&#xff08;Filter&#xff09;、排序&#xff08;Sort&#xff09;、分组&#xff08;GroupBy&#xff09;、求平均值&…

《视听导报》是什么类型的报纸?报纸上发文章要交版面费吗?

作为个人成果发表的重要场所&#xff0c;报纸目前正得到越来越多单位的认可。不过在投稿时&#xff0c;我们既要考虑投稿的报纸是否符合评审标准&#xff0c;也要考虑发表文章的成本是否在我们的承受范围之内。 下面就让我们以《视听导报》为例&#xff0c;了解下如何查看报纸的…

candb++ windows11运行报错,找不到mfc140.dll

解决问题记录 mfc140.dll下载 注意&#xff1a;放置位置别搞错了

服务器引导异常,Grub报错: error: ../../grub-core/fs/fshelp.c:258:file xxxx.img not found.

服务器引导异常,Grub报错: error: ../../grub-core/fs/fshelp.c:258:file xxxx.img not found. 1. 故障现象2. 解决思路3. 故障分析4. 案件回溯5. 解决问题 1. 故障现象 有一台服务器业务报无法连接. 尝试用Ping命令发现无法ping通. 通过控制台查看发现有以下报错: error: ..…

LeetCode第432场周赛 (前3题|多语言)

比赛链接&#xff1a;第432场周赛 文章目录 3417. 跳过交替单元格的之字形遍历思路代码CJavaPython 3418. 机器人可以获得的最大金币数思路代码CJavaPython 3419. 图的最大边权的最小值思路代码CJavaPython 总结 3417. 跳过交替单元格的之字形遍历 思路 没啥好说的就是模拟 按…

下载导出Tomcat上的excle文档,浏览器上显示下载

目录 1.前端2.Tomcat服务器内配置3.在Tomcat映射的文件内放置文件4.重启Tomcat&#xff0c;下载测试 1.前端 function downloadFile() {let pictureSourceServer "http://192.168.1.1:8080/downFile/";let fileName "测试文档.xlsx";let fileURL pictu…