Logistic回归

news2025/1/12 15:46:56

通常,Logistic回归用于二分类问题,例如预测明天是否会下雨。当然它也可以用于多分类问题.

Logistic回归是分类方法,它利用的是Sigmoid函数阈值在[0,1]这个特性。Logistic回归进行分类的主要思想是:根据现有数据对分类边界线建立回归公式,以此进行分类。其实,Logistic本质上是一个基于条件概率的判别模型(Discriminative Model)。

所以要想了解Logistic回归,我们必须先看一看Sigmoid函数 ,我们也可以称它为Logistic函数。它的公式如下:

整合成一个公式,就变成了如下公式

 

θ是参数列向量(要求解的),x是样本列向量(给定的数据集).这样我们的数据集([x0,x1,...,xn]),不管是大于1或者小于0,都可以映射到[0,1]区间进行分类。hθ(x)给出了输出为1的概率

那么问题来了!如何得到合适的参数向量θ?

式即为在已知样本x和参数θ的情况下,样本x属于正样本(y=1)和负样本(y=0)的条件概率。理想状态下,根据上述公式,求出各个点的概率均为1,也就是完全分类都正确。但是考虑到实际情况,样本点的概率越接近于1,其分类效果越好。比如一个样本属于正样本的概率为0.51,那么我们就可以说明这个样本属于正样本。另一个样本属于正样本的概率为0.99,那么我们也可以说明这个样本属于正样本。但是显然,第二个样本概率更高,更具说服力。我们可以把上述两个概率公式合二为一:

为了简化问题,我们对整个表达式求对数

 这个损失函数,是对于一个样本而言的。给定一个样本,我们就可以通过这个损失函数求出,样本所属类别的概率,而这个概率越大越好,所以也就是求解这个损失函数的最大值。既然概率出来了,那么最大似然估计也该出场了。假定样本与样本之间相互独立,那么整个样本集生成的概率即为所有样本生成概率的乘积,便可得到如下公式

 其中,m为样本的总数,y(i)表示第i个样本的类别,x(i)表示第i个样本,需要注意的是θ是多维向量,x(i)也是多维向量。

综上所述,满足J(θ)的最大的θ值即是我们需要求解的模型。

怎么求解使J(θ)最大的θ值呢?因为是求最大值,所以我们需要使用梯度上升算法。如果面对的问题是求解使J(θ)最小的θ值,那么我们就需要使用梯度下降算法。

def Gradient_Ascent_test():
    def f_prime(x_old):  # f(x)的导数
        return -2 * x_old + 4

    x_old = -1  # 初始值,给一个小于x_new的值
    x_new = 0  # 梯度上升算法初始值,即从(0,0)开始
    alpha = 0.01  # 步长,也就是学习速率,控制更新的幅度
    presision = 0.00000001  # 精度,也就是更新阈值
    while abs(x_new - x_old) > presision:
        x_old = x_new
        x_new = x_old + alpha * f_prime(x_old)  # 上面提到的公式
    print(x_new)  # 打印最终求解的极值近似值

if __name__ == '__main__':
    Gradient_Ascent_test()

目标函数

 

 

 代码

-0.017612   14.053064  0
-1.395634  4.662541   1
-0.752157  6.538620   0
-1.322371  7.152853   0
0.423363   11.054677  0
0.406704   7.067335   1
0.667394   12.741452  0
-2.460150  6.866805   1
0.569411   9.548755   0
-0.026632  10.427743  0
0.850433   6.920334   1
1.347183   13.175500  0
1.176813   3.167020   1
-1.781871  9.097953   0
-0.566606  5.749003   1
0.931635   1.589505   1
-0.024205  6.151823   1
-0.036453  2.690988   1
-0.196949  0.444165   1
1.014459   5.754399   1
1.985298   3.230619   1
-1.693453  -0.557540  1
-0.576525  11.778922  0
-0.346811  -1.678730  1
-2.124484  2.672471   1
1.217916   9.597015   0
-0.733928  9.098687   0
-3.642001  -1.618087  1
0.315985   3.523953   1
1.416614   9.619232   0
-0.386323  3.989286   1
0.556921   8.294984   1
1.224863   11.587360  0
-1.347803  -2.406051  1
1.196604   4.951851   1
0.275221   9.543647   0
0.470575   9.332488   0
-1.889567  9.542662   0
-1.527893  12.150579  0
-1.185247  11.309318  0
-0.445678  3.297303   1

 数据可视化

import matplotlib.pyplot as plt
import numpy as np

"""
函数说明:加载数据

Parameters:
    无
Returns:
    dataMat - 数据列表
    labelMat - 标签列表
"""


