【机器学习】PyTorch手动实现Logistic回归算法

news2024/11/25 5:21:37

参考地址:点击打开

计算较为繁琐,需要用到sigmoid函数和梯度下降算法,步骤主要如下:

  1. 二项分布概率公式表示
  2. 最大似然估计和对数化计算
  3. 求道
  4. 带入梯度下降算法计算和优化

在这里插入图片描述

代码:



import numpy as np
import matplotlib.pyplot as plt

def sigmoid(z):
    return 1.0/(1+np.exp(-z))


# datas NxD
# labs Nx1
# w    Dx1

def weight_update(datas,labs,w,alpha=0.01):
    z = np.dot(datas,w) # Nx1
    h = sigmoid(z)        # Nx1
    Error = labs-h        # Nx1 
    w = w + alpha*np.dot(datas.T,Error)
    return w

# 随机梯度下降
def train_LR_batch(datas,labs,batchsize=80,n_epoch=2,alpha=0.005):
    
    print("epoch:%d,alpha:%.8f batchsize:%d"%(n_epoch,alpha,batchsize))
    
    N,D = np.shape(datas)
    # weight 初始化
    w = np.ones([D,1])  # Dx1
    N_batch = N//batchsize  #取整数
    
    print("n:%d d:%d batchsize:%d"%(N,D,batchsize))
    
    for i in range(n_epoch):
        
        # 数据打乱
        rand_index  = np.random.permutation(N).tolist()
        # 每个batch 更新一下weight
        for j in range(N_batch):
            print("i:%d j:%d N_batch:%d\r\n"%(i,j,N_batch))
            # alpha = 4.0/(i+j+1) +0.01
            index = rand_index[j*batchsize:(j+1)*batchsize]
            batch_datas = datas[index]
            batch_labs = labs[index]
            w=weight_update(batch_datas,batch_labs,w,alpha)
    
        error = test_accuracy(datas,labs,w)
        print("epoch %d  error  %.2f%%"%(i,error*100))
    return w


def train_LR(datas,labs,n_epoch=2,alpha=0.005):
    N,D = np.shape(datas)   
    w = np.ones([D,1])  # Dx1
    # 进行n_epoch轮迭代
    for i in range(n_epoch):
        w = weight_update(datas,labs,w,alpha)
        error_rate=test_accuracy(datas,labs,w)
        print("epoch %d error %.3f%%"%(i,error_rate*100))
    return w


def test_accuracy(datas,labs,w):
    N,D = np.shape(datas)
    z = np.dot(datas,w) # Nx1
    h = sigmoid(z)        # Nx1
    lab_det = (h>0.5).astype(np.float)
    error_rate=np.sum(np.abs(labs-lab_det))/N
    return error_rate


def draw_desion_line(datas,labs,w,name="0.jpg"):
    dic_colors={0:(.8,0,0),1:(0,.8,0)}
  
    # 画数据点
    for i in range(2):
        index = np.where(labs==i)[0]
        sub_datas = datas[index]
        plt.scatter(sub_datas[:,1],sub_datas[:,2],s=16.,color=dic_colors[i])
    
    # 画判决线
    min_x = np.min(datas[:,1])
    max_x = np.max(datas[:,1])
    w = w[:,0]
    x = np.arange(min_x,max_x,0.01)
    y = -(x*w[1]+w[0])/w[2]
    plt.plot(x,y)
    
    plt.savefig(name)

    
def load_dataset(file):    
    with open(file,"r",encoding="utf-8") as f:
        lines = f.read().splitlines()

    # 取 lab 维度为 N x 1
    labs = [line.split("\t")[-1] for line in lines]
    labs = np.array(labs).astype(np.float32)
    labs= np.expand_dims(labs,axis=-1) # Nx1

    # 取数据 增加 一维全是1的特征
    datas = [line.split("\t")[:-1] for line in lines]
    datas = np.array(datas).astype(np.float32)
    N,D = np.shape(datas)
    # 增加一个维度
    datas = np.c_[np.ones([N,1]),datas]
    return datas,labs


if __name__ == "__main__":
    ''' 实验1 基础测试数据'''
    # 加载数据
    file = "testset.txt"
    datas,labs = load_dataset(file)
    
    #weights = train_LR(datas,labs,alpha=0.001,n_epoch=800)
    
    weights = train_LR_batch(datas,labs,batchsize=80,alpha=0.001,n_epoch=800)
    print(weights)
    draw_desion_line(datas,labs,weights,name="test_1.jpg")
    


    

训练需要的数据集testset.txt文件:

