【Tensorflow学习二】神经网络优化方法学习率、激活函数、损失函数、正则化

news2024/9/24 5:33:28

文章目录

  • 预备知识
    • tf.where
    • tf.random.RandomState.rand() 返回一个[0,1)之间的随机数
    • np.vstack() 将数组按照垂直方向叠加
    • np.mgrid[ ] np.ravel( ) np.c_[ ] 一起使用可以生成网格坐标点
  • 复杂度、学习率
    • 复杂度
    • 指数衰减学习率
  • 激活函数
    • Sigmoid激活函数
    • Tanh激活函数
    • ReLu激活函数
    • Leaky ReLu激活函数
  • 损失函数
    • 均方误差
    • 交叉熵损失函数
    • 自定义损失函数
    • Softmax与交叉熵结合
  • 欠拟合与过拟合
    • 正则化缓解过拟合

预备知识

tf.where

#条件语句真返回A,条件语句假返回B
#tf.where(条件语句,真返回A,假返回B)

import tensorflow as tf

a=tf.constant([1,2,3,1,1])
b=tf.constant([0,1,3,4,5])
c=tf.where(tf.greater(a,b),a,b)#若a>b,返回a对应位置的元素,否则返回b对应位置的元素
print("c:",c)

>>>
c: tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

tf.random.RandomState.rand() 返回一个[0,1)之间的随机数

#返回一个[0,1)之间的随机数
#np.random.RandomState.rand(维度)#维度为空,返回标量
import numpy as np
rdm=np.random.RandomState(seed=1)#seed=常数,每次生成的随机数相同
a=rdm.rand()#返回一个随机标量
b=rdm.rand(2,3)#返回维度为2行3列的随机矩阵

print("a:",a)
print("b:",b)
>>>
a: 0.417022004702574
b: [[7.20324493e-01 1.14374817e-04 3.02332573e-01]
 [1.46755891e-01 9.23385948e-02 1.86260211e-01]]

np.vstack() 将数组按照垂直方向叠加

import numpy as np
a=np.array([1,2,3])
b=np.array([4,5,6])
c=np.vstack((a,b))

print("c:\n",c)
>>>
c:
 [[1 2 3]
 [4 5 6]]

np.mgrid[ ] np.ravel( ) np.c_[ ] 一起使用可以生成网格坐标点

#np.mgrid[] np.ravel() np.c_[]一起使用可以生成网格坐标点
# np.mgrid[起始值:结束值:步长,起始值:步长,......]
# x.ravel()将x变为一维数组,“把.向量拉直”
# np.c_[数组1,数组2,...]
import numpy as np

x,y=np.mgrid[1:3:1,2:4:0.5]
grid=np.c_[x.ravel(),y.ravel()]
print("x:",x)
print("y:",y)
print("x.ravel():\n", x.ravel())
print("y.ravel():\n", y.ravel())
print("grid:\n",grid)

>>>
x: [[1. 1. 1. 1.]
 [2. 2. 2. 2.]]
y: [[2.  2.5 3.  3.5]
 [2.  2.5 3.  3.5]]
x.ravel():
 [1. 1. 1. 1. 2. 2. 2. 2.]
y.ravel():
 [2.  2.5 3.  3.5 2.  2.5 3.  3.5]
grid:
 [[1.  2. ]
 [1.  2.5]
 [1.  3. ]
 [1.  3.5]
 [2.  2. ]
 [2.  2.5]
 [2.  3. ]
 [2.  3.5]]

复杂度、学习率

复杂度

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1yMGPdwh-1670312759174)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221204210242419.png)]

指数衰减学习率

可以先用较大的学习率,快速得到最优解,然后逐步减小学习率,使得模型在训练后期稳定
指 数 衰 减 学 习 率 = 初 试 学 习 率 ∗ 学 习 率 衰 减 率 当 前 轮 数 / 多 少 轮 衰 减 一 次 指数衰减学习率=初试学习率*学习率衰减率^{当前轮数/多少轮衰减一次} =/
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-T6BsV8Yk-1670312759176)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221204213250545.png)]

激活函数

Sigmoid激活函数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-G30KRZP3-1670312759176)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221204213531093.png)]

特点:

(1)容易造成梯度消失
(2)输出非0均值,收敛慢(我们希望输入每层神经网络的特征是以0为均值的小数值)
(3)幂运算复杂,训练时间长