def loadDataSet():
    dataMat = []  # 创建数据列表
    labelMat = []  # 创建标签列表
    fr = open('testSet.txt')  # 打开文件
    for line in fr.readlines():  # 逐行读取
        lineArr = line.strip().split()  # 去回车,放入列表
        dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])  # 添加数据
        labelMat.append(int(lineArr[2]))  # 添加标签
    fr.close()  # 关闭文件
    return dataMat, labelMat  # 返回

"""
函数说明:绘制数据集

Parameters:
    无
Returns:
    无
"""

def plotDataSet():
    dataMat, labelMat = loadDataSet()  # 加载数据集
    dataArr = np.array(dataMat)  # 转换成numpy的array数组
    n = np.shape(dataMat)[0]  # 数据个数
    xcord1 = [];
    ycord1 = []  # 正样本
    xcord2 = [];
    ycord2 = []  # 负样本
    for i in range(n):  # 根据数据集标签进行分类
        if int(labelMat[i]) == 1:
            xcord1.append(dataArr[i, 1]);
            ycord1.append(dataArr[i, 2])  # 1为正样本
        else:
            xcord2.append(dataArr[i, 1]);
            ycord2.append(dataArr[i, 2])  # 0为负样本
    fig = plt.figure()
    ax = fig.add_subplot(111)  # 添加subplot
    ax.scatter(xcord1, ycord1, s=20, c='red', marker='s', alpha=.5)  # 绘制正样本
    ax.scatter(xcord2, ycord2, s=20, c='green', alpha=.5)  # 绘制负样本
    plt.title('DataSet')  # 绘制title
    plt.xlabel('x');
    plt.ylabel('y')  # 绘制label
    plt.show()  # 显示


if __name__ == '__main__':
    plotDataSet()

 

# -*- coding:UTF-8 -*-
import numpy as np

def loadDataSet():
    dataMat = []  # 创建数据列表
    labelMat = []  # 创建标签列表
    fr = open('testSet.txt')  # 打开文件
    for line in fr.readlines():  # 逐行读取
        lineArr = line.strip().split()  # 去回车,放入列表
        dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])  # 添加数据
        labelMat.append(int(lineArr[2]))  # 添加标签
    fr.close()  # 关闭文件
    return dataMat, labelMat  # 返回


def sigmoid(inX):
    return 1.0 / (1 + np.exp(-inX))

"""
函数说明:梯度上升算法

Parameters:
    dataMatIn - 数据集
    classLabels - 数据标签
Returns:
    weights.getA() - 求得的权重数组(最优参数)
"""

def gradAscent(dataMatIn, classLabels):
    dataMatrix = np.mat(dataMatIn)  # 转换成numpy的mat
    labelMat = np.mat(classLabels).transpose()  # 转换成numpy的mat,并进行转置
    m, n = np.shape(dataMatrix)  # 返回dataMatrix的大小。m为行数,n为列数。
    alpha = 0.001  # 移动步长,也就是学习速率,控制更新的幅度。
    maxCycles = 500  # 最大迭代次数
    weights = np.ones((n, 1))
    for k in range(maxCycles):
        h = sigmoid(dataMatrix * weights)  # 梯度上升矢量化公式
        error = labelMat - h
        weights = weights + alpha * dataMatrix.transpose() * error
    return weights.getA()  # 将矩阵转换为数组,返回权重数组

if __name__ == '__main__':
    dataMat, labelMat = loadDataSet()
    print(gradAscent(dataMat, labelMat))

 

import matplotlib.pyplot as plt
import numpy as np
def loadDataSet():
    dataMat = []  # 创建数据列表
    labelMat = []  # 创建标签列表
    fr = open('testSet.txt')  # 打开文件
    for line in fr.readlines():  # 逐行读取
        lineArr = line.strip().split()  # 去回车,放入列表
        dataMat.append([1.0, float(lineArr[0]), float(lineArr[1])])  # 添加数据
        labelMat.append(int(lineArr[2]))  # 添加标签
    fr.close()  # 关闭文件
    return dataMat, labelMat  # 返回

def sigmoid(inX):
    return 1.0 / (1 + np.exp(-inX))

def gradAscent(dataMatIn, classLabels):
    dataMatrix = np.mat(dataMatIn)  # 转换成numpy的mat
    labelMat = np.mat(classLabels).transpose()  # 转换成numpy的mat,并进行转置
    m, n = np.shape(dataMatrix)  # 返回dataMatrix的大小。m为行数,n为列数。
    alpha = 0.001  # 移动步长,也就是学习速率,控制更新的幅度。
    maxCycles = 500  # 最大迭代次数
    weights = np.ones((n, 1))
    for k in range(maxCycles):
        h = sigmoid(dataMatrix * weights)  # 梯度上升矢量化公式
        error = labelMat - h
        weights = weights + alpha * dataMatrix.transpose() * error
    return weights.getA()  # 将矩阵转换为数组,返回权重数组


