TensorFlow 简单的二分类神经网络的训练和应用流程

news2025/2/1 23:09:32

展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括:

1. 数据准备与预处理

2. 构建模型

3. 编译模型

4. 训练模型

5. 评估模型

6. 模型应用与部署

加载和应用已训练的模型


1. 数据准备与预处理

在本例中,数据准备是通过两个 Numpy 数组来完成的:

  • x:输入特征,形状为 (8, 2),包含 8 个数据点,每个数据点有 2 个特征。
  • y:标签,形状为 (8,),包含对应的 0 或 1 标签,表示每个输入点的类别。
x = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1], [0.7, 0.7], [0.7, -0.7], [-0.7, -0.7], [-0.7, 0.7]])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0])

2. 构建模型

使用 Keras 的 Sequential 模型来构建神经网络。此模型包含两个全连接层(Dense 层):

  • 第一个 Dense 层有 3 个单位,激活函数是 Sigmoid。
  • 第二个 Dense 层有 1 个单位,激活函数是 Sigmoid,输出层的激活函数将模型输出的值映射到 0 到 1 之间,适合二分类任务。
l1 = tf.keras.layers.Dense(units=3, activation='sigmoid')
l2 = tf.keras.layers.Dense(units=1, activation='sigmoid')
model = tf.keras.Sequential([l1, l2])

3. 编译模型

在编译阶段,我们选择了优化器、损失函数和评估指标:

  • 优化器:SGD(随机梯度下降),学习率设置为 0.9。
  • 损失函数:binary_crossentropy,适用于二分类任务。
  • 评估指标:accuracy,表示训练过程中对分类准确率的衡量。
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])

4. 训练模型

通过 model.fit() 函数来训练模型。我们传入训练数据 x 和标签 y,并设置训练的 epoch 数量为 2000。

model.fit(x, y, epochs=2000)

5. 评估模型

在此示例中,评估部分通过训练后的 model 来进行,并没有显式写出 evaluate() 函数。评估通常是在训练之后,通过测试集或验证集对模型性能进行评估,具体可以使用 model.evaluate() 来查看损失和准确度。

6. 模型应用与部署

训练完成后,我们保存了训练好的模型。保存后的模型可以被加载和应用于新的数据集。

model.save('my_model.keras')  # 保存模型

7.加载和应用已训练的模型

加载保存的模型,并用其对新数据进行预测。model.predict() 方法返回的是预测的概率值,我们通过设置阈值(如 0.9)将其转换为类别(0 或 1)。

model = tf.keras.models.load_model('my_model.keras')  # 加载模型
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])  # 新的输入数据
predictions = model.predict(nx)  # 获取预测结果
print(predictions)  # 输出概率

# 将概率转化为类别
predicted_classes = (predictions > 0.9).astype(int)
print(predicted_classes)  # 输出最终的类别预测

8.完整代码
test.py 训练模型

import tensorflow as tf
import numpy as np
# 创建int32类型的0维张量,即标量
l1=tf.keras.layers.Dense(units=3,activation='sigmoid')
l2=tf.keras.layers.Dense(units=1,activation='sigmoid')
model=tf.keras.Sequential([l1,l2])
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
x=np.array([[1,1],[1,-1],[-1,1],[-1,-1],[0.7,0.7],[0.7,-0.7],[-0.7,-0.7],[-0.7,0.7]])
y=np.array([1,1,1,1,0,0,0,0])
model.fit(x,y,epochs=2000)
# 保存训练好的模型(Keras 格式)
model.save('my_model.keras')

 test2.py加载模型并进行预测:

import tensorflow as tf
import numpy as np

# 加载训练好的模型
model = tf.keras.models.load_model('my_model.keras')

# 预测数据
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])

# 获取预测结果
predictions = model.predict(nx)

# 输出预测结果
print(predictions)

# 如果需要将概率转化为类别(0或1)
predicted_classes = (predictions > 0.9).astype(int)

# 输出最终的类别预测
print(predicted_classes)

9.视频分享