-0.017612	14.053064	0
-1.395634	4.662541	1
-0.752157	6.538620	0
-1.322371	7.152853	0
0.423363	11.054677	0
0.406704	7.067335	1
0.667394	12.741452	0
-2.460150	6.866805	1
0.569411	9.548755	0
-0.026632	10.427743	0
0.850433	6.920334	1
1.347183	13.175500	0
1.176813	3.167020	1
-1.781871	9.097953	0
-0.566606	5.749003	1
0.931635	1.589505	1
-0.024205	6.151823	1
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	1
1.985298	3.230619	1
-1.693453	-0.557540	1
-0.576525	11.778922	0
-0.346811	-1.678730	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	1
1.416614	9.619232	0
-0.386323	3.989286	1
0.556921	8.294984	1
1.224863	11.587360	0
-1.347803	-2.406051	1
1.196604	4.951851	1
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	0
-1.527893	12.150579	0
-1.185247	11.309318	0
-0.445678	3.297303	1
1.042222	6.105155	1
-0.618787	10.320986	0
1.152083	0.548467	1
0.828534	2.676045	1
-1.237728	10.549033	0
-0.683565	-2.166125	1
0.229456	5.921938	1
-0.959885	11.555336	0
0.492911	10.993324	0
0.184992	8.721488	0
-0.355715	10.325976	0
-0.397822	8.058397	0
0.824839	13.730343	0
1.507278	5.027866	1
0.099671	6.835839	1
-0.344008	10.717485	0
1.785928	7.718645	1
-0.918801	11.560217	0
-0.364009	4.747300	1
-0.841722	4.119083	1
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	0
0.342578	12.281162	0
-0.810823	-1.466018	1
2.530777	6.476801	1
1.296683	11.607559	0
0.475487	12.040035	0
-0.783277	11.009725	0
0.074798	11.023650	0
-1.337472	0.468339	1
-0.102781	13.763651	0
-0.147324	2.874846	1
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	1
-0.851633	4.375691	1
-1.510047	6.061992	0
-1.076637	-3.181888	1
1.821096	10.283990	0
3.010150	8.401766	1
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	1
1.400102	12.628781	0
1.752842	5.468166	1
0.078557	0.059736	1
0.089392	-0.715300	1
1.825662	12.693808	0
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.220530	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.388610	9.341997	0
0.317029	14.739025	0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/781066.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

05.计算机网络——TCP协议

文章目录 TCP协议段格式TCP交付过程TCP解包过程确认应答机制\[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kDvQFCTM-1689855767485)(C:\Users\11794\AppData\Roaming\Typora\typora-user-images\image-20230719204622485.png)\] 32位序号/32位确认…

深度学习anaconda+pycharm+虚拟环境迁移

