3 tensorflow构建的模型详解

news2025/1/13 11:42:50

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

这里不赘述太多概念相关的东西,不理解可先跳过,后面例子看懂了在返回看可能就理解了。

2、密集层

再看看上一篇预测西瓜价格的代码

import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')

history = model.fit(weight, total_cost, epochs=500)

# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

其中这行代码:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

意思是构建了一个模型,里面添加了一层神经网络,只有一个神经元。(Sequential是顺序的意思,Dense是密集层。)

密集层(也叫全连接层),在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

3、西瓜预测模型详解

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

下面是实现这个算法的流程:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])


# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)设置训练数据
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')

history = model.fit(weight, total_cost, epochs=5)

# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]

print('w:')
print(w)
print('b: ')
print(b)

# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 


 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1149497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安防监控项目---CGI接口的移植和使用

文章目录 前言一、CGI二、CGI的具体移植步骤2.1 cgi源码下载2.2 搭建交叉编译环境2.3 注意事项 三、测试结果总结 前言 书接上期,上期与大家分享的是boa服务器的移植,那么几天要和大家介绍的呢是一款接口,哈哈哈,用起来也是有点难…

vue使用百度富文本

🔥博客主页: 破浪前进 🔖系列专栏: Vue、React、PHP ❤️感谢大家点赞👍收藏⭐评论✍️ 1、下载UEditor 链接已放到文章中了 2、上传到项目目录中 一般上传到public下,方便到时候打包进去,以免…

骨传导耳机怎么佩戴,骨传导蓝牙耳机什么牌子好用

市面上的传统耳机一直以来都存在一些问题,比如长时间佩戴会导致耳朵不适,或者声音过大可能会伤害到耳膜。但是,现在有一种独特的耳机正在迅速走红,它被称为骨传导耳机,而骨传导耳机是怎么佩戴的呢,它在佩戴…

Ionic 7 版本发布 - 免费开源、超受欢迎的移动应用开发 UI 工具包,主题优雅且完美支持 Vue.js

Ionic 是一款优秀的移动 UI 框架,迭代也很快,现在也支持了 Vue,是时候向大家推荐用来开发 APP 了。 Ionic 全称是 Ionic Framework,是一个功能强大的开源 UI 工具库,用来帮助前端开发者构建跨平台的移动应用。 Ionic …

干洗店小程序上门洗鞋店管理软件功能介绍;

干洗店小程序上门洗鞋店管理软件功能介绍; 营销工具-洗鞋店管理软件多渠道玩法,拓客留客 支付-会员管理系统多种支付方式,灵活经营 ​ ​提供洗鞋店管理软件服务,实现会员精细化运营 会员档案-洗鞋店管理软件记录会员的全方位信…

Pytorch 猫狗识别案例

猫狗识别数据集https://download.csdn.net/download/Victor_Li_/88483483?spm=1001.2014.3001.5501 训练集图片路径 测试集图片路径 训练代码如下 import torch import torchvision import matplotlib.pyplot as plt import torchvision.models as models import torch.nn a…

IntelliJ IDEA 安装mybaits当前运行sql日志插件在线与离线安装方法

先安装好idear 去网上找找这个安装包下载下来,注意版本要完全一致! 比如: https://www.onlinedown.net/soft/1233409.htm手动安装离线插件方法举例 提前下载好插件的安装包 可以去网上下载这个安装包 搜索离线安装包的资源,包…

【文末送书】AI时代数据的重要性

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab,机器人运动控制、多机器人协作,智能优化算法,滤波估计、多传感器信息融合,机器学习,人工智能等相关领域的知识和技术。关…

亚马逊美国加拿大电动移动设备合规标准是什么?如何办理?

亚马逊美国站电动移动设备合规标准是什么? 加拿大站电动移动设备合规标准 办理流程: 1.填写申请表 2.提供产品的资料(说明书,电路原理图,如是多个型号的,提供型号差异列表) 3.寄样 4.测试 …

电商生态圈:跨境电商的商业合作新模式

