基于Python的机器学习系列(7):多元逻辑回归

news2024/9/21 21:15:05

        在本篇博文中,我们将探讨多元逻辑回归,它是一种扩展的逻辑回归方法,适用于分类数量超过两个的场景。与二元逻辑回归不同,多元逻辑回归使用Softmax函数将多个类别的概率输出映射到[0, 1]范围内,并确保所有类别的概率和为1。本文将通过具体的代码实现详细介绍多元逻辑回归的工作机制。

1. Softmax函数

        在多元逻辑回归中,我们使用Softmax函数来处理多类别的概率分布。Softmax函数可以将模型的线性输出转化为各类别的概率值。公式如下:

        其中c表示类别,k表示类别的总数。

Softmax函数的作用如下:

  • 它将每个类别的得分转化为非负的概率值。
  • Softmax函数通过分母的求和操作确保所有类别的概率和为1。
  • Softmax函数的分子和分母都包含指数函数$e$,这使得其导数易于计算,并能与交叉熵损失函数配合良好。
def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)

X = np.array([[1, 2, 3],
             [2, 4, 5]])

Y = np.array([[0, 0, 1, 0],
              [1, 0, 0, 0]])  # one-hot encoded classes

W = np.array([[1, 2, 3, 4],
              [2, 3, 1, 0],
              [1, 2, 5, 1]])

softmax_output = softmax(X @ W)
print("Softmax Output:\n", softmax_output)
print("Check sum of probabilities:", softmax_output.sum(axis=1))
2. 交叉熵损失函数

        在多元逻辑回归中,我们继续使用交叉熵作为损失函数。对于每个样本,交叉熵损失可以表示为:

        与二元逻辑回归的交叉熵损失不同,损失函数现在对所有类别进行求和,从而扩展到多类场景。

def cross_entropy_loss(Y, h):
    return - np.sum(Y * np.log(h))

loss = cross_entropy_loss(Y, softmax_output)
print("Cross Entropy Loss:", loss)
3. 梯度计算

        对于每个参数theta,损失函数J的梯度计算公式为:

        其中H和Y分别为预测值矩阵和真实值的one-hot编码矩阵。

        通过链式法则,我们可以详细推导出每个参数的梯度,并将其应用于梯度下降算法中。

4. 多元逻辑回归的实现

        下面是多元逻辑回归的完整实现步骤:

  1. 准备数据,包括添加截距项、one-hot编码标签、标准化特征等。
  2. 使用Softmax函数进行预测,并计算交叉熵损失。
  3. 基于损失函数计算梯度,并更新参数。
  4. 迭代上述步骤,直至达到收敛条件。
class LogisticRegression:
    
    def __init__(self, k, n, method, alpha=0.001, max_iter=5000):
        self.k = k
        self.n = n
        self.alpha = alpha
        self.max_iter = max_iter
        self.method = method
    
    def fit(self, X, Y):
        self.W = np.random.rand(self.n, self.k)
        self.losses = []
        
        if self.method == "batch":
            for i in range(self.max_iter):
                loss, grad = self.gradient(X, Y)
                self.losses.append(loss)
                self.W -= self.alpha * grad
                if i % 500 == 0:
                    print(f"Loss at iteration {i}: {loss}")
        
        elif self.method == "minibatch":
            batch_size = int(0.3 * X.shape[0])
            for i in range(self.max_iter):
                ix = np.random.randint(0, X.shape[0] - batch_size)
                batch_X = X[ix:ix + batch_size]
                batch_Y = Y[ix:ix + batch_size]
                loss, grad = self.gradient(batch_X, batch_Y)
                self.losses.append(loss)
                self.W -= self.alpha * grad
                if i % 500 == 0:
                    print(f"Loss at iteration {i}: {loss}")
        
        else:
            raise ValueError('Method must be "batch" or "minibatch".')
        
    def gradient(self, X, Y):
        m = X.shape[0]
        h = self.h_theta(X)
        loss = - np.sum(Y * np.log(h)) / m
        error = h - Y
        grad = X.T @ error
        return loss, grad

    def h_theta(self, X):
        return self.softmax(X @ self.W)

    def softmax(self, x):
        return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
    
    def predict(self, X):
        return np.argmax(self.h_theta(X), axis=1)
    
    def plot_losses(self):
        plt.plot(self.losses)
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.title('Loss over time')
        plt.show()

