【机器学习】PyTorch手动实现Logistic算法

news2024/11/26 12:51:37

参考地址:点击打开

计算较为繁琐,需要用到sigmoid函数和梯度下降算法,步骤主要如下:

  1. 二项分布概率公式表示
  2. 最大似然估计和对数化计算
  3. 求道
  4. 带入梯度下降算法计算和优化

在这里插入图片描述

代码:

import numpy as np
import matplotlib.pyplot as plt

def sigmoid(z):
    return 1.0/(1+np.exp(-z))


# datas NxD
# labs Nx1
# w    Dx1

def weight_update(datas,labs,w,alpha=0.01):
    z = np.dot(datas,w) # Nx1
    h = sigmoid(z)        # Nx1
    Error = labs-h        # Nx1 
    w = w + alpha*np.dot(datas.T,Error)
    return w
    
def train_LR(datas,labs,n_epoch=2,alpha=0.005):
    N,D = np.shape(datas)   
    w = np.ones([D,1])  # Dx1
    # 进行n_epoch轮迭代
    for i in range(n_epoch):
        w = weight_update(datas,labs,w,alpha)
        error_rate=test_accuracy(datas,labs,w)
        print("epoch %d error %.3f%%"%(i,error_rate*100))
    return w
    
def test_accuracy(datas,labs,w):
    N,D = np.shape(datas)
    z = np.dot(datas,w) # Nx1
    h = sigmoid(z)        # Nx1
    lab_det = (h>0.5).astype(np.float)
    error_rate=np.sum(np.abs(labs-lab_det))/N
    return error_rate
    
def draw_desion_line(datas,labs,w,name="0.jpg"):
    dic_colors={0:(.8,0,0),1:(0,.8,0)}
  
    # 画数据点
    for i in range(2):
        index = np.where(labs==i)[0]
        sub_datas = datas[index]
        plt.scatter(sub_datas[:,1],sub_datas[:,2],s=16.,color=dic_colors[i])
    
    # 画判决线
    min_x = np.min(datas[:,1])
    max_x = np.max(datas[:,1])
    w = w[:,0]
    x = np.arange(min_x,max_x,0.01)
    y = -(x*w[1]+w[0])/w[2]
    plt.plot(x,y)
    
    plt.savefig(name)
    
def load_dataset(file):    
    with open(file,"r",encoding="utf-8") as f:
        lines = f.read().splitlines()
    
    # 取 lab 维度为 N x 1
    labs = [line.split("\t")[-1] for line in lines]
    labs = np.array(labs).astype(np.float32)
    labs= np.expand_dims(labs,axis=-1) # Nx1
    
    # 取数据 增加 一维全是1的特征
    datas = [line.split("\t")[:-1] for line in lines]
    datas = np.array(datas).astype(np.float32)
    N,D = np.shape(datas)
    # 增加一个维度
    datas = np.c_[np.ones([N,1]),datas]
    return datas,labs
    
if __name__ == "__main__":
    ''' 实验1 基础测试数据'''
    # 加载数据
    file = "testset.txt"
    datas,labs = load_dataset(file)
    
    weights = train_LR(datas,labs,alpha=0.001,n_epoch=800)
    print(weights)
    draw_desion_line(datas,labs,weights,name="test_1.jpg")
    

训练需要的数据集testset.txt文件:

-0.017612	14.053064	0
-1.395634	4.662541	1
-0.752157	6.538620	0
-1.322371	7.152853	0
0.423363	11.054677	0
0.406704	7.067335	1
0.667394	12.741452	0
-2.460150	6.866805	1
0.569411	9.548755	0
-0.026632	10.427743	0
0.850433	6.920334	1
1.347183	13.175500	0
1.176813	3.167020	1
-1.781871	9.097953	0
-0.566606	5.749003	1
0.931635	1.589505	1
-0.024205	6.151823	1
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	1
1.985298	3.230619	1
-1.693453	-0.557540	1
-0.576525	11.778922	0
-0.346811	-1.678730	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	1
1.416614	9.619232	0
-0.386323	3.989286	1
0.556921	8.294984	1
1.224863	11.587360	0
-1.347803	-2.406051	1
1.196604	4.951851	1
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	0
-1.527893	12.150579	0
-1.185247	11.309318	0
-0.445678	3.297303	1
1.042222	6.105155	1
-0.618787	10.320986	0
1.152083	0.548467	1
0.828534	2.676045	1
-1.237728	10.549033	0
-0.683565	-2.166125	1
0.229456	5.921938	1
-0.959885	11.555336	0
0.492911	10.993324	0
0.184992	8.721488	0
-0.355715	10.325976	0
-0.397822	8.058397	0
0.824839	13.730343	0
1.507278	5.027866	1
0.099671	6.835839	1
-0.344008	10.717485	0
1.785928	7.718645	1
-0.918801	11.560217	0
-0.364009	4.747300	1
-0.841722	4.119083	1
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	0
0.342578	12.281162	0
-0.810823	-1.466018	1
2.530777	6.476801	1
1.296683	11.607559	0
0.475487	12.040035	0
-0.783277	11.009725	0
0.074798	11.023650	0
-1.337472	0.468339	1
-0.102781	13.763651	0
-0.147324	2.874846	1
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	1
-0.851633	4.375691	1
-1.510047	6.061992	0
-1.076637	-3.181888	1
1.821096	10.283990	0
3.010150	8.401766	1
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	1
1.400102	12.628781	0
1.752842	5.468166	1
0.078557	0.059736	1
0.089392	-0.715300	1
1.825662	12.693808	0
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.220530	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.388610	9.341997	0
0.317029	14.739025	0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/770217.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[CPU飙升排查]生产CPU飙升,YGC不断的事故