def plotBestFit(weights):
    dataMat, labelMat = loadDataSet()  # 加载数据集
    dataArr = np.array(dataMat)  # 转换成numpy的array数组
    n = np.shape(dataMat)[0]  # 数据个数
    xcord1 = [];
    ycord1 = []  # 正样本
    xcord2 = [];
    ycord2 = []  # 负样本
    for i in range(n):  # 根据数据集标签进行分类
        if int(labelMat[i]) == 1:
            xcord1.append(dataArr[i, 1]);
            ycord1.append(dataArr[i, 2])  # 1为正样本
        else:
            xcord2.append(dataArr[i, 1]);
            ycord2.append(dataArr[i, 2])  # 0为负样本
    fig = plt.figure()
    ax = fig.add_subplot(111)  # 添加subplot
    ax.scatter(xcord1, ycord1, s=20, c='red', marker='s', alpha=.5)  # 绘制正样本
    ax.scatter(xcord2, ycord2, s=20, c='green', alpha=.5)  # 绘制负样本
    x = np.arange(-3.0, 3.0, 0.1)
    y = (-weights[0] - weights[1] * x) / weights[2]
    ax.plot(x, y)
    plt.title('BestFit')  # 绘制title
    plt.xlabel('X1');
    plt.ylabel('X2')  # 绘制label
    plt.show()

if __name__ == '__main__':
    dataMat, labelMat = loadDataSet()
    weights = gradAscent(dataMat, labelMat)
    plotBestFit(weights)

机器学习实战教程(六):Logistic回归基础篇之梯度上升算法 (cuijiahua.com) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/70722.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

啊?我这手速也太差了吧?——C++Easyx“挑战六秒”小游戏

🐑本文作者:C橙羊🐑 🎮🔊本文代码适合编译环境:DEV-C💻 ✨🧨温馨提示:此文转载于codebus🎉🎠 最近橙羊在Easyx官网的codebus里随便逛逛的时候&am…

SpringMVC从入门到精通(一)

文章目录1. SpringMVC基本概念1.1 三层架构1.2 MVC架构1.3 什么是SpringMVC1.4 SpringMVC的优势2. SpringMVC 的入门2.1 入门程序2.2 SpringMVC执行原理刨析2.3 SpringMVC的核心执行流程2.4 SpringMVC的组件3. RequestMapping注解4.请求参数绑定4.1 参数绑定4.2 请求参数乱码问…

磨金石摄影技能干货分享|优秀纪实摄影作品欣赏—北京记事

1、蜂窝煤 三名青年男子踏着三轮车拉着满满一车蜂窝煤。脸上流露出清澈的笑容。这是九十年代的北京,背后的天安门格外的显眼。那时候处于改革开放的初期,虽然还不是很富裕,但大家脸上洋溢着幸福与希望的笑容。 蜂窝煤是冬天必备,九…

【强化学习论文合集】十一.2018国际表征学习大会论文(ICLR2018)