一、下载好anaconda和pycharm安装包。 下载anaconda:Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror pycharm汉化包 二、安装anaconda 深度学习环境配置-Anaconda以及pytorch1.2.0的环境配置(Bubbliiiing 深度学习 教程&…

Pycharm远程服务器连接教程

第一步 只有Pycharm专业版才能远程连接服务器 第二步:远程连接部分 点击左上角的号新建一个连接,起一个名字,比如叫dilab191: 设置SSH参数 Tools-Development-Options 第三步, 添加远程服务器解释器部分 File-settings-Project …

spring复习:(50)@Configuration注解配置的singleton的bean是什么时候被创建出来并缓存到容器的?

一、主类: 二、配置类: 三、singleton bean的创建流程 运行到context.refresh(); 进入refresh方法: 向下运行到红线位置时: 会实例化所有的singleton bean.进入finisheBeanFactoryInitialization方法: 向下拖动代…

旧版Xcode文件较大导致下载总是失败但又不能断点续传重新开始的解决方法

问题: 旧版mac下载旧版Xcode时需要进入https://developer.apple.com/download/all/?qxcode下载,但是下载这些文件需要登录。登录后下载中途很容易失败,失败后又必须重新下载。 解决方案: 下载这里面的内容都需要登录&#xff0…

华为、阿里巴巴、字节跳动 100+ Python 面试问题总结(五)

系列文章目录 个人简介:机电专业在读研究生,CSDN内容合伙人,博主个人首页 Python面试专栏:《Python面试》此专栏面向准备面试的2024届毕业生。欢迎阅读,一起进步!🌟🌟🌟 …

苹果手机IOS自带科学计算器冷门功能使用

前言 事件是这样的,前几天有人想买个斜坡枕,斜坡枕是个直角三角形,已知短直角边长度是14CM,长直角边长度是80CM,他想知道这个斜坡是多少度,我说这个不是很简单吗?计算一下 a r c t a n ( 14 80…

C# List 详解七

目录 42.Sort() 43.ToArray() 44.ToString() 45.TrimExcess() 46.TrueForAll(Predicate) C# List 详解一 1.Add(T),2.AddRange(IEnumerable),3.AsReadOnly(),4.BinarySearch(T), C# List 详解二 5.Cl…

Matlab 刚性问题求解器-ode23s

1、ode23s介绍 ode23s(stiff differential equation solver)是MATLAB中的一种求解刚性(stiff)微分方程的数值方法。刚性微分方程通常具有多个时间尺度差异较大的变量,并且其中至少有一个变量具有快速变化的特性。 od…

Antv G6 force分布式布局 icon“+“ “-“收缩自定义,关系图子节点

子节点收缩 const collapseIcon (x, y, r) > {// 折叠return [[M, x - r, y],[a, r, r, 0, 1, 0, r * 2, 0],[a, r, r, 0, 1, 0, -r * 2, 0],[M, x - r 4, y],[L, x - r 2 * r - 4, y]]}const expandIcon (x, y, r) > {// 拓展return [[M, x - r, y],[a, r, r, 0, 1,…

SQL优化——插入数据优化(load指令的使用)

插入数据时的优化主键优化order by优化group by优化limit优化count优化update优化 1.插入数据时的优化 批量插入数据时最好最多别超过一千条,如果一次批量插入几万条数据,可以将其分割成多条insert语句进行插入。 mysql的事务提交方式是默认自动提交的…

Linux 下centos 查看 -std 是否支持 C17

实际工作中,可能会遇到c的一些高级特性,例如std::invoke,此函数是c17才引入的,如何判断当前的gcc是否支持c17呢,这里提供两种办法。 1.根据gcc的版本号来推断 gcc --version,可以查看版本号,笔者…

15.矩阵运算与img2col方式的卷积

使用矩阵计算卷积 GEMM算法 矩阵乘法运算(General Matrix Multiplication),形如: C A B , A ∈ R m k , B ∈ R k n , C ∈ R m n C AB, A\in \mathbb{R}^{m\times k},B\in \mathbb{R}^{k\times n},C\in \mathbb{R}^{m\times n} CAB,A∈Rmk,B∈Rk…

vite4.x+vue3.x中使用装饰器语法,eslint校验不识别@的报错处理方法

在项目中,使用了pre-commit校验代码,eslint校验无法识别,导致一直无法提交代码,查找了资料,eslint版本过低,不能解决现在遇到的问题 最终正确的配置方法: 装饰器配置文件babel.config.js module.exports …

了解应用层

应用层 1. 概述2. 应用程序组织方式2.1 C/S方式2.1 P2P方式 3. 动态主机配置协议DHCP3.1 DHCP工作流程 4. 域名系统DNS4.1 域名结构4.2 域名分类4.3 域名服务器4.3.1 分类 4.4 DNS域名解析过程 5. 文件传输协议FTP5.1 FTP工作流程 6. 电子邮件系统6.1 邮件信息格式6.2 简单邮件…

EtherCAT转TCP/IP网关EtherCAT解决方案

你是否曾经为生产管理系统的数据互联互通问题烦恼过?曾经因为协议不同导致通讯问题而感到困惑?现在,我们迎来了突破性的进展! 介绍捷米特JM-TCPIP-ECT,一款自主研发的Ethercat从站功能的通讯网关。它能够连接到Etherc…

12.面板问题

面板问题 html部分 <h1>Lorem ipsum dolor sit, amet consectetur adipisicing.</h1><div class"container"><div class"faq"><div class"title-box"><h3 class"title">Lorem, ipsum dolor.<…

TypeScript 中的常用类型声明大全

文章目录 基本数据类型1.number类型2.String 类型3. Boolean 类型4. undefined 类型5.Null类型6.Symbol类型7.BigInt类型 引用数据类型8.Array 类型9.Object 类型 TS 新增特性数据类型4.联合类型5.字面量类型6.Any 类型7.unknown 类型8.Void 类型9.never 类型10.对象类型12 tup…

基于linux下的高并发服务器开发(第三章)- 3.11 读写锁

读写锁的类型 pthread_rwlock_t int pthread_rwlock_init(pthread_rwlock_t *restrict rwlock, const pthread_rwlockattr_t *restrict attr); int pthread_rwlock_destroy(pthread_rwlock_t *rwlock); int pthread_rwlock_rdlock(pthread_rwlock_t *rwlock); int pthread_rwlo…

macOS 下使用 brew 命令安装 Node.js

&#x1f468;&#x1f3fb;‍&#x1f4bb; 热爱摄影的程序员 &#x1f468;&#x1f3fb;‍&#x1f3a8; 喜欢编码的设计师 &#x1f9d5;&#x1f3fb; 擅长设计的剪辑师 &#x1f9d1;&#x1f3fb;‍&#x1f3eb; 一位高冷无情的编码爱好者 大家好&#xff0c;我是 DevO…