# 使用模型进行训练
model = LogisticRegression(k=3, n=X_train.shape[1], method="minibatch")
model.fit(X_train, Y_train_encoded)
model.plot_losses()

# 预测并评估
y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred))
结语

        在本文中,我们探讨了多元逻辑回归,扩展了二元逻辑回归的概念,以适应多个类别的分类任务。通过软最大化函数(softmax)和交叉熵损失函数,我们能够有效地训练模型,并根据多个类别预测结果的概率分布来进行分类。这些技术在多类别分类问题中具有广泛的应用。

        通过理解多元逻辑回归的数学推导和代码实现,我们可以更好地应对实际的分类任务,尤其是在数据维度较高、类别较多的情况下。

        敬请期待下一篇博文:基于Python的机器学习系列(8):Newton Raphson逻辑回归,我们将探索一种更为高级的优化方法,用于进一步提高逻辑回归的效率和性能。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2068852.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用漏洞实现 Outlook 的 RCE:第一部分

概述 2023 年 3月补丁星期二解决的漏洞中,有一个是Outlook 的一个严重漏洞,编号为CVE-2023-23397,该漏洞被 Forest Blizzard 在野利用,微软已将其确定为俄罗斯国家支持的威胁行为者。2023 年 12 月,微软与波兰网络司令部 (DKWOC) 联合发布消息称,他们发现同一威胁行为者…

Debug-023-Document.createElement()的使用

Document.createElement() document.createElement()是在对象中创建一个对象,要与appendChild() 或 insertBefore()方法联合使用。 appendChild() 方法在节点的子节点列表末添加新的子节点。 insertBefore() 方法在节点的子节点列表任意位置插入新的节点。 用途举…

Linux -- git

1 啥是git git是一个代码的历史版本管理工具,通过用树形结构管理一个代码版本可以快速实现回滚等操作 1.1 git基本概念 工作区(Working Directory/Working Tree): 这是你当前正在处理项目文件的地方。你可以在工作区中创建、修改…

非关系型数据库MongoDB(文档型数据库)介绍与使用实例