随着数字化浪潮的不断崛起,电子商务行业正经历着前所未有的革命性变革。在这个变革的过程中,跨境电商已经成为全球贸易的推动力量。然而,跨境电商并非孤立存在,而是在日益扩大的电商生态圈内蓬勃发展。本文将探讨跨境电商的商业合…

AT8548 双通道有刷直流电机驱动芯片的作用

AT8548为玩具、打印机和其它机电应用提供一种双通道电机驱动方案。亿胜盈科AT8548内置两路H桥驱动,可以驱动两个直流有刷电机,或者通过输出并接驱动一个直流有刷电机,或者一个双J步进电机,或者螺线管及其它感性负载。 亿胜盈科AT8…

计算机考研 | 2011年 | 计算机组成原理真题

文章目录 【计算机组成原理2011年真题43题-11分】【第一步:信息提取】【第二步:具体解答】 【计算机组成原理2011年真题44题-12分】【第一步:信息提取】【第二步:具体解答】 【计算机组成原理2011年真题43题-11分】 (1…

【23真题】大神凭这套拿452分!看看你能拿多少?

今天分享的是23年福州大学866的信号与系统试题及解析。23年福州大学新一代电子信息的最高分是452分!但是我看不到单科分数。按照75,75,150,150。也就是只有450,说明这个同学,专业课和数学几乎拿满&#xff…

【设计模式】第17节:行为型模式之“解释器模式”

一、简介 解释器模式为某个语言定义它的语法(或者叫文法)表示,并定义一个解释器用来处理这个语法。 二、适用场景 领域特定语言复杂输入解释可扩展的语言结构 三、UML类图 四、案例 对输入的特定格式的打印语句进行解析并执行。 packag…

3D模型格式转换工具HOOPS Exchange:更快、更准确的CAD数据转换工具

HOOPS Exchange是一个开发平台,可以帮助快速开发高性能,跨平台的工程应用程序,是一款更快、更准确的CAD数据转换工具包,是3D数据格式转换首选解决方案。 ▷ 工业级3D数据格式转换 通过单个界面即可读取和写入30多种CAD文件格式&…

虚拟人裸眼3D动画宣传片:品牌营销的流量密码

在数字化转型的大背景下,行业竞争越来越激烈,品牌迫切需要一种新颖的、差异化的宣传片方式提升流量。而依靠户外大屏播放的虚拟人裸眼3D动画宣传片,具有强地标性和网红属性,成为推动文旅、城市营销、品牌营销的重要渠道。 虚拟人裸…

PWA 是属于谷歌的“小程序”!有哪些核心技术

在国内由于小程序的风生水起,PWA 应用在国内的状况一直都不是很好,PWA 和小程序有很多的相似性,但是 PWA 是由谷歌发起的技术,小程序是微信发起的技术,所以小程序在国内得到了大力的扶持,很快就在国内技术界…

Day 3 登录页以及路由 (一)

登录页以及路由 需求 作为一个后台管理系统,登录页是必不可少的。登录页的需求也很简单,输入账号密码,有登录、重置按钮即可。主要界面类似这种: 登录提交到后台,校验成功后,跳转到系统主页。 另外一个需…

Spring MVC的常用注解(接收请求数据篇)

目录 RequestMapping 例子: RequestMapping 支持什么类型的请求 使 RequestMapping 只支持特定的类型 RestController 通过 HTTP 请求传递参数给后端 1.传递单个参数 注意使⽤基本类型来接收参数的情况 2.传递多个参数 3.传递对象 4.RequestParam 后端参数…

分布式:一文吃透分布式锁,Redis/Zookeeper/MySQL实现

目录 一、项目准备spring项目数据库 二、传统锁演示超卖现象使用JVM锁解决超卖解决方案JVM失效场景 使用一个SQL解决超卖使用mysql悲观锁解决超卖使用mysql乐观锁解决超卖四种锁比较Redis乐观锁集成Redis超卖现象redis乐观锁解决超卖 三、分布式锁概述四、Redis分布式锁实现方案…