逻辑回归模型预测

news2024/10/6 1:39:25

范例题目:
建立一个逻辑回归模型预测一个学生是否通过研究生入学考试。N表示训练集中学生个数,Score1、Score2、 Score3、 Score4是N维数组,分别表示N个学生研究生初试、专业课复试、英语面试成绩、专业课面试成绩。Admitted是N维{0,1}数组,1代表被录取,0代表未被录取。给出逻辑回归的参数结构、初始化过程、损失函数(经验风险)设置,基于随机梯度下降和梯度下降的参数学习过程。

数学推导

逻辑回归是一种二元分类算法,可以用于预测一个学生是否通过研究生入学考试。以下是逻辑回归模型的参数结构和学习过程。

参数结构:
逻辑回归模型的参数结构包括权重向量 w w w 和偏置 b b b。对于一个具有 n n n 个特征的样本 x = ( x 1 , x 2 , . . . , x n ) x=(x_1,x_2,...,x_n) x=(x1,x2,...,xn),模型的预测输出 y y y 可以表示为:
y = σ ( w 1 x 1 + w 2 x 2 + . . . + w n x n + b ) y = \sigma(w_1x_1 + w_2x_2 + ... + w_nx_n + b) y=σ(w1x1+w2x2+...+wnxn+b)
其中, σ ( x ) \sigma(x) σ(x) 是sigmoid函数,可以将任意实数映射到区间 ( 0 , 1 ) (0,1) (0,1) 上,定义为:
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1 + e^{-x}} σ(x)=1+ex1
初始化过程:
为了学习逻辑回归模型的参数 w w w b b b,需要对它们进行初始化。常见的初始化方法是将 w w w 初始化为0向量, b b b 初始化为0标量。

损失函数(经验风险)设置:
逻辑回归模型的损失函数通常采用交叉熵损失函数。对于一个样本 ( x , y ) (x, y) (x,y),其交叉熵损失函数可以表示为:
J ( w , b ) = − ( y log ⁡ ( y ^ ) + ( 1 − y ) log ⁡ ( 1 − y ^ ) ) J(w,b)=-\left(y\log \left(\hat{y}\right)+(1-y)\log \left(1-\hat{y}\right)\right) J(w,b)=(ylog(y^)+(1y)log(1y^))
其中, y ^ = σ ( w 1 x 1 + w 2 x 2 + . . . + w n x n + b ) \hat{y}=\sigma(w_1x_1 + w_2x_2 + ... + w_nx_n + b) y^=σ(w1x1+w2x2+...+wnxn+b) 是模型的预测输出, y y y 是样本的真实标签。

基于随机梯度下降的参数学习过程:
在基于随机梯度下降的参数学习过程中,模型每次随机选取一个样本 ( x ( i ) , y ( i ) ) (x^{(i)},y^{(i)}) (x(i),y(i)) 进行训练。算法的更新过程如下:
w j ← w j − α ∂ J ∂ w j = w j + α ( y ( i ) − y ^ ( i ) ) x j ( i ) w_j \leftarrow w_j - \alpha \frac{\partial J}{\partial w_j} = w_j + \alpha(y^{(i)}-\hat{y}^{(i)})x_j^{(i)} wjwjαwjJ=wj+α(y(i)y^(i))xj(i) b ← b − α ∂ J ∂ b = b + α ( y ( i ) − y ^ ( i ) ) b \leftarrow b - \alpha \frac{\partial J}{\partial b} = b + \alpha(y^{(i)}-\hat{y}^{(i)}) bbαbJ=b+α(y(i)y^(i))
其中, α \alpha α 是学习率, y ^ ( i ) = σ ( w 1 x 1 ( i ) + w 2 x 2 ( i ) + . . . + w n x n ( i ) + b ) \hat{y}^{(i)}=\sigma(w_1x_1^{(i)} + w_2x_2^{(i)} + ... + w_nx_n^{(i)} + b) y^(i)=σ(w1x1(i)+w2x2(i)+...+wnxn(i)+b) 是模型对样本 x ( i ) x^{(i)} x(i) 的预测输出。