Tanh激活函数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-78o59Ng1-1670312759177)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221204215210928.png)]

特点:

(1)输出是0均值(由于sigmoid的地方)
(2)易造成梯度消失
(3)幂运算复杂,训练时间长

ReLu激活函数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ULgKb6c2-1670312759177)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221204220112455.png)]

优点:

(1)解决了梯度消失问题(在正区间)
(2)只需要判断输入是否大于0,计算速度快
(3)收敛速度远快于Sigmoid和Tanh

缺点:

(1)输出非0均值,收敛慢
(2)Dead ReLu问题:某些神经元可能永远无法被激活,导致相应的参数无法被更新(可以改变随机初始化,避免过多的负数特征送入relu函数;可以设置更小的学习率,减小参数分布的巨大变化,避免训练中产生过多负数特征)

Leaky ReLu激活函数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZN01M9Tg-1670312759177)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221204221053650.png)]

为了解决ReLu负区间为0,引起神经元死亡问题而设计的

Leaky ReLu在负区间引入了一个固定的斜率a,使得Leaky ReLu负区间不恒等于0

理论上来讲,Leaky ReLuReLu的所有优点,外加不会有Dead Relu问题,但是在实际操作当中,并没有完全证明Leaky ReLu总是好于ReLu

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vLvXikD2-1670312759178)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221204221549300.png)]

损失函数

标签:y_

预测:y

均方误差

均 方 误 差 m s e : M S E ( y ^ , y ) = ∑ i = 1 n ( y − y ^ ) 2 n 均方误差mse:MSE(\hat{y},y)=\frac{\sum_{i=1}^{n}(y-\hat{y})^2}{n} mseMSE(y^,y)=ni=1n(yy^)2

loss_mse=tf.reduce_mean(tf.square(y_-y))

交叉熵损失函数

H = − ∑ y l o g y ^ H=-\sum{}ylog\hat{y} H=ylogy^

tf.losses.categorical_crossentropy(y_-y)

import tensorflow as tf

loss_ce1 = tf.losses.categorical_crossentropy([1, 0], [0.6, 0.4])
loss_ce2 = tf.losses.categorical_crossentropy([1, 0], [0.8, 0.2])
print("loss_ce1:", loss_ce1)
print("loss_ce2:", loss_ce2)
>>>
loss_ce1: tf.Tensor(0.5108256, shape=(), dtype=float32)
loss_ce2: tf.Tensor(0.22314353, shape=(), dtype=float32)

自定义损失函数

可以用tf.where构建损失函数

Softmax与交叉熵结合

分类问题中,输出先经过Softmax函数,再计算yy_的交叉熵损失函数

Tensorflow提供一个函数,将两者结合

tf.nn.softmax_cross_entropy_with_logits(y_,y)

import tensorflow as tf
import numpy as np

y_ = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]])
y = np.array([[12, 3, 2], [3, 10, 1], [1, 2, 5], [4, 6.5, 1.2], [3, 6, 1]])
y_pro = tf.nn.softmax(y)
loss_ce1 = tf.losses.categorical_crossentropy(y_,y_pro)
loss_ce2 = tf.nn.softmax_cross_entropy_with_logits(y_, y)

print('分步计算的结果:\n', loss_ce1)
print('结合计算的结果:\n', loss_ce2)
>>>
分步计算的结果:
 tf.Tensor(
[1.68795487e-04 1.03475622e-03 6.58839038e-02 2.58349207e+00
 5.49852354e-02], shape=(5,), dtype=float64)
结合计算的结果:
 tf.Tensor(
[1.68795487e-04 1.03475622e-03 6.58839038e-02 2.58349207e+00
 5.49852354e-02], shape=(5,), dtype=float64)

欠拟合与过拟合

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vonGk6xd-1670312759178)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221205103952286.png)]

欠拟合的解决方法:

增加输入特征项
增加网络参数
减少正则化参数

过拟合的解决方法:

数据清洗
增大训练集
采用正则化
增大正则化参数

正则化缓解过拟合

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3lJzzHQS-1670312759179)(C:\Users\98306\AppData\Roaming\Typora\typora-user-images\image-20221205105333057.png)]

tf.nn.l2_loss(w)

过拟合情况:

正则化之后:

曲线变得更为平缓,有效缓解了过拟合

下面为L2正则化的例子:

