3 tensorflow构建模型详解

news2025/1/10 17:20:27

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成,每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
神经网络是一种广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所做出的交互反应。

2、密集层

在上一篇我们创建了预测价格模型,其中代码为:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

意思是构建了一个模型,里面添加了一层神经网络,只有一个神经元。(Sequential是顺序的意思,Dense是密集层。)

密集层(也叫全连接层),在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

3、西瓜费用预测模型详解

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

代码如下:

import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')

history = model.fit(weight, total_cost, epochs=500)

# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

下面是代码详解:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])


# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)设置训练数据
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')

history = model.fit(weight, total_cost, epochs=5)

# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]

print('w:')
print(w)
print('b: ')
print(b)

# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 

4、创建更复杂一点的模型

现实生活中我们要预测的东西影响因素可能有很多个,如房价预测,房价可能受到房屋面积、房间数量等等因素影响。思考一下,下面的神经网络图创建模型时要如何设置参数呢?

model = tf.keras.Sequential([
    tf.keras.layers.Dense(2, input_shape=[3]),
    tf.keras.layers.Dense(1)
])


 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1152276.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

音视频技术开发周刊 | 317

每周一期,纵览音视频技术领域的干货。 新闻投稿:contributelivevideostack.com。 MIT惊人再证大语言模型是世界模型!LLM能分清真理和谎言,还能被人类洗脑 MIT等学者的「世界模型」第二弹来了!这次,他们证明…

什么是Steam红锁?及红锁的原因

Steam红锁分为两种,一种是商业红,一种是欺诈红。 造成红锁的原因有哪些? 1.非正常玩家,大量囤货,就是你交易饰品的交易量太大了,而且频繁地买进同一个饰品,官方就会判定你是商业行为&#xff0…

批量重命名文件夹:用数字随机重命名法管理您的文件夹

在文件管理中,文件夹的命名是一项至关重要的任务。一个好的文件夹命名方案可以帮助我们更高效地组织和查找文件。然而,随着时间的推移,我们可能会遇到文件夹数量过多,难以管理和查找的问题。为了解决这个问题,我们可以…

ubuntu PX4 vscode stlink debug设置

硬件 stlink holybro debug板 pixhawk4 安装openocd 官方文档,但是第一步安装建议从源码安装,bug少很多 github链接 编译安装,参考 ./bootstrap (when building from the git repository)./configure [options]makesudo make install安装后…

SpringMVC Day 06 : 转发视图

前言 在SpringMVC框架中,视图解析器可以将逻辑视图名称转换为实际的视图对象。除了直接渲染视图,你还可以通过SpringMVC提供的转发和重定向机制来跳转到另一个视图。在本篇博客中,我们将学习SpringMVC中的转发视图技术,以及如何使…

sscanf 函数的使用

一、sscanf 函数介绍 头文件 #include <stdio.h> 原型&#xff1a; int sscanf(const char *str, const char *format, ...); 返回&#xff1a; On success, these functions return the number of input items success‐ fully matched and assigned; this can be few…

钡铼技术ARM工控机在机器人控制领域的应用

ARM工控机是一种基于ARM架构的工业控制计算机&#xff0c;用于在工业自动化领域中进行数据采集、监控、控制和通信等应用。ARM&#xff08;Advanced RISC Machine&#xff09;架构是一种低功耗、高性能的处理器架构&#xff0c;广泛应用于移动设备、嵌入式系统和物联网等领域。…

如何使用内网穿透工具,将Tomcat网页发布到公共互联网上

文章目录 前言1.本地Tomcat网页搭建1.1 Tomcat安装1.2 配置环境变量1.3 环境配置1.4 Tomcat运行测试1.5 Cpolar安装和注册 2.本地网页发布2.1.Cpolar云端设置2.2 Cpolar本地设置 3.公网访问测试4.结语 前言 Tomcat作为一个轻量级的服务器&#xff0c;不仅名字很有趣&#xff0…

API安全之《大话:API的前世今生》

写在前面&#xff1a;本文结合API使用的业界现状&#xff0c;系统性地阐述API的基本概念、发展历史、表现形式等基础内容&#xff0c;主要包含以下内容&#xff1a; 1.什么是API 2.API的发展历史 3.现代API常用消息格式 4.top N 互联网企业API 使用现状 当前的世界是一个信…

【Go入门】GO流程与函数介绍(代码运行逻辑控制)