强化学习(Reinforcement Learning, RL),又称再励学习、评价学习或增强学习,是机器学习的范式和方法论之一,用于描述和解决智能体(agent)在与环境的交互过程中通过学习策略以达成回报最大化或实现特定目标的问题。 本专栏整理了近几年国际顶级会议中,涉及强化学习(Rein…

历届青少年蓝桥杯python编程选拔赛 STEMA评测比赛真题解析【持续更新 已更新至34题】

蓝桥杯python选拔赛真题 历届青少年蓝桥杯python编程选拔赛真题解析 选拔赛 真题34-回文数升级 【蓝桥杯选拔赛真题34】python回文数升级 青少年组蓝桥杯python 选拔赛STEMA比赛真题解析_小兔子编程的博客-CSDN博客python回文数升级2020年青少年组python蓝桥杯选拔赛真题一、…

剑指Offer39——数组中出现次数超过一半的数字

摘要 剑指Offer39 数组中出现次数超过一半的数字 本题常见的三种解法: 哈希表统计法: 遍历数组 nums ,用 HashMap 统计各数字的数量,即可找出 众数 。此方法时间和空间复杂度均为 O(N) 。数组排序法: 将数组 nums 排…

Python学习-8.1.1 标准库(time库的基础与实例)

2.1 time库 time库是Python提供的处理时间标准库。time库提供系统级精确计时器的计时功能,可以用来分析程序性能,也可以让程序暂停运行时间。 2.1.1 时间处理函数 time.time()函数:获取当前时间戳。 代表着如今的时间与1970年1月1日0分0秒…

18.10 字节码指令集与解析举例 - 同步控制指令

同步控制指令 组成 java虚拟机支持两种同步结构:方法级的同步和方法内部一段指令序列的同步,这两种同步都是使用monitor来支持的。 方法级的同步 方法级的同步:是隐式的,即无须通过字节码指令来控制,它实现在方法调…

Java+SSM网上书城全套含微信支付电商购物(含源码+论文+答辩PPT等)

项目功能简介: 本项目含代码详细讲解视频,手把手带同学们敲代码从0到1完成项目 该项目采用技术Springmvc、Spring、MyBatis、Tomcat服务器、MySQL数据库 项目含有源码、配套开发软件、软件安装教程、项目发布教程以及代码讲解教程 项目功能介绍: 系统管理…

HTML做一个简单的页面(纯html代码)地球专题学习网站

🎉精彩专栏推荐 💭文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业: 【📚毕设项目精品实战案例 (10…

STM32F4 | 新建工程模板——寄存器版本 | HAL库入门 | 新建工程模板——库函数版本

文章目录一、新建工程模板——寄存器版本二、HAL入门1.固件库和寄存器的区别2.STM32CubeF43.HAL库包介绍三、新建HAL库工程模板一、新建工程模板——寄存器版本 开发环境:MDK5软件包:STM32CubeF4包 新建工程模板的一般步骤为: 新建工程目录&a…

【UE5】多用户协同编辑

UE5新出了一个多用户协同功能所以想搭一个来玩玩。 Epic已经将流程极度的简化了,在B站虚幻官方也放出了教程视频,[官方文档](多用户编辑入门 | 虚幻引擎文档 (unrealengine.com))也有教程。 这里做一下简要记录。 1.启用插件 首先打开Multi-User Edi…

SoftPerfect NetWorx中管理流量和宽带设备工具

SoftPerfect NetWorx中管理流量和宽带设备工具 NetWorx是用于在Windows中管理流量和宽带设备的简单工具和实用程序。如果我们利用交通设施,毫不拖延地利用教育系统,以及与各种驾驶员相关的学习,那么当加载互联网时,通过软件秘密使…

[附源码]计算机毕业设计酒店客房管理系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

BaiqiSoft MstHtmlEditor for .NET负责编辑的控制器

BaiqiSoft MstHtmlEditor for .NET负责编辑的控制器 BaiqiSoft MstHtmlEditor获取.NET for win表单被认为是一个元素,用户可以轻松灵活地将其融入到C#、VB.NET甚至WPF软件中。负责编辑的控制器,.NET Win Forms的MstHtmlEditor,允许用户和开发人员,甚至非技术用户使用该系列…

Docker入门第二期

写目一、宿主机与容器之间的文件拷贝二、数据卷三、数据卷容器四、Dockerfile一、宿主机与容器之间的文件拷贝 docker run -p 3307:3306 --name mysql1 -di -v /home/javaxl/data/mysql/mysql.conf.d/:/etc/mysql/mysql.conf.d/ -v /home/javaxl/data/mysql/data/:/var/lib/…

用Python把附近的足浴店都给采集了一遍,好兄弟:针不戳~

前言 嗨喽,大家好呀~这里是爱看美女的小编 又到了学Python时刻~ (文末送读者福利) 我又来了!今天整个好玩的,你们肯定喜欢~ 咱们上班累了,不得好好犒劳一下自己,是吧 ! 于是我整…

相控阵天线(十三):天线校准技术仿真介绍之换相法

目录简介换相法算法简介换相法校准对方向图的影响Hadamard控制矩阵的换相法仿真循环移相控制矩阵的换相法仿真简介 传统方法按照测试区域分,可分为远场、中场和近场测量。远场测量发展成熟,可直接测量方向图,但对条件要求较高,且…

SpringMVC从入门到精通(二)

文章目录6. 响应视图和结果数据6.1 返回值类型6.2 springmvc作用域传值6.3 转发和重定向6.4 json数据格式的请求与响应7.SpringMVC 实现文件上传7.1 文件上传三要素7.2 文件上传依赖7.2 文件上传示例(后端需要配置文件解析器)8. SpringMVC 中的异常处理8…

FMT航点飞行(一)

一、航点飞行前检查: (1)将飞行器在position模式下启动飞行,移动一个距离;在qgc地面站地图上观察移动的距离,记录大概移动的距离为D;打航向,观察地磁计转向是否正常; &a…