基于梯度下降的参数学习过程:
在基于梯度下降的参数学习过程中,模型在每一轮迭代中使用整个训练集进行训练。算法的更新过程如下:
b ← b − α 1 N ∑ i = 1 N ( y ( i ) − y ^ ( i ) ) b \leftarrow b - \alpha \frac{1}{N} \sum_{i=1}^N (y^{(i)}-\hat{y}^{(i)}) bbαN1i=1N(y(i)y^(i))
其中, α \alpha α 是学习率, N N N 是训练集中样本的个数, y ^ ( i ) = σ ( w 1 x 1 ( i ) + w 2 x 2 ( i ) + . . . + w n x n ( i ) + b ) \hat{y}^{(i)}=\sigma(w_1x_1^{(i)} + w_2x_2^{(i)} + ... + w_nx_n^{(i)} + b) y^(i)=σ(w1x1(i)+w2x2(i)+...+wnxn(i)+b) 是模型对样本 x ( i ) x^{(i)} x(i) 的预测输出。

以上是逻辑回归模型的参数结构、初始化过程、损失函数设置,以及基于随机梯度下降和梯度下降的参数学习过程。在实际应用中,可以根据具体情况调整学习率和训练轮数等超参数,以获得更好的模型性能。

代码实现

代码基于torch实现,其中注释详尽,自行查阅
以下代码首先定义了一个逻辑回归模型类 LogisticRegression,包含一个全连接层和一个 Sigmoid 激活函数。训练函数 train 使用随机梯度下降优化器和二元交叉熵损失函数对模型进行训练,最后测试模型的准确率,并进行了简易的可视化。
在这里插入图片描述

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 定义逻辑回归模型类
class LogisticRegression(nn.Module):
    def __init__(self, num_features):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(num_features, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

# 训练函数
def train(X, y, model, learning_rate, num_epochs, batch_size):
    criterion = nn.BCELoss() # 二元交叉熵损失函数
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # 随机梯度下降优化器

    dataset = torch.utils.data.TensorDataset(X, y)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epochs):
        for inputs, targets in dataloader:
            # 前向传播
            outputs = model(inputs)

            # 计算损失函数值并反向传播
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if (epoch+1) % 100 == 0:
            print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# 数据准备
np.random.seed(42)
N = 1000
X = np.random.rand(N, 2) * 4 - 2
y = np.zeros((N, 1))
y[np.sum(X ** 2, axis=1) <= 1.5] = 1
X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()

# 定义模型和超参数
num_features = 2
learning_rate = 1e-4
num_epochs = 1000
batch_size = 10
model = LogisticRegression(num_features)

# 模型训练
train(X, y, model, learning_rate, num_epochs, batch_size)

# 绘制决策边界和数据点
with torch.no_grad():
    X_grid = np.meshgrid(np.linspace(-2, 2, 100), np.linspace(-2, 2, 100))
    X_test = torch.from_numpy(np.array([X_grid[0].ravel(), X_grid[1].ravel()]).T).float()
    y_pred = model(X_test).detach().numpy().reshape(X_grid[0].shape)
    y_pred = np.where(y_pred >= 0.5, 1, 0)

    outputs = model(X)
    predicted = (outputs >= 0.5).float()
    accuracy = (predicted == y).float().mean()
    print('Accuracy: {:.2f}%'.format(accuracy.item() * 100))

plt.contourf(X_grid[0], X_grid[1], y_pred, alpha=0.5)
plt.scatter(X[:, 0], X[:, 1], c=y[:, 0], cmap='bwr')
plt.title('Logistic Regression')
plt.xlabel('X1')
plt.ylabel('X2')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/443768.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RabbitMQ编程模型

RabbitMQ基础概念 RabbitMQ是基于AMQP协议开发的一个MQ产品。 虚拟主机 virtual host RabbitMQ出于服务器复用的想法&#xff0c;可以在一个RabbitMQ集群中划分出多个虚拟主机&#xff0c;每一个虚拟主机都有AMQP的全套基础组件&#xff0c;并且可以针对每个虚拟主机进行权…

