TensorFlow实现逻辑回归模型

news2025/4/21 22:00:11

逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将介绍如何使用TensorFlow框架实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来直观地观察模型的训练过程。

数据准备

首先,我们准备两类数据点,分别表示两个不同的类别。这些数据点将作为模型的输入特征。

# 1.散点输入
class1_points=np.array([[1.9,1.2],
                        [1.5,2.1],
                        [1.9,0.5],
                        [1.5,0.9],
                        [0.9,1.2],
                        [1.1,1.7],
                        [1.4,1.1]])
class2_points=np.array([[3.2,3.2],
                        [3.7,2.9],
                        [3.2,2.6],
                        [1.7,3.3],
                        [3.4,2.6],
                        [4.1,2.3],
                        [3.0,2.9]])

将两类数据点合并为一个矩阵,并为每个数据点分配相应的标签(0或1)。

#不用单独提取出x1_data 和x2_data
#框架会根据输入特征数自动提取
x_train=np.concatenate((class1_points,class2_points),axis=0)
y_train=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class2_points))))

将数据转换为TensorFlow张量,以便在模型中使用。

import tensorflow as tf

x_train_tensor = tf.convert_to_tensor(x_train, dtype=tf.float32)
y_train_tensor = tf.convert_to_tensor(y_train, dtype=tf.float32)

模型定义

使用TensorFlow的tf.keras模块定义逻辑回归模型。模型包含一个输入层和一个输出层,输出层使用sigmoid激活函数。

def LogisticRegreModel():
    input = tf.keras.Input(shape=(2,))
    fc = tf.keras.layers.Dense(1, activation='sigmoid')(input)
    lr_model = tf.keras.models.Model(inputs=input, outputs=fc)
    return lr_model

model = LogisticRegreModel()

定义优化器和损失函数。这里使用随机梯度下降优化器和二元交叉熵损失函数。

opt = tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(optimizer=opt, loss="binary_crossentropy")

训练过程

训练模型时,我们记录每个epoch的损失值,并动态绘制决策边界和损失曲线。

 

import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2)

epochs = 500
epoch_list = []
epoch_loss = []

for epoch in range(1, epochs + 1):
    y_pre = model.fit(x_train_tensor, y_train_tensor, epochs=50, verbose=0)
    epoch_loss.append(y_pre.history["loss"][0])
    epoch_list.append(epoch)

    w1, w2 = model.get_weights()[0].flatten()
    b = model.get_weights()[1][0]

    slope = -w1 / w2
    intercept = -b / w2

    x_min, x_max = 0, 5
    x = np.array([x_min, x_max])
    y = slope * x + intercept

    ax1.clear()
    ax1.plot(x, y, 'r')
    ax1.scatter(x_train[:len(class1_points), 0], x_train[:len(class1_points), 1])
    ax1.scatter(x_train[len(class1_points):, 0], x_train[len(class1_points):, 1])

    ax2.clear()
    ax2.plot(epoch_list, epoch_loss, 'b')
    plt.pause(1)

结果展示

训练完成后,决策边界图将显示模型如何将两类数据分开,损失曲线图将显示模型在训练过程中的损失值变化。生成结果基本如图所示:

通过动态绘制决策边界和损失曲线,我们可以直观地观察模型的训练过程,了解模型如何逐渐学习数据的分布并优化决策边界。

总结

本文介绍了如何使用TensorFlow实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来观察模型的训练过程。逻辑回归是一种简单而有效的分类算法,适用于二分类问题。通过TensorFlow框架,我们可以轻松地实现和训练逻辑回归模型,并利用其强大的功能来优化模型的性能。


完整代码

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# 1.散点输入
class1_points=np.array([[1.9,1.2],
                        [1.5,2.1],
                        [1.9,0.5],
                        [1.5,0.9],
                        [0.9,1.2],
                        [1.1,1.7],
                        [1.4,1.1]])
class2_points=np.array([[3.2,3.2],
                        [3.7,2.9],
                        [3.2,2.6],
                        [1.7,3.3],
                        [3.4,2.6],
                        [4.1,2.3],
                        [3.0,2.9]])