MongoDB介绍 MongoDB是一种开源的文档型数据库管理系统,它使用类似于JSON的BSON格式(Binary JSON)来存储数据。与传统关系型数据库不同,MongoDB不使用表和行的结构,而是采用集合(Collection)(My…

漏洞发现——漏洞扫描工具的对比

本帖字的实验环境是来自学校的靶机 文章目录 Xray介绍安装教程使用教程主动扫描单个url扫描批量扫描 被动扫描联合游览器联合burpsuite Awvs介绍安装教程使用教程联合xary三者联合bp和xray Goby介绍安装教程使用教程 Afrog介绍安装教程使用教程 Vulmap介绍安装教程使用教程 Poc…

SpringMVC核心机制环境搭建

文章目录 1.SpringMVC执行流程1.基础流程图2.详细流程图 2.安装Tomcat1.下载2.解压到任意目录即可3.IDEA配置Tomcat1.配置Deloyment2.配置Server 3.创建maven项目1.创建sun-springmvc模块(webapp)2.查看是否被父模块管理3.pom.xml引入依赖4.目录5.SunDis…

电子电气架构--- 智能汽车电子架构的核心诉求

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不…

Android点击和触摸音量小的问题(问题追踪)

有客户反馈:A14触摸声音没有 于是乎,追踪setting打开触摸声音的代码: Overridepublic boolean onPreferenceTreeClick(PreferenceScreen preferenceScreen, Preference preference) {if (preference mVibrateWhenRinging) {Settings.System…

Linux | 进程优先级进程的环境变量

文章目录 进程概念4、进程优先级4.1基本概念4.2查看系统进程4.2.1 ps -l4.2.2 PRI & NI 4.3用top命令更改已存在进程的nice: 5、环境变量5.1常见环境变量5.2查看环境变量5.3测试PATH配置环境变量 5.4代码中获取环境变量5.4代码中获取环境变量 进程概念 4、进程…

RFID 智慧城市书房:开启智慧阅读新时代

在当今数字化、智能化的时代,人们对于阅读的需求和体验也在不断升级。RFID 智慧城市书房的出现,为满足人们对高品质阅读环境的追求提供了全新的途径。 一、RFID 技术:智慧城市书房的核心支撑 RFID,即射频识别技术,是一…

DDS IP实现啁啾信号

简介 DDS(Direct Digital Synthesizer)即数字合成器,是一种新型的频率合成技术,具有低成本、低功耗、高分辨率、频率转换时间短、相位连续性好等优点,对数字信号处理及其硬件实现有着很重要的作用。DDS 的基本…

18945 小团的配送团队

### 思路 1. **建图**:将订单视为图的节点,已知关系视为图的边,构建无向图。 2. **连通分量**:使用深度优先搜索(DFS)或广度优先搜索(BFS)找到图中的所有连通分量。 3. **排序**&…

探索人工智能的未来:埃里克·施密特2024斯坦福大学分享六

代理与文本生成模型的未来展望 您认为明年代理或文本生成模型会出现通货膨胀点吗? 不,不会。 我听到了类似的观点,尤其是埃里克科维茨的看法。他有一个很好的方式来阐述这三个趋势。虽然我之前也听说过这些趋势,但将它们整合起…

C语言破墙镐对称飞迷宫

目录 开头程序程序的流程图程序游玩的效果(gif)结尾 开头 大家好&#xff0c;我叫这是我58。 程序 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <stdlib.h> #include <Windows.h> enum WASD {W,A,S,D }; void printmaze(const char s…

【CTF Web】CTFShow cookie泄露 Writeup(cookie泄露+URL解码)

cookie泄露 10 cookie 只是一块饼干&#xff0c;不能存放任何隐私数据 解法 按 F12 打开开发者工具&#xff0c;点击网络&#xff0c;刷新页面。 flag 在 响应标头的 Set-Cookie 中。 用 URL 解码工具转换。 Flag ctfshow{8483acdb-a677-4c77-8aff-438d44ff1a3e}声明 本博客…

论文翻译软件哪个好用?如何将论文转化?

在学术海洋里遨游&#xff0c;每一篇论文都是思想的灯塔。 但当这座灯塔用外语构建&#xff0c;如何让它在中国读者面前同样熠熠生辉&#xff1f;别担心&#xff0c;把论文翻译成中文的旅程&#xff0c;不仅可以轻松启航&#xff0c;还能优雅靠岸&#xff01; 不知道怎么把论…

【Android笔记】Android APK编译打包流程

前言 本文将介绍Android从一个项目打包成APK的过程&#xff0c;其中涉及Android Java和Kotlin文件、资源文件、清单文件、依赖jar包和so库等在打包过程中处理。 步骤 总体的打包流程如下图&#xff0c;下面就介绍下详细的打包步骤。 1、将aidl文件编译成java文件 在构建过程中…

OpenAI API VBA function returns #Value! but MsgBox displays response

题意&#xff1a;“OpenAI API VBA 函数返回 #Value!&#xff0c;但 MsgBox 显示响应” 问题背景&#xff1a; I am trying to integrate the OpenAI API into Excel. The http request to OpenAI chat completion works correctly and the response is OK. When I display it…

esp32c3 luaos

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、介绍二、相关介绍2.1helloworld——2.2任务框架2.3消息传递 与消息订阅2.4uart2.5二进制数据/c结构体的打包与解析2.6 zbuffer库2.8 uart 485 数据解析2.9 …

Ubuntu 20.04安装中文输入法

本文旨在详细介绍在Ubuntu 20.04操作系统中安装中文输入法的步骤和方法。我们将从选择适合的中文输入法软件、下载与安装过程、配置输入法设置以及解决可能遇到的问题等方面展开讲解&#xff0c;帮助用户轻松实现在Ubuntu 20.04系统下流畅输入中文的需求。无论你是Ubuntu的新手…