面向对象(高级)-Annotation注解、单元测试的使用

注解&#xff08;Annotation&#xff09; 注解大纲 注解的使用1.Annotation的理解 - 注解&#xff08;Annotation&#xff09;是从JDK5.0开始引入&#xff0c;以注解名在代码中存在。 - Annotation可以像修饰符一样被使用&#xff0c;可用于修饰包、类、构造器、方法、成员变…

LeetCode:59. 螺旋矩阵 II

&#x1f34e;道阻且长&#xff0c;行则将至。&#x1f353; &#x1f33b;算法&#xff0c;不如说它是一种思考方式&#x1f340; 算法专栏&#xff1a; &#x1f449;&#x1f3fb;123 一、&#x1f331;59. 螺旋矩阵 II 题目描述&#xff1a;给你一个正整数 n &#xff0c…

Python中类属性和类方法

1. 类的结构 1.1 术语 —— 实例 使用面相对象开发&#xff0c;第 1 步 是设计 类使用 类名() 创建对象&#xff0c;创建对象 的动作有两步&#xff1a; (1) 在内存中为对象 分配空间 (2) 调用初始化方法 __init__ 为 对象初始化对象创建后&#xff0c;内存 中就有了一个对象…

【敲敲云】零代码实战,主子表汇总统计—免费的零代码产品

近来很多朋友在使用敲敲云时&#xff0c;不清楚如何使用主子表&#xff0c;及如何在主表中统计子表数据&#xff1b;下面我们就以《订单》表及《订单明细》表来设计一下吧&#xff0c;用到的组件有“设计子表”、“公式”、“汇总”等。 《订单》表展示 总金额 订单明细中“小…

C++ Linux Web Server 面试基础篇-操作系统(三、进程通信)

⭐️我叫忆_恒心,一名喜欢书写博客的在读研究生👨‍🎓。 如果觉得本文能帮到您,麻烦点个赞👍呗! 近期会不断在专栏里进行更新讲解博客~~~ 有什么问题的小伙伴 欢迎留言提问欧,喜欢的小伙伴给个三连支持一下呗。👍⭐️❤️ Qt5.9专栏定期更新Qt的一些项目Demo 项目与…

恢复调度平台mysql主从同步

修复问题 调度平台两台MySQL从节点存在Slave_SQL_Running异常&#xff0c;需要恢复。 部署步骤 一、先停止调度平台core服务与web服务&#xff0c;否则无法正常锁表 1.1停止调度平台core服务 2.1停止web服务 3.确认MySQL所有执行线程是否都已经停止 show processlist; 如…

小红书流量规则是什么,推荐机制解读

当今的互联网自媒体世界&#xff0c;说到底还是一个流量时代&#xff0c;一个流量为王的时代。不管你在小红书也好&#xff0c;还是其他自媒体平台都需要知晓平台的流量规则。今天和大家分享下小红书流量规则是什么&#xff0c;让我们一起通过流量规则分析小红书机制和算法。 一…

人工智能如何助力建筑设计自动化?

ChatGPT和DALL-E等工具使用大规模机器学习(ML)模型&#xff0c;并访问大量有标记和有意义的数据&#xff0c;以对文本和图像中的查询提供有见解的响应。但是&#xff0c;一些行业对训练ML模型的数据集的访问有限&#xff0c;这使得使用生成式AI来解决现实世界问题的好处很难获得…

书写我的人生回忆录-这应该是给父母最好的礼物

作为一个业余的软件开发爱好者&#xff0c;我又捣鼓了一个有意思的小东西 &#xff0c;使用完全免费哈 《书写我的人生回忆录》是一款软件&#xff0c;其中包含70个问题&#xff0c;涵盖了父母的个人喜好、家庭、工作、人生经历和态度等方面。通过回答这些问题&#xff0c;您的…

爬虫请求头Content-Length的计算方法