背景 最近给上线还未使用的服务配置监控,监控系统电话将我呼醒 ,导致原本就不多的头发一阵掉落. 还好系统还没有流量进入,先免打扰,第二天再处理. 查看面板情况如下: FGC正常 YGC不断 CPU飙升 思路确定 分析了下YGC不断,但是没有FGC,CPU飙升,可能出现的情况是哪里有活锁或者…

数据迁移卷不动了?Squids DBMotion新增多种数据库迁移能力

Squdis DBMotion新增了多种数据库的迁移能力:SQLServer to SQLServer、Redis to Redis、MySQL to Kafka,增加了列映射、校验任务独立、抽样校验、校验复检和限速等十多项功能。 本次版本更新,DBMotion新增了三种数据库迁移同步的场景。目前&…

el-table组件插槽“slot-scope”

目录 一、代码展示 二、返回的数组对象不含value或者ispass&#xff0c;不会报错 三、插槽里面放的是要手动输入的值时 一、代码展示 <el-table v-loading"loading" :data"checklistList" selection-change"handleSelectionChange"><…

轻松实现数据一体化:轻易云数据集成平台全解析

在当今快速发展的商业环境中&#xff0c;企业面临着大量来自多样数据源的数据。如何将这些数据进行高效集成和利用&#xff0c;成为企业数字化转型的关键挑战。轻易云数据集成平台提供了一个一站式的解决方案&#xff0c;帮助企业实现数据的无缝集成和高效利用。下面我们将通过…

Pandas Groupby:在Python中汇总、聚合和分组数据

GroupBy是一个非常简单的概念。我们可以创建一个类别分组&#xff0c;并对这些类别应用一个函数。这是一个简单的概念&#xff0c;但它是一种在数据科学中广泛使用的非常有价值的技术。在真实的的数据科学项目中&#xff0c;您将处理大量数据并一遍又一遍地尝试&#xff0c;因此…

Linux(CentOS7)下源码编译 PostgreSQL13.10 安装手册

Linux&#xff08;CentOS7&#xff09;下PostgreSQL安装手册 文章目录 一、准备PostgreSQL二、安装PostgreSQL2.1解压安装包2.2编译PG2.3查看PG安装目录2.4配置PG环境变量2.5查看PG版本2.6创建postgres用户2.7创建PG数据库数据存放目录2.8授权PG数据库数据存放目录2.9切换postg…

一起学SF框架系列5.9-spring-Beans-bean实例创建

bean实例化底层采用Java反射机制&#xff0c;但Spring根据框架需要提供了更多的增强功能。 类关系图 InstantiationStrategy&#xff1a;接口-定义了创建RootBeanDefinition对应bean实例方法 SimpleInstantiationStrategy&#xff1a;简单bean的实例化处理。实现了Instantiati…

WEB:题目名称-文件包含

背景知识 题目 题目了文件包含&#xff0c;所以想到了php伪协议 构造payload尝试读取flag.php /?filenamephp://filter/readconvert.base64-encode/resourceflag.php 页面提示“do not hack”猜测可能是黑名单检测敏感字符串。猜测字符串哪些被禁用&#xff0c;这里输入单个…

【算法与数据结构】144、94、145LeetCode二叉树的前中后遍历(递归法、迭代法)

文章目录 一、题目二、递归算法三、迭代算法3.1 迭代算法13.2 迭代算法2 ——统一风格写法 四、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、递归算法 思路分析&#xff1a;这道题比较简单&#xff0c;不多说…