with tf.GradientTape() as tape:  # 记录梯度信息

    h1 = tf.matmul(x_train, w1) + b1  # 记录神经网络乘加运算
    h1 = tf.nn.relu(h1)
    y = tf.matmul(h1, w2) + b2

    # 采用均方误差损失函数mse = mean(sum(y-out)^2)
    loss_mse = tf.reduce_mean(tf.square(y_train - y))
    # 添加l2正则化
    loss_regularization = []
    # tf.nn.l2_loss(w)=sum(w ** 2) / 2
    loss_regularization.append(tf.nn.l2_loss(w1))
    loss_regularization.append(tf.nn.l2_loss(w2))
    # 求和
    # 例:x=tf.constant(([1,1,1],[1,1,1]))
    #   tf.reduce_sum(x)
    # >>>6
    loss_regularization = tf.reduce_sum(loss_regularization)
    loss = loss_mse + 0.03 * loss_regularization  # REGULARIZER = 0.03

# 计算loss对各个参数的梯度
variables = [w1, b1, w2, b2]
grads = tape.gradient(loss, variables)

# 实现梯度更新
# w1 = w1 - lr * w1_grad
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/67168.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

今天面了个00后测试员,让我见识到了内卷届的天花板

深耕IT行业多年,我们发现,对于一个程序员而言,能去到一线互联网公司,会给我们以后的发展带来多大的影响。 很多人想说,这个我也知道,但是进大厂实在是太难了,简历投出去基本石沉大海&#xff0…

Linux安装KVM

一、虚拟化技术 1、全虚拟化和半虚拟化技术 如果给KVM、XEN简单归类的话,KVM是完全虚拟化技术又叫硬件辅助虚拟化技术(Full Virtualization)。相反,XEN是半虚拟化技术(paravirtualization),也叫做准虚拟化…

线上环境内存溢出-OutOfMemoryError

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 线上环境内存溢出-OutofMemoryError前言一、OutOfMemoryError是什么?二、实际情况(参考)解决方案1.实战总结前言 公司线上环境&#xff0…

getBoundingClientRect属性研究

getBoundingClientRect属性研究 概念 getBoundingClientRect 返回 width、height和下图中的6个属性 实测总结: 抓住一个核心点,就是height、width的值: box-sizing 是 content-box时,width和height 内容borderpaddingbox-siz…

国家级专新特精“小巨人”「皖仪科技」携手企企通,打造采购数字化平台成功上线

近日,安徽皖仪科技股份有限公司(以下简称“皖仪科技”)携手企企通共同打造的数字化采购管理系统成功上线。基于皖仪科技的采购业务流程和规则,形成全新的数字化采购体系,在推动企业降本增效的同时,实现企业…

单商户商城系统功能拆解42—应用中心—商城公告

单商户商城系统,也称为B2C自营电商模式单店商城系统。可以快速帮助个人、机构和企业搭建自己的私域交易线上商城。 单商户商城系统完美契合私域流量变现闭环交易使用。通常拥有丰富的营销玩法,例如拼团,秒杀,砍价,包邮…

冠状病毒疾病优化算法 (COVIDOA)附matlab代码

​✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法…

npp各个平台npp数据比较

文章目录1GEE中的npp数据2 与其他数据的比较1GEE中的npp数据 在GEE上查阅npp,可以看到有连个数据集,一个是Terra的,另一个是Aqua的。 我比较了两个的不同,发现Terra是2000-目前的,而Aqua是2002-目前的,都是…

2022吴恩达机器学习课程——第一课

注:参考B站视频教程 视频链接:【(强推|双字)2022吴恩达机器学习Deeplearning.ai课程】 文章目录第一周一、监督学习与无监督学习二、线性回归三、梯度下降第二周一、向量化二、特征缩放第三周一、逻辑回归二、训练逻辑回归模型三、逻辑回归中的梯度下降四…

[运维]如何快速压缩一个数据库的硬盘占用大小(简单粗暴但有效)

文章目录前言一、数据库文件为什么会那么大?1.数据空间2.日志空间3.索引空间4.其他二、我的解决方案总结前言 在维护网站时我们经常会遇到数据库占用服务器磁盘空间的问题。高端的食材往往只需要采用最朴素的烹饪方式。本文我讲一个简单粗暴但有效的方法。本文以Sq…

RabbitMQ快速上手以及RabbitMQ交换机的四种模式