初识TensorFlow 
https://v.douyin.com/ifG2mmLH/
复制此链接,打开Dou音搜索,直接观看视频!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2289504.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

docker安装Redis:docker离线安装Redis、docker在线安装Redis、Redis镜像下载、Redis配置、Redis命令

一、镜像下载 1、在线下载 在一台能连外网的linux上执行docker镜像拉取命令 docker pull redis:7.4.0 2、离线包下载 两种方式: 方式一: -)在一台能连外网的linux上安装docker执行第一步的命令下载镜像 -)导出 # 导出镜像…

Retrieval-Augmented Generation for Large Language Models: A Survey——(1)Overview

Retrieval-Augmented Generation for Large Language Models: A Survey——(1)Overview 文章目录 Retrieval-Augmented Generation for Large Language Models: A Survey——(1)Overview1. Introduction&Abstract1. LLM面临的问题2. RAG核心三要素3. RAG taxonomy 2. Overv…

LabVIEW透镜多参数自动检测系统

在现代制造业中,提升产品质量检测的自动化水平是提高生产效率和准确性的关键。本文介绍了一个基于LabVIEW的透镜多参数自动检测系统,该系统能够在单一工位上完成透镜的多项质量参数检测,并实现透镜的自动搬运与分选,极大地提升了检…

什么是Maxscript?为什么要学习Maxscript?

MAXScript是Autodesk 3ds Max的内置脚本语言,它是一种与3dsMax对话并使3dsMax执行某些操作的编程语言。它是一种脚本语言,这意味着您不需要编译代码即可运行。通过使用一系列基于文本的命令而不是使用UI操作,您可以完成许多使用UI操作无法完成的任务。 Maxscript是一种专有…

Redis|前言

文章目录 什么是 Redis?Redis 主流功能与应用 什么是 Redis? Redis,Remote Dictionary Server(远程字典服务器)。Redis 是完全开源的,使用 ANSIC 语言编写,遵守 BSD 协议,是一个高性…

LeetCode:63. 不同路径 II