LCD—STM32液晶显示(2.使用FSMC模拟8080时序)

目录 使用STM32的FSMC模拟8080接口时序 FSMC简介 FSMC NOR/PSRAM中的模式B时序图 用FSMC模拟8080时序 重点&#xff1a;HADDR内部地址与FSMC地址信号线的转换&#xff08;实现地址对齐&#xff09; 使用STM32的FSMC模拟8080接口时序 ILI9341的8080通讯接口时序可以由STM32使…

PHP要怎么学--【强撸项目000】

强撸项目 总目录在000集 文章目录 本系列校训学习资源的选择环境的问题本人推荐 PHP视频的知识点分析总结题外话 本系列校训 用免费公开视频&#xff0c;卷飞培训班哈人&#xff01;打死不报班&#xff0c;赚钱靠狠干&#xff01; 只要自己有电脑&#xff0c;前后项目都能搞&a…

sqli-labs 堆叠注入 解析

打开网页首先判断闭合类型 说明为双引号闭合 我们可以使用单引号将其报错 先尝试判断回显位 可以看见输出回显位为2&#xff0c;3 尝试暴库爆表 这时候进行尝试堆叠注入&#xff0c;创造一张新表 ?id-1 union select 1,database(),group_concat(table_name) from informatio…

给你一个网站,你如何测试?

首先&#xff0c;查找需求说明、网站设计 等相关文档&#xff0c;分析测试需求。 制定测试计划&#xff0c;确定测试范围和测试策略&#xff0c;一般包括以下几个部分&#xff1a; 功能性测试&#xff1b;界面测试&#xff1b;性能测试&#xff1b;数据库测试&#xff1b;安全…

AD如何查看PCB完成度?快来看这篇文

在Altium Designer&#xff08;AD&#xff09;中&#xff0c;很多工程师通过使用Design Rule Check&#xff08;DRC&#xff0c;常用于检查PCB设计是否符合设计规范和要求&#xff09;功能来检查PCB设计的完成度&#xff0c;但很多小白不太熟悉怎么去使用DRC&#xff0c;下面来…

深入浅出C语言—【函数】下

目录 5. 函数的嵌套调用和链式访问5.1嵌套调用5.2 链式访问 6. 函数的声明和定义6.1 函数声明6.2 函数定义 7. 函数递归&#x1f451;7.1 什么是递归&#xff1f;7.2 递归的两个必要条件7.2.1 练习17.2.2 练习2 7.3 递归与迭代7.3.1 练习37.3.2 练习4 5. 函数的嵌套调用和链式访…

解决Missing cookie ‘JssionId‘ for method parameter of type String问题

错误描述如下所示&#xff1a; 上述错误是我在使用CookieValue注解&#xff0c;获取cookieID时出现的&#xff0c;错误原因是由于**CookieValue注解注解中的value值和浏览器中的cookie的jssionID不一致所导致的** 如下所示为浏览器中的CookieID的参数名 而我在注解中写的如下图…

浪涌保护器行业应用防雷选型方案

当今社会中&#xff0c;电气设备的使用范围越来越广泛&#xff0c;也越来越普及&#xff0c;而与之相关的浪涌保护器就显得尤为重要。在这个领域&#xff0c;有一种高品质的浪涌保护器 —— 地凯防雷SPD浪涌保护器&#xff0c;它可以为各种设备提供强大的保护&#xff0c;并在各…

YOLOv5——pytorch环境搭建

环境搭建是一个最最基础而又基本的事情&#xff0c;是一切工作开始前的基本要求。 由于YOLOv7和YOLOv5不兼容&#xff0c;这次用到了YOLOv5&#xff0c;我不得不再使用anaconda创建一个虚拟环境。 Tip&#xff1a;很多人不了解Anaconda存在的意义&#xff0c;就是为了弥补pyt…

四、DML-1.数据操作-添加

一、DML介绍 Data Manipulation Language 数据操作语言 用来对数据库中表的数据记录进行增删改操作。 二、添加数据 1、给指定字段添加数据 insert into employee(id, workno, name, gender, age, idcard,entrydate) values (1, 001,Itcast, 男, 18, 123456789012345678, 2…

kaggle新赛:学生摘要评估大赛赛题解析(NLP)

赛题名称&#xff1a;CommonLit - Evaluate Student Summaries 赛题链接&#xff1a; https://www.kaggle.com/competitions/commonlit-evaluate-student-summaries/ 赛题背景 摘要写作是所有年龄段学习者的一项重要技能。总结可以增强阅读理解能力&#xff0c;特别是在第二…