吴恩达deeplearning.ai:Tensorflow训练一个神经网络

news2025/1/20 20:09:25

以下内容有任何不理解可以翻看我之前的博客哦:吴恩达deeplearning.ai
在之前的博客中。我们陆续学习了各个方面的有关深度学习的内容,今天可以从头开始训练一个神经网络了。

Tensorflow训练神经网络模型

我们使用之前用过的例子:
在这里插入图片描述
这个神经网络有三层,第一层拥有25个神经元,第二层15个神经元,第三层为最终输出层。
现在提供一个训练集X,一个标签Y,该如何通过代码的形式来表现呢?

#1导入工具包
import tensrflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

#2创建三个层并让Tensorflow按照顺序将几个层串联起来
  model = Sequential([
    Dense(units = 25, activation = 'sigmoid')
    Dense(units = 15, activation = 'sigmoid')
    Dense(units = 1, activation = 'sigmoid')
                     ])
 #3引入工具包,并且让损失函数使用分类交叉熵的形式
from tensorflow.keras.losses import
BinaryCrossentropy
  model.compile(loss = BinaryCrossentropy())

#调用拟合函数,epoch代表训练次数
  model.fit(X, Y, epochs=100)

模型中的一些细节讲解

框架相关

让我们先复习一下之前的内容,如何实现逻辑回归的:
第一步,如何在给定输入特征X和参数W,b的情况下计算输出(定义模型),我们这里经常使用的是sigmoid函数。
第二步,指定损失函数与成本函数
第三步,训练模型,最小化J(w,b)
让我们在训练神经网络的背景下来看看这几步:

#1导入工具包
import tensrflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

#2创建三个层并让Tensorflow按照顺序将几个层串联起来
  model = Sequential([
    Dense(units = 25, activation = 'sigmoid')
    Dense(units = 15, activation = 'sigmoid')
    Dense(units = 1, activation = 'sigmoid')
                     ])

这几段代码说明了神经网络的整个架构体系,告诉你第一层有25个神经元,第二层有15个神经元,第三层一个,采用的激活函数均为sigmoid。

损失函数相关