流程和函数 这小节我们要介绍Go里面的流程控制以及函数操作。 流程控制 流程控制在编程语言中是最伟大的发明了&#xff0c;因为有了它&#xff0c;你可以通过很简单的流程描述来表达很复杂的逻辑。Go中流程控制分三大类&#xff1a;条件判断&#xff0c;循环控制和无条件跳…

回文链表Java

我们可以采用双指针的办法进行,如下图: 如果链表长度为偶数,则直接从第二个指针的位置开始对链表进行反转;如果是奇数,则从第二指针的下一位进行链表反转 代码实现: public static void main(String[] args) {ListNode next4 new ListNode(1, null);ListNode next3 new Lis…

Jetson Xavier NX FFmpeg支持硬件编解码

最近在用Jetson Xavier NX板子做视频处理&#xff0c;但是CPU进行视频编解码&#xff0c;效率比较地下。 于是便考虑用硬解码来对视频进行处理。 通过jtop查看&#xff0c;发现板子是支持 NVENC硬件编解码的。 1、下载源码 因为需要对ffmpeg进行打补丁修改&#xff0c;因此需…

堆栈与队列算法-以数组来实现堆栈

目录 堆栈与队列算法-以数组来实现堆栈 C代码 扑克牌发牌算法 C代码 堆栈与队列算法-以数组来实现堆栈 以数组结构来实现堆栈的好处是设计的算法都相当简单。不过&#xff0c;如果堆栈本身的大小是变动的&#xff0c;而数组大小只能事先规划和声明好&#xff0c;那么数组规…

Qt QUrl详解

1.QUrl概述 QUrl 是Qt框架中用于处理URL的类&#xff0c;提供了一些方法来解析和构造URL。URL&#xff08;Uniform Resource Locator&#xff09;是用于定位和访问互联网资源的地址。QUrl类可以用于解析URL的各个部分&#xff0c;并提供了一些方法来获取和设置URL的各个部分。…

Android RecyclerView — 实现自动加载更多

在App中&#xff0c;使用列表来显示数据是十分常见的。使用列表来展示数据&#xff0c;最好不要一次加载太多的数据&#xff0c;特别是带图片时&#xff0c;页面渲染的时间会变长&#xff0c;常见的做法是进行分页加载。本文介绍一种无感实现自动加载更多的实现方式。 实现自动…

Windows原生蓝牙编程 第三章 配对后进行蓝牙通信【C++】

蓝牙系列文章目录 第一章 获取本地蓝牙并扫描周围蓝牙信息并输出 第二章 选取设备输入配对码并配对 第三章 配对后进行蓝牙通信 文章目录 前言头文件一、建立连接套接字二、设置发送信息函数三、全部代码四、测试服务端选择及蓝牙通信总结 前言 接着第二章&#xff0c;我们已经…

爱写bug的小邓程序员个人博客

博客网址: http://www.006969.xyz 欢迎来到我的个人博客&#xff0c;这里主要分享我对于前后端相关技术的学习笔记、项目实战经验以及一些技术感悟。 在我的博客中&#xff0c;你将看到以下主要内容&#xff1a; 技术文章 我将会分享我在学习前后端技术过程中的一些感悟&am…

【并发编程】进程与线程

主要知识点&#xff1a; 进程和线程的概念 并行和并发的概念 线程基本应用 一、进程与线程 进程 程序由指令和数据组成&#xff0c;但这些指令要运行&#xff0c;数据要读写&#xff0c;就必须将指令加载至 CPU&#xff0c;数据加载至内存。在指令运行过程中还需要用到磁盘、…

【SpringMVC篇】讲解RESTful相关知识

&#x1f38a;专栏【SpringMVC】 &#x1f354;喜欢的诗句&#xff1a;天行健&#xff0c;君子以自强不息。 &#x1f386;音乐分享【如愿】 &#x1f384;欢迎并且感谢大家指出小吉的问题&#x1f970; 文章目录 &#x1f384;REST简介&#x1f33a;RESTful入门案例⭐案例一⭐…

【Java笔试强训】Day7(WY22 Fibonacci数列、CM46 合法括号序列判断)

Fibonacci数列 链接&#xff1a;Fibonacci数列 题目&#xff1a; Fibonacci数列是这样定义的&#xff1a; F[0] 0 F[1] 1 for each i ≥ 2: F[i] F[i-1] F[i-2] 因此&#xff0c;Fibonacci数列就形如&#xff1a;0, 1, 1, 2, 3, 5, 8, 13, …&#xff0c;在Fibonacci数列…