跟着carl学算法,本系列博客仅做个人记录,建议大家都去看carl本人的博客,写的真的很好的! 代码随想录 LeetCode:63. 不同路径 II 给定一个 m x n 的整数数组 grid。一个机器人初始位于 左上角(即 grid[0][0]…

Redis-布隆过滤器

文章目录 布隆过滤器的特点:实践布隆过滤器应用 布隆过滤器的特点: 就可以把布隆过滤器理解为一个set集合,我们可以通过add往里面添加元素,通过contains来判断是否包含某个元素。 布隆过滤器是一个很长的二进制向量和一系列随机映射函数。 可以用来检索…

【视频+图文详解】HTML基础3-html常用标签

图文教程 html常用标签 常用标签 1. 文档结构 <!DOCTYPE html>&#xff1a;声明HTML文档类型。<html>&#xff1a;定义HTML文档的根元素。<head>&#xff1a;定义文档头部&#xff0c;包含元数据。<title>&#xff1a;设置网页标题&#xff0c;浏览…

【B站保姆级视频教程:Jetson配置YOLOv11环境(五)Miniconda安装与配置】

Jetson配置YOLOv11环境&#xff08;5&#xff09;Miniconda安装与配置 文章目录 0. Anaconda vs Miniconda in Jetson1. 下载Miniconda32. 安装Miniconda33. 换源3.1 conda 换源3.2 pip 换源 4. 创建环境5. 设置默认启动环境 0. Anaconda vs Miniconda in Jetson Jetson 设备资…

【PLL】杂散生成和调制

时钟生成 --》 数字系统 --》峰值抖动频率生成 --》无线系统 --》 频谱纯度、 周期信号的相位不确定性 随机抖动&#xff08;random jitter, RJ&#xff09;确定性抖动&#xff08;deterministic jitter,DJ&#xff09; 时域频域随机抖动积分相位噪声确定性抖动边带 杂散生成和…

游戏引擎 Unity - Unity 启动(下载 Unity Editor、生成 Unity Personal Edition 许可证)

Unity Unity 首次发布于 2005 年&#xff0c;属于 Unity Technologies Unity 使用的开发技术有&#xff1a;C# Unity 的适用平台&#xff1a;PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域&#xff1a;开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

【C++动态规划 离散化】1626. 无矛盾的最佳球队|2027

本文涉及知识点 C动态规划 离散化 LeetCode1626. 无矛盾的最佳球队 假设你是球队的经理。对于即将到来的锦标赛&#xff0c;你想组合一支总体得分最高的球队。球队的得分是球队中所有球员的分数 总和 。 然而&#xff0c;球队中的矛盾会限制球员的发挥&#xff0c;所以必须选…

python-leetcode-从中序与后序遍历序列构造二叉树

106. 从中序与后序遍历序列构造二叉树 - 力扣&#xff08;LeetCode&#xff09; # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right r…

Java实战:图像浏览器

文章目录 1. 实战概述2. 知识准备3. 实现步骤3.1 创建Java项目3.2 创建图像浏览器类3.2.1 声明变量与常量3.2.2 创建构造方法3.2.3 创建初始化界面方法3.2.4 创建处理事件方法3.2.5 创建主方法3.2.6 查看完整代码 3.3 运行程序&#xff0c;查看结果 4. 实战小结5. 扩展练习 1. …

I.MX6ULL 中断介绍上

i.MX6ULL是NXP&#xff08;原Freescale&#xff09;推出的一款基于ARM Cortex-A7内核的微处理器&#xff0c;广泛应用于嵌入式系统。在i.MX6ULL中&#xff0c;中断&#xff08;Interrupt&#xff09;是一种重要的机制&#xff0c;用于处理外部或内部事件&#xff0c;允许微处理…

(即插即用模块-特征处理部分) 十九、(NeurIPS 2023) Prompt Block 提示生成 / 交互模块

文章目录 1、Prompt Block2、代码实现 paper&#xff1a;PromptIR: Prompting for All-in-One Blind Image Restoration Code&#xff1a;https://github.com/va1shn9v/PromptIR 1、Prompt Block 在解决现有图像恢复模型时&#xff0c;现有研究存在一些局限性&#xff1a; 现有…

MySQL数据库(二)- SQL

目录 ​编辑 一 DDL (一 数据库操作 1 查询-数据库&#xff08;所有/当前&#xff09; 2 创建-数据库 3 删除-数据库 4 使用-数据库 (二 表操作 1 创建-表结构 2 查询-所有表结构名称 3 查询-表结构内容 4 查询-建表语句 5 添加-字段名数据类型 6 修改-字段数据类…

数据分析系列--⑦RapidMiner模型评价(基于泰坦尼克号案例含数据集)

一、前提 二、模型评估 1.改造⑥ 2.Cross Validation算子说明 2.1Cross Validation 的作用 2.1.1 模型评估 2.1.2 减少过拟合 2.1.3 数据利用 2.2 Cross Validation 的工作原理 2.2.1 数据分割 2.2.2 迭代训练与测试 ​​​​​​​ 2.2.3 结果汇总 ​​​​​​​ …

gentoo中利用ollama运行DeepSeek-R1

一、安装ollama gentoo linux中 1.安装步骤&#xff1a; Step1. #cd /usr/local/src Step2. #wget2 -o -V https://ollama.com/install.sh Setp3. #sh ./install.sh 2.ollama完成安装。查看ollama版本&#xff1a; 3.查看ollama服务运行状态&#xff1a; 二、安装&#xf…

【NEXT】网络编程——上传文件(不限于jpg/png/pdf/txt/doc等),或请求参数值是file类型时,调用在线服务接口

最近在使用华为AI平台ModelArts训练自己的图像识别模型&#xff0c;并部署了在线服务接口。供给客户端&#xff08;如&#xff1a;鸿蒙APP/元服务&#xff09;调用。 import核心能力&#xff1a; import { http } from kit.NetworkKit; import { fileIo } from kit.CoreFileK…