再写一遍 损失函数的一般数学表达式:
J ( W , B ) = 1 m ∑ L ( f ( x ( i ) , y ( i ) ) J(W,B) = \frac{1}{m}\sum L(f(x^{(i)},y^{(i)}) J(W,B)=m1L(fx(i),y(i))

 #3引入工具包,并且让损失函数使用分类交叉熵的形式
from tensorflow.keras.losses import
BinaryCrossentropy
  model.compile(loss = BinaryCrossentropy())

这个名叫keras的工具包其实是和tensorflow是完全不同的两个项目开发的,只是最后合入了tensorflow,所有它的工具包需要你单独import。另外,由于工具包的种类真的很多,所以不知道工具包的名字和使用方法时可以上网查找哦。
我们在之前的博客中,曾经学习过二元交叉熵(这是统计学上的叫法),二元的意思是说明这是个布尔值,要么为1要么为0.只是在之前的博客中不叫这个名字,而是为了能够在一个式子之中写出价代价函数:
L ( f ( x ) , y ) = − y l o g ( f ( x ) ) − ( 1 − y ) l o g ( ( 1 − f ( x ) ) L(f(x),y) = -ylog(f(x)) - (1-y)log((1-f(x)) L(f(x),y)=ylog(f(x))(1y)log((1f(x))
在制定了损失函数之后,Tensorflow就知道了你是希望最小化m个训练的平均值。
如果你是想解决其它类型的问题例如回归问题,你可以给tensorflow指定其它种类的损失函数:

from tensorflow.keras.losses import MeanSquareError
model.compile(loss = MeanSquareError())

这是最小化平方误差损失的损失函数。

梯度下降

梯度下降时,你需要重复公式:
w = w − α ∂ ∂ w j J ( w , b ) b = b − α ∂ ∂ b j J ( w , b ) w = w - \alpha\frac{\partial}{\partial w_j}J(w,b)\\ b = b - \alpha\frac{\partial}{\partial b_j}J(w,b) w=wαwjJ(w,b)b=bαbjJ(w,b)

#调用拟合函数,epoch代表训练次数
  model.fit(X, Y, epochs=100)

Tensorflow使用的是一种叫做反向传播的算法来计算这些偏导数项,只是在函数model.fit中完成的,并告诉它这样迭代100次。

很明显我们现在的代码严重依赖于Tensorflow库,随着技术的发展,大部分工程师都会使用库而非自己重头编起。现在你已经了解了如何自己训练一个神经网络了,在接下来的博客中我们讲讲到一些你可以改变的地方,使得你的神经网络更加强大。
为了给读者你造成不必要的麻烦,博主的所有视频都没开仅粉丝可见,如果想要阅读我的其他博客,可以点个小小的关注哦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1469702.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件开发的艺术与科学

随着科技的飞速发展,软件开发已成为当今社会不可或缺的一部分。从智能手机应用程序到企业级管理系统,软件开发已经渗透到我们生活的方方面面。本文将探讨软件开发的重要性和现状,以及开发过程中涉及的关键环节和常见问题。 一、软件开发的重…

leetcode:491.递增子序列

1.误区:不能直接对数组排序再求解子集,因为那样就改变了原有数组的顺序 2.树形结构:一个一个取数,然后保证是递增序列,且不能重复。(数层上不可以重复取,树枝上可以重复取)收集的结…

Android BitmapDrawable.bitmap与BitmapFactory.decodeResource获取不到原始图像素级真实宽高,Kotlin

Android BitmapDrawable.bitmap与BitmapFactory.decodeResource获取不到原始图像素级真实宽高,Kotlin 当一个图片放在ImageView里面后,用以下方式获取图的宽高: val bmp1 (this.drawable as BitmapDrawable).bitmapLog.d("fly", &…

SpringBoot实现缓存预热方案

缓存预热是指在 Spring Boot 项目启动时,预先将数据加载到缓存系统(如 Redis)中的一种机制。 那么问题来了,在 Spring Boot 项目启动之后,在什么时候?在哪里可以将数据加载到缓存系统呢? 实现方案概述 在 Spring Boot 启动之后,可以通过以下手段实现缓存预热: 使用…

蓝桥杯《修剪灌木》

题目描述 爱丽丝要完成一项修剪灌木的工作。有 N 棵灌木整齐的从左到右排成一排。爱丽丝在每天傍晚会修剪一棵灌木,让灌木的高度变为 0 厘米。爱丽丝修剪灌木的顺序是从最左侧的灌木开始,每天向右修剪一棵灌木。当修剪了最右侧的灌木后,她会…

#FPGA(基础知识)

1.IDE:Quartus II 2.设备:Cyclone II EP2C8Q208C8N 3.实验:正点原子-verilog基础知识 4.时序图: 5.步骤 6.代码:

Java 存图方式

图最常见的两种存储方式是邻接表和邻接矩阵。 链式前向星其实就是静态建立的邻接表,时间效率为 O(n),空间效率也为 O(n)。遍历效率也为 O(n)。 一、邻接表 邻接表存储方式适合存储边稀疏的图,判断两点之间是否有边不方便; 邻接矩阵适合存储边稠密的,判断边和权值都很方…

如何使用移动端设备在公网环境远程访问本地黑群晖

文章目录 前言本教程解决的问题是:按照本教程方法操作后,达到的效果是前排提醒: 1. 搭建群晖虚拟机1.1 下载黑群晖文件vmvare虚拟机安装包1.2 安装VMware虚拟机:1.3 解压黑群晖虚拟机文件1.4 虚拟机初始化1.5 没有搜索到黑群晖的解…

使用Django的admin功能管理数据_vscode

之前的文章 项目 hello_django, app名 hello,已有的model LogMessage: https://blog.csdn.net/weixin_44741835/article/details/136202771?spm1001.2014.3001.5502 参考得到电子书:第八章。 https://www.dedao.cn/ebook/reader?idrEQKv6…

Windows上基于名称快速定位文件和文件夹的免费工具Everything

在Windows上搜索文件时,使用windows上内置搜索会很慢,这里推荐使用Everything工具进行搜索。 "Everything"是Windows上一款搜索引擎,它能够基于文件名快速定位文件和文件夹位置。不像Windows内置搜索,"Everything&…

好用的伪原创工具有哪些?

伪原创工具哪个好用?在互联网时代,内容创作是一项至关重要的工作。然而,随着信息爆炸式增长,内容创作者们往往面临着时间和灵感的压力。为了解决这一难题,越来越多的人开始寻找伪原创工具,这些工具可以帮助…

32单片机基础:对射式红外传感器计次

接线如下图: 在HardWare建立两个文件:如图 COuntSensor.c 如何配置外部中断,根据下面图,我们需要把外部中断从GPIO到NVIC这一路出现的外设模块都配置好。把这条信号打通就OK了。 1.配置RCC:把我们这里涉及的外设时钟都打开,不打…

用什么软件制作电子杂志

想要制作高大上的电子杂志?别再烦恼啦!今天给大家推荐一款超级实用的软件,让你轻松制作出专业水准的电子杂志! 这款软件功能强大,操作简单,适合所有对设计感兴趣的小伙伴们。无论是新手还是专业设计师&…

20.scala视图界定

目录 概述实践代码执行 结束 概述 scala 中的视图界定 实践 代码 /*** 视图界定*/ object Genericity03 {def main(args: Array[String]): Unit {println(new MaxInt(1,2).compare)println(new MaxLong(1L,2L).compare)// 不行 // println(new MaxValue(1,2).compare)// …

[c++] 深拷贝和浅拷贝,拷贝构造、赋值运算符

1 拷贝构造和赋值运算符 1.1 拷贝构造 拷贝构造在如下场景会被调用: (1)函数调用时,函数参数是对象的值传递 (2)声明对象同时初始化的时候(而不是声明和初始化分开,因为声明的时候就创建了对…

游戏配置内存“瘦身”策略

背景 游戏配置数据绝对是游戏服务器进程的内存大头,有些游戏服务器单纯数据配置的容量就超过一个G。因此,这部分内存优化也就放在首要位置了。 优化策略 在《服务器进程如何降低内存》一文中,我们讲述了可以通过“优化游戏配置缓存”来降低游戏服务器进程的内存使用量。本…

【电子通识】认识FMEA(失效模式和影响分析)

FMEA是Failure Mode and Effect Analysis的英文缩写,中文名称为失效模式和影响分析。主要应用于航空航天、食品、汽车和核电等行业。 FMEA讨论的是事先策划以及执行措施,预防问题的发生或控制问题的发展,降低设计和过程的风险。由于问题还没…

C语言------操作符的巧妙使用

1.计算一个数字二进制补码里面1的个数 (1)方法一 根据这个10进制的整数,对这个数进行%10,/10不断地进行下去, %10得到最后一位,/10得到舍去最后一位之后剩余的数; 同理得到:二进…

深入理解 CSS 定位与布局高级技巧

更多web开发知识欢迎访问我的专栏>>> CSS高级 目标:掌握定位的作用及特点;掌握 CSS 高级技巧 01-定位 作用:灵活的改变盒子在网页中的位置 实现: 1.定位模式:position 2.边偏移:设置盒子的位…

构造百万测试数据五大方法!

在测试的工作过程中,很多场景是需要构造一些数据在项目里的,方便测试工作的进行。比如下面的场景: 项目需要做性能测试,需要大量的数据就算是功能测试,比如测试搜索功能,需要有数据做搜索测试需要检查数据…