#不用单独提取出x1_data 和x2_data
#框架会根据输入特征数自动提取
x_train=np.concatenate((class1_points,class2_points),axis=0)
y_train=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class2_points))))
#转化为张量
x_train_tensor=tf.convert_to_tensor(x_train,dtype=tf.float32)
y_train_tensor=tf.convert_to_tensor(y_train,dtype=tf.float32)

#2.定义前向模型
# 使用类的方式
# 先设置一下随机数种子
seed=0
tf.random.set_seed(0)

def LogisticRegreModel():
    input=tf.keras.Input(shape=(2,))
    fc=tf.keras.layers.Dense(1,activation='sigmoid')(input)
    lr_model=tf.keras.models.Model(inputs=input,outputs=fc)
    return lr_model
#实例化网络
model=LogisticRegreModel()
#3.定义损失函数和优化器
#定义优化器
#需要输入模型参数和学习率
lr=0.1
opt=tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(optimizer=opt,loss="binary_crossentropy")



# 最后画图
fig,(ax1,ax2)=plt.subplots(1,2)
#训练
epoches=500
epoch_list=[]
epoch_loss=[]
for epoch in range(1,epoches+1):
    # verbose=0 进度条不显示  epochs迭代次数
    y_pre=model.fit(x_train_tensor,y_train_tensor,epochs=50,verbose=0)
    # print(y_pre.history["loss"])
    epoch_loss.append(y_pre.history["loss"][0])
    epoch_list.append(epoch)
    w1,w2=model.get_weights()[0].flatten()
    b=model.get_weights()[1][0]

    #画左图
    # 使用斜率和截距画直线
    #目前将x2当作y轴 x1当作x轴
    # w1*x1+w2*x2+b=0
    #求出斜率和截距
    slope=-w1/w2
    intercept=-b/w2
    #绘制直线 开始结束位置
    x_min,x_max=0,5
    x=np.array([x_min,x_max])
    y=slope*x+intercept
    ax1.clear()
    ax1.plot(x,y,'r')
    #画散点图
    ax1.scatter(x_train[:len(class1_points),0],x_train[:len(class1_points),1])
    ax1.scatter(x_train[len(class1_points):, 0],x_train[len(class1_points):, 1])


    #画右图
    ax2.clear()
    ax2.plot(epoch_list,epoch_loss,'b')
    plt.pause(1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2284619.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《十七》浏览器基础

浏览器:是安装在电脑里面的一个软件,能够将页面内容渲染出来呈现给用户查看,并让用户与网页进行交互。 常见的主流浏览器: 常见的主流浏览器有:Chrome、Safari、Firefox、Opera、Edge 等。 输入 URL,浏览…

网络安全 | F5-Attack Signatures-Set详解

关注:CodingTechWork 创建和分配攻击签名集 可以通过两种方式创建攻击签名集:使用过滤器或手动选择要包含的签名。  基于过滤器的签名集仅基于在签名过滤器中定义的标准。基于过滤器的签名集的优点在于,可以专注于定义用户感兴趣的攻击签名…

STranslate 中文绿色版即时翻译/ OCR 工具 v1.3.1.120

STranslate 是一款功能强大且用户友好的翻译工具,它支持多种语言的即时翻译,提供丰富的翻译功能和便捷的使用体验。STranslate 特别适合需要频繁进行多语言交流的个人用户、商务人士和翻译工作者。 软件功能 1. 即时翻译: 文本翻译&#xff…

基于微信小程序的助农扶贫系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

我谈区域偏心率

偏心率的数学定义 禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》P312 区域的拟合椭圆看这里。 Rafael Gonzalez的二阶中心矩的表达不说人话。 我认为半长轴和半短轴不等于特征值,而是特征值的根号。…

关于低代码技术架构的思考

我们经常会看到很多低代码系统的技术架构图,而且经常看不懂。是因为技术架构图没有画好,还是因为技术不够先进,有时候往往都不是。 比如下图: 一个开发者,看到的视角往往都是技术层面,你给用户讲React18、M…

若依路由配置教程

1. 路由配置文件 2. 配置内容介绍 { path: "/tool/gen-edit", component: Layout, //在路由下,引用组件的名称,在页面中包括这个组件的内容(页面框架内容) hidden: true, //此页面的内容,在左边的菜单中不用显示。 …

基于ESP8266的多功能环境监测与反馈系统开发指南

项目概述 本系统集成了物联网开发板、高精度时钟模块、环境传感器和可视化显示模块,构建了一个智能环境监测与反馈装置。通过ESP8266 NodeMCU作为核心控制器,结合DS3231实时时钟、DHT11温湿度传感器、光敏电阻和OLED显示屏,实现了环境参数的…

HTML5 Web Worker 的使用与实践

引言 在现代 Web 开发中,用户体验是至关重要的。如果页面在执行复杂计算或处理大量数据时变得卡顿或无响应,用户很可能会流失。HTML5 引入了 Web Worker,它允许我们在后台运行 JavaScript 代码,从而避免阻塞主线程,保…

flutter_学习记录_00_环境搭建

1.参考文档 Mac端Flutter的环境配置看这一篇就够了 flutter的中文官方文档 2. 本人环境搭建的背景 本人的电脑的是Mac的,iOS开发,所以iOS开发环境本身是可用的;外加Mac电脑本身就会配置Java的环境。所以,后面剩下的就是&#x…

自助设备系统设置——对接POS支付

输入管理员密码 一、录入POS网关信息 填写网关信息后保存,重新启动软件

Calibre(阅读转换)-官方开源中文版[完整的电子图书馆系统,包括图书馆管理,格式转换,新闻,材料转换为电子书]

Calibre(阅读&转换)-官方开源中文版 链接:https://pan.xunlei.com/s/VOHbKYUwd3ASVXTi2Ok1vkK3A1?pwd92ny#

【unity游戏开发之InputSystem——06】PlayerInputManager组件实现本地多屏的游戏(基于unity6开发介绍)

文章目录 PlayerInputManager 简介1、PlayerInputManager 的作用2、主要功能一、PlayerInputManager组件参数1、Notification Behavior 通知行为2、Join Behavior:玩家加入的行为3、Player Prefab 玩家预制件4、Joining Enabled By Default 默认启用加入5、Limit Number Of Pl…

算法刷题Day29:BM67 不同路径的数目(一)

题目链接 描述 解题思路: 二维dp数组初始化。 dp[i][0] 1, dp[0][j] 1 。因为到达第一行第一列的每个格子只能有一条路。状态转移 dp[i][j] dp[i-1][j] dp[i][j-1] 代码: class Solution: def uniquePaths(self , m: int, n: int) -> int: #…

美国本科申请文书PS写作中的注意事项

在完成了introduction之后,便可进入到main body的写作之中。美国本科申请文书PS的写作不同于学术论文写作,要求你提出论点进行论证之类。PS更多的注重对你自己的经历或者motivation的介绍和描述。而这一描述过程只能通过对你自己的过往的经历的展现才能体…

内存泄漏的通用排查方法

本文聊一聊如何系统性地分析查找内存泄漏的具体方法,但不会具体到哪种语言和具体业务代码逻辑中,而是会从 Linux 系统上通用的一些分析方法来入手。这样,不论你使用什么开发语言,不论你在开发什么,它总能给你提供一些帮…

【Python】第五弹---深入理解函数:从基础到进阶的全面解析

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、函数 1.1、函数是什么 1.2、语法格式 1.3、函数参数 1.4、函数返回值 1.5、变量作用域 1.6、函数…

读书笔记--分布式服务架构对比及优势

本篇是在上一篇的基础上,主要对共享服务平台建设所依赖的分布式服务架构进行学习,主要记录和思考如下,供大家学习参考。随着企业各业务数字化转型工作的推进,之前在传统的单一系统(或单体应用)模式中&#…

关于WPF中ComboBox文本查询功能

一种方法是使用事件&#xff08;包括MVVM的绑定&#xff09; <ComboBox TextBoxBase.TextChanged"ComboBox_TextChanged" /> 然而运行时就会发现&#xff0c;这个事件在疯狂的触发&#xff0c;很频繁 在实际应用中&#xff0c;如果关联查询数据库&#xff0…

LockSupport概述、阻塞方法park、唤醒方法unpark(thread)、解决的痛点、带来的面试题

目录 ①. 什么是LockSupport? ②. 阻塞方法 ③. 唤醒方法(注意这个permit最多只能为1) ④. LockSupport它的解决的痛点 ⑤. LockSupport 面试题目 ①. 什么是LockSupport? ①. 通过park()和unpark(thread)方法来实现阻塞和唤醒线程的操作 ②. LockSupport是一个线程阻塞…