Win10安装&#xff1a; ​win10下安装 RabbitMQ​_柚几哥哥的博客-CSDN博客 Linux安装&#xff1a; Linux下载安装 RabbitMQ​_柚几哥哥的博客-CSDN博客 一、基础使用 1、导入依赖 <!--RabbitMQ--><dependency><groupId>org.springframework.boot</g…

JAVA12_06学习总结(JDBC,工具类优化)

今日内容 1. PreparedStatement PreparedStatement--预编译步骤1)注册驱动2)获取数据库连接对象3)准备sql语句--不需要拼接--需要的参数全部使用 ? 占位符4)通过数据库连接对象,获取预编译对象,同时将sql语句房费数据库,将参数和参数类型都存储在预编译中Connection中的方法…

均匀传输线的串扰和饱和长度

下图为串扰的电路模型&#xff0c;动态线与静态线之间通过互容与互感联系&#xff0c;这样也说明了动态线的信号耦合到静态线上的条件是存在di/dt或者dv/dt时&#xff0c;也就是说只在信号边沿上产生串扰&#xff0c;当电压或者电流为常数的时候静态线上就不会有串扰的信号。 信…

扩散模型:Diffusion models as plug-and-play priors作为即插即用先验的扩散模型

扩散模型&#xff1a;Diffusion models as plug-and-play priors作为即插即用先验的扩散模型0.摘要1.概述2.方法2.1.问题设置2.2.将去噪扩散概率模型作为先验3.实验&#xff1a;图像生成3.1.MNIST的简单说明3.2.使用现成组件条件生成脸部图像4.实验&#xff1a;语义分割附录B&a…

Ubuntu 20.04 系统最快安装WRF软件手册

前言 天气研究和预报&#xff08;WRF&#xff09;模型是一种中尺度数值天气预报系统&#xff0c;在全球范围内用于业务预报和研究目的。 这是在基于Intel的i7&#xff08;12核&#xff09;Linux Ubuntu 20.04 LTS系统上安装WRF 4.2.1的版本。这将有助于初学者在普通台式机上实现…

树莓派4b+mcp2515实现CAN总线通讯和系统编程(一.配置树莓派CAN总线接口)

文章目录前言硬件连线树莓派环境准备启用树莓派ssh启用mcp2515驱动下载can-utils工具测试CAN通讯开启CAN网卡测试发送和接收前言 树莓派本身是没有CAN通讯能力的&#xff0c;但他有mcp2515模块的驱动&#xff0c;可以通过SPI来控制mcp2515进行CAN的通讯。 本章主要讲,如何使能…

基于卡尔曼滤波的二维目标跟踪(Matlab代码实现)

&#x1f352;&#x1f352;&#x1f352;欢迎关注&#x1f308;&#x1f308;&#x1f308; &#x1f4dd;个人主页&#xff1a;我爱Matlab &#x1f44d;点赞➕评论➕收藏 养成习惯&#xff08;一键三连&#xff09;&#x1f33b;&#x1f33b;&#x1f33b; &#x1f34c;希…

双十二选哪个品牌led灯好一点?国产led灯这些品牌护眼好

现在绝大部分人造灯光都是使用led灯珠作为发光源了&#xff0c;所以led灯普遍的质量都比较好&#xff0c;也能护眼&#xff0c;特别是习惯晚上熬夜工作、学习、看书的人群&#xff0c;也都会选择led台灯来辅助照明&#xff0c;因为相比传统的家用室内顶灯&#xff0c;led护眼灯…

【小游戏】Unity游戏愤怒的足球(小鸟)

目录 1.弹弓逻辑 2.鸟的逻辑 3.GameManager主逻辑 文末有源工程地址 难度系数: ★★★★☆ 游戏玩法: 愤怒的足球,其实就是经典的愤怒的小鸟换图 项目简介: 功能完善,主要代码逻辑完整 本文内容: 记录一下这个工程,对内部代码逻辑没有深入了解有待以后发掘 1.弹弓逻…

workerman 聊天demo

1.demo下载 链接: https://pan.baidu.com/s/1MOqcDwvrZGgaYpZUBxxZiA 提取码: 2yqf 2.安装workerman 我这里使用的是tp5框架 下载官方压缩包解压到根目录 3.workerman 数据发送相关类 将worker目录放到项目extend文件夹中 4.启用workerman 登录服务器 linux启动方式&…