重点&#xff1a;使用node.js 环境计算&#xff0c;同时要让计算的数据通过JSON.stringify从对象变成string。 1. Blob size var str 中国 new Blob([str]).size // 6 2、Buffer.byteLength # node > var str 中国 undefined > Buffer.byteLength(str, utf8) 6 原文…

Spring开启事务流程和事务相关配置

文章目录 Spring事务Spring快速入门事务相关配置 Spring事务 Spring快速入门 事务作用&#xff1a;在数据层保障一系列的数据库操作同成功同失败 Spring事务作用&#xff1a;在数据层或业务层保障一系列的数据库操作同成功同失败 Spring提供了一个接口PlatformTransactionMa…

Vue可视化项目搭建

安装Nodejs 全局下载Vue项目脚手架 创建项目 运行项目 项目初始化 安装Nodejs 下载地址&#xff1a;https://nodejs.org/zh-cn/ 下载完成之后一路点击下一个安装 全局下载Vue项目脚手架 进入开始菜单以管理员身份运行命令提示符 输入更换镜像源为淘宝源 npm config s…

java线程屏障CyclicBarrier

CyclicBarrier允许一组线程在达到一个公共的屏障点时相互等待。它在涉及固定大小的线程组、并且这些线程必须相互等待的程序中非常有用&#xff0c;CyclicBarrier可以在等待的线程被释放后被重用。 构造方法 CyclicBarrier(int parties) 创建一个新的屏障并设置将要访问这个…

问卷调查样本量的确定方法

我们在进行问卷调查的时候&#xff0c;问卷的收集数量是重要的流程之一。问卷数量取决于几个因素&#xff0c;包括研究的目的和研究的类型。接下来&#xff0c;我们就聊一聊怎么确定所需的调查问卷数量。 1、确定研究目标。 确定所需问卷数量的第一步是明确研究目标。这一步是…

jar包依赖冲突该怎么解决(IT枫斗者)

jar包依赖冲突该怎么解决&#xff08;IT枫斗者&#xff09; maven jar包依赖规则 间接依赖路径最短优先一个项目依赖了a和b两个jar包&#xff0c;其中a-b-c1.0&#xff0c;d-e-c1.0,由于c1.0路径最短&#xff0c;所以项目最后使用的jar包是c1.0pom文件中申明顺序优先有人就问…

使用三轴XYZ平台绘制空心字

1. 功能说明 本文示例将实现R312三轴XYZ平台绘制“机器时代”空心字的功能。 2. 电子硬件 在这个示例中&#xff0c;采用了以下硬件&#xff0c;请大家参考&#xff1a; 主控板 Basra主控板&#xff08;兼容Arduino Uno&#xff09; 扩展板 Bigfish2.1扩展板 SH-ST步进电机扩展…

2023年最系统的自动化测试,测试开发面试题,10k以下不建议看

鉴于现在严峻的就业形势&#xff0c;千万大学生即将出新手村&#xff0c;今天给大家打包好了2023最能避免薪资倒挂的《面试圣经》。不经一番寒彻骨,怎得梅花扑鼻香。这份面试题&#xff0c;与君共勉&#xff01; 一、开场白 Q&#xff1a;简单自我介绍一下吧 Q&#xff1a;项…

Bots攻击威胁石油石化企业 瑞数动态安全实现从“人防”到“技防”

近日&#xff0c;中国石油石化企业信息技术交流大会暨油气产业数字化转型高峰论坛在京召开。本届大会由中国石油学会、中国石油、中国石化、中国海油、国家管网、国家能源、中国中化、中国航油、延长石油、中国地质调查局等单位共同主办。 作为我国石油石化行业的盛会&#xf…

论坛现场回顾:维视教育的新工科人才培养 「最佳实践 」

全国高校电子信息类专业教学论坛隆重召开 由教育部高等学校电子信息类专业教学指导委员会主办&#xff0c;苏州大学、清华大学出版社承办的“全国高校电子信息类专业教学论坛”于2023年4月14日-16日在江苏省苏州市隆重开幕&#xff0c;维视教育作为电子信息类教学指导委员会战略…