tensorflow06——正则化缓解过拟合

news2024/10/6 1:46:16

正则化主要是在损失函数中引入了第二个部分,模型复杂度,具体就是对w参数赋予了权值,并求和,再乘上一个超参数。
(利用给w加上权值,弱化训练数据的噪声)

	大概可以理解为这个意思
	假设模型有两个参数矩阵——w1,w2
	使用L2正则化		  
	
	loss = loss_mse + 超参数*loss_regularization
	# 其中loss_regularization就是对两个参数矩阵正则化后求和
	# loss_regularization=tf.reduce_sum[(tf.nn.l2_loss(w1),tf.nn.l2_loss(w2)]


01本文先介绍了训练集情况
02介绍了几个预备知识点
03进行实例展示(正则化之前,和正则化之后)

训练集dot.csv数据大概如下,可以随机生成
【包含特征值x1,x2和对应的标签y_c,共300组】
在这里插入图片描述

预备知识点

【01】 np.ravel()和np.flatten()

两者的功能是一致的,将多维数组降为一维
两者的区别是返回拷贝还是返回视图
np.flatten()返回一份拷贝,对拷贝所做修改不会影响原始矩阵
np.ravel()返回的是视图,修改时会影响原始矩阵(后面实例使用这个)

【02】np.r_[a,b]和np.c_[a,b]

np.r_是按列连接两个矩阵,要求列数相等。
np.c_是按行连接两个矩阵,要求行数相等。(后面实例使用这个)
(记忆:c是行!)

import numpy as np

a = np.array([[1,2,3],[7,8,9]])
b = np.array([[4,5,6],[1,2,3]])
     
print("a:\n",a) 
print("b:\n",b) 

c=np.c_[a,b]
d=np.r_[a,b]

print("c:\n",c)
print("d:\n",d)    

输出:
a:
[[1 2 3]
[7 8 9]]
b:
[[4 5 6]
[1 2 3]]
c:
[[1 2 3 4 5 6]
[7 8 9 1 2 3]]
d:
[[1 2 3]
[7 8 9]
[4 5 6]
[1 2 3]]

【03】np.mgrid[a: b :c]

返回多维结构,常见的如2D图形,3D图形

【04】np.squeeze()

squeeze v.挤压
np.squeeze()函数可以删除数组形状中的单维度条目,就是把数组shape中等于1的维度去掉,但是对非1的维度不起作用。

#例3
>>> d  = np.arange(10).reshape(1,2,5)
>>> d
array([[[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]]])
        
>>> d.shape
(1, 2, 5)

>>> np.squeeze(d)
array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])

>>> d.shape
(2, 5)

【05】plt.contour(X, Y, Z, [levels], 其他参数)

这个函数用于绘制等高线,也可以说是分界线

plt就是matplotlib.pyplot

X, Y表示的是坐标位置(这里是可选的,但是如果不传入的话就是python根据传入的高度数组(Z)的大小自动生成的坐标),一般很多会使用二维数组,但是实际上一维数组也可以的

Z代表每个坐标对应的高度值,是一个二维数组。其中每个值表示的是每个坐标对应的高度。就相当于函数值y。对应x1和x2两个参数。

levels有两种传入形式。
一种是传入一个整数,这个整数表示你想绘制的等高线的条数,但是显示结果可能并不是完全和传入的整数的条数一样,是大致差不多的条数(可能相差一两条)(为什么是大致条数呢?可能是python帮你默认生成的比较合适的几条等高线吧)。
另一种方式是传入一个包含高度值Z的一维数组,这样python便会画出传入的高度值对应的等高线。(下面代码使用的是这个方式)


实例展示——正则化以前

实例内容是:一个坐标点二分类问题。

训练集是上述表格文件,根据已知的数据x1,x2作为横纵坐标点训练
对应的标签1类标记为红色,0类标记为蓝色

测试集数据直接用二维坐标系上正负三范围内的坐标点
在这个区间内标记处红色蓝色的点,最终勾勒呈现出两种颜色的分界线

import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

# 读入dot.csv文件的数据 300个
# x_date是两个特征值300*2
# y_date是特征值对应的标签300*1
df = pd.read_csv('/Users/sgfile/Downloads/BaiduNetdisk/人工智能实践:Tensorflow笔记20221224/class2(1)/class2/dot.csv')
x_date = np.array(df[['x1', 'x2']])
y_date = np.array(df['y_c'])

x_train = np.vstack(x_date).reshape(-1, 2)
# 参数-1表示自动匹配,也就是我只需要指定列数2,行数自动计算
# np.vstack():在竖直方向上堆叠
# np.hstack():在水平方向上堆叠
y_train = np.vstack(y_date).reshape(-1, 1)

# 给标签值标记颜色 如果标签为1则是红色,0为蓝色
Y_c = [['red' if y else 'blue'] for y in y_train]

# 转换x数据的类型,方便后面矩阵相乘
x_train = tf.cast(x_train, dtype=tf.float32)
y_train = tf.cast(y_train, dtype=tf.float32)

# 切片数据,将特征与标签配对,并打包,生成训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)

# 构建神经网络
# 输入层2个神经元——————隐藏层11个神经元————输出层1个神经元
#          w1=[2*11] b1=11    w2=[11*1],b2=1
#

w1 = tf.Variable(tf.random.normal([2, 11], dtype=tf.float32))
b1 = tf.Variable(tf.constant(0.01, shape=[11]))

w2 = tf.Variable(tf.random.normal([11, 1], dtype=tf.float32))
b2 = tf.Variable(tf.constant(0.01, shape=[1]))

lr = 0.005
epoch = 800

for epoch in range(epoch):
    for step, (x_train, y_train) in enumerate(train_db):
        # 记录梯度信息
        with tf.GradientTape() as tape:
            # 记录神经网络乘加运算第一层
            h1 = tf.matmul(x_train, w1) + b1
            # 激活函数relu
            h1 = tf.nn.relu(h1)
            # 记录神经网络乘加运算第二层
            y = tf.matmul(h1, w2) + b2
            # 采用均方误差损失函数mse
            loss_mse = tf.reduce_mean(tf.square(y_train - y))
					
        # 梯度更新四个参数

        # step1——计算梯度信息
        grads = tape.gradient(loss, [w1, b1, w2, b2])

        # step2——实现梯度更新参数
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])
        w2.assign_sub(lr * grads[2])
        b2.assign_sub(lr * grads[3])

	# 没有使用正则化,loss只有一个部分
	loss=loss_mse
		
    # 每20轮,打印loss信息
    if epoch % 20 == 0:
        print("epoch:", epoch, "loss:", float(loss))

# ------------------------至此,神经网络四个参数训练完成

输出结果:
epoch: 0 loss: 2.4265878200531006
epoch: 20 loss: 0.43145981431007385
epoch: 40 loss: 0.16190886497497559
epoch: 60 loss: 0.08229479938745499
epoch: 80 loss: 0.05829337239265442
epoch: 100 loss: 0.048353154212236404
epoch: 120 loss: 0.040947575122117996
epoch: 140 loss: 0.03630096837878227
epoch: 160 loss: 0.03280305117368698
epoch: 180 loss: 0.030033906921744347
epoch: 200 loss: 0.02776075340807438
epoch: 220 loss: 0.026936165988445282
epoch: 240 loss: 0.02629946731030941
epoch: 260 loss: 0.025844631716609
epoch: 280 loss: 0.02554425597190857
epoch: 300 loss: 0.025302037596702576
epoch: 320 loss: 0.02517053484916687
epoch: 340 loss: 0.02507457695901394
epoch: 360 loss: 0.024986503645777702
epoch: 380 loss: 0.02490798570215702
epoch: 400 loss: 0.024857262149453163
epoch: 420 loss: 0.024823719635605812
epoch: 440 loss: 0.02479807287454605
epoch: 460 loss: 0.0247130636125803
epoch: 480 loss: 0.024589741602540016
epoch: 500 loss: 0.024555539712309837
epoch: 520 loss: 0.02451319247484207
epoch: 540 loss: 0.02456258237361908
epoch: 560 loss: 0.02463117428123951
epoch: 580 loss: 0.024696240201592445
epoch: 600 loss: 0.024768143892288208
epoch: 620 loss: 0.024842776358127594
epoch: 640 loss: 0.02491491474211216
epoch: 660 loss: 0.024984782561659813
epoch: 680 loss: 0.02505091391503811
epoch: 700 loss: 0.025109589099884033
epoch: 720 loss: 0.025105630978941917
epoch: 740 loss: 0.025101713836193085
epoch: 760 loss: 0.0251059141010046
epoch: 780 loss: 0.025111591443419456

# 预测部分
print("**************[predict]****************")

# 生成网格坐标点,规格是:正负3的坐标系,0.1*0.1的网格
xx, yy = np.mgrid[-3:3:0.1, -3:3:0.1]
# 将xx,yy拉直,合并配对为二维张量grid,表示二维坐标点
grid = np.c_[xx.ravel(), yy.ravel()]
grid = tf.cast(grid, tf.float32)



# 把网格的坐标当作数据集送入训练好的神经网络,进行预测,输出存入列表probs
probs = []
for x_test in grid:
    # 使用训练好的参数进行预测
    h1 = tf.matmul([x_test], w1)+b1
    h1 = tf.nn.relu(h1)
    y = tf.matmul(h1, w2)+b2
    probs.append(y)

# 绘制散点图
# 取x_data的第0列给x1,第1列给x2
x1 = x_date[:, 0]
x2 = x_date[:, 1]

# probs的形状调整为和xx一致
probs =  np.array(probs).reshape(xx.shape)
plt.scatter(x1,x2,color=np.squeeze(Y_c))
# squeeze去掉纬度是1的纬度
# 相当于去掉[['red'],[''blue]],内层括号变为['red','blue']


# 把坐标xx yy和对应的值probs放入plt.contour()函数————用于画等高线的(也可以看做边界线)
# 给probs值为0.5的所有点上色——也就是上面那条线
# plt点show后 显示的是红蓝点的分界线
plt.contour(xx,yy,probs,levels=[0.5]) #预测值y为0.5
plt.show()

输出:
在这里插入图片描述可以看到分割线轮廓不够平滑,存在过拟合现象


实例展示——正则化以后

仅需在以上代码的with结构中加入正则化

以下是更新后的with结构,对参数w1,w2使用了L2正则化

        with tf.GradientTape() as tape:
            # 记录神经网络乘加运算第一层
            h1 = tf.matmul(x_train, w1) + b1
            # 激活函数relu
            h1 = tf.nn.relu(h1)
            # 记录神经网络乘加运算第二层
            y = tf.matmul(h1, w2) + b2
            # 采用均方误差损失函数mse
            loss_mse = tf.reduce_mean(tf.square(y_train - y))

			# 添加L2正则化————tf.nn.l2_loss(w)=sum(w**2)/2
			loss_regularization=[]
			loss_regularization.append(tf.nn.l2_loss(w1))
			loss_regularization.append(tf.nn.l2_loss(w2))
			# 求和
			loss_regularization=tf.reduce_sum(loss_regularization)
			  
			# 经过正则化后的loss包含两部分
			# 一个是衡量预测值与标准答案差距的均方误差loss_mse
			# 一个是表示各个参数权重和的正则化loss_regularization*超参数
			loss = loss_mse+0.03*loss_regularization

在这里插入图片描述可以见到,加入L2正则化后的曲线更平缓,缓解了过拟合现象

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/126843.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从0-1搭建流媒体系统之ZLMediaKit 安装、运行、推流、拉流

音视频开发系列 文章目录音视频开发系列前言一、ZLMediaKit是什么?二、使用过程1.编译、安装、运行2.推流、拉流总结前言 目前、比较有名的流媒体服务器有ZLMediaKit、srs、live555、eadydarwin等。因为srs是单线程服务、对于多核服务器的支持需要通过部署多个服务…

蓝牙学习七(MAC地址)

1.简介 一个BLE设备,可以使用两种类型的地址(一个BLE设备可以同时具备两种地址):Public Device Address(公共设备地址)和Random Device Address(随机设备地址)。而Random Device Add…

如何用 java 实现【二叉搜索树】

文章目录搜索树概念1. 查找操作2. 插入操作3. 删除操作4. 以上三种操作的测试5. 性能分析搜索树概念 二叉搜索树 又称 二叉排序树,它或者是一棵空树,或者是具有以下性质的二叉树: 若它的 左 子树 不为空,则 左 子树上所有节点的值…

自定义神经网络入门-----Pytorch

文章目录目标检测的相关评价指标IoUmAP正例和负例准确率P召回率R准确率ACCP-R曲线--APnn.Module类全连接层感知机类使用nn.Sequential进行构造使用randn函数进行简单测试损失函数nn.functionalnn.optim模型处理网络模型库torchvision.models模型Fine-tune和save参考目标检测的相…

【STM32F4系列】【HAL库】【自制库】模拟IIC从机

介绍 本项目是利用GPIO模拟I2C的从机 网上常见的是模拟I2C主机 本项目是作为一个两个单片机之间低速通信的用法 协议介绍请看,传送门 模拟主机请看这里 从机 功能 实现I2C从机端读写寄存器 编程思路 I2C的从机实现比起主机来麻烦一些 因为SCL的时序是由主机发送,从机需…

【nowcoder】笔试强训Day12

目录 一、选择题 二、编程题 2.1二进制插入 2.2 查找组成一个偶数最接近的两个素数 一、选择题 1.以下方法,哪个不是对add方法的重载? public class Test {public void add( int x,int y,int z){} } A. public int add(int x,int y,float z){return 0;} B.…

Go语言设计与实现 -- WaitGroup, Once, Cond

WaitGroup 我们可以通过 sync.WaitGroup 将原本顺序执行的代码在多个 Goroutine 中并发执行,加快程序处理的速度。 我们来看一下sync.WaitGroup的结构体: type WaitGroup struct {//保证WaitGroup不会被开发者通过再赋值的方式复制noCopy noCopy// 64-…

重学redux之Redux-Thunk高级使用(三)

这是第三篇了,哥们,如果没看过前两篇,可以去看看之前的两篇,有基础的可以直接看,不多说,直接开讲 默认情况下,Redux 的动作是同步调度的,对于任何需要与外部 API 通信或执行副作用的应用程序来说都是一个问题。 Redux 允许中间件位于被分派的动作和到达 reducer 的动…

抖音本地生活的蓬勃发展,离不开服务商的推波助澜

抖音本地生活,已经势不可挡01 抖音公布本地生活成绩单,交易额增长30倍抖音经过6年时间的演变,产品功能日益丰富,已经从内容消费,延续到线上购物、线下团购等领域,从最初的记录美好生活,成为一种…

统计分析工具-FineReport配置SQL Server外接数据库(2)

1. 配置外接数据库 1.1 外接数据库配置入口 外接数据库的配置入口,有三种形式: 1)超级管理员第一次登录数据决策系统时,即可为系统配置外接数据库。如下图所示: 2)对于使用内置数据库的系统,管…

站点能源低碳目标网,助力网络碳中和 | 华为发布站点能源十大趋势

2022年12月29日,华为今天举办站点能源十大趋势发布会并重磅发布白皮书。发布会上,华为站点能源领域总裁尧权全面解读了能源数字化、低碳网络、站点供电绿色化等站点能源十大趋势。 尧权表示,2022年是不平凡的一年,全球能源危机背…

十、通过网络服务将esp8266引脚状态显示在网页中

ESP8266在服务器模式运行时,我们可以使用浏览器来显示它的引脚状态。 1、实现目标 学习如何通过esp8266建立基本网站,在该网站上实时显示esp8266的引脚值。 2、原理图 FLASH按键与D3引脚连接,可以通过FLASH按键改变D3引脚的电平。当没有按…

中型企业适合用什么样的CRM管理软件,求推荐?

中型企业适合用什么样的CRM管理软件,求推荐? CRM管理软件是现代企业必不可少的管理软件之一,很多企业都会选择CRM管理软件来经营客户资源,但能够精准地选择到适合自己企业的CRM管理软件则是困难的。 中型企业需要与自己业务流程…

数据可视化之finebi和tableau电力系统分析实现对比

通过一个电力系统简单案例,尝试实际执行finebi和Tableau数据可视化设计的各项基本步骤,以熟悉Tableau和finebi数据可视化设计技巧,提高大数据可视化应用能力。 一、工具/准备工作 在开始本实验之前,请认真阅读课程的相关内容。 …

写给小白的TensorFlow的入门课

文章目录前言学习AI的必要性和业务的关系最简单的例子要做什么?数据图形化展示构建计算图形计算图形最小化误差MacOS 中配置运行环境安装验证安装简单模型训练识别数字图片的模型训练Softmax Regression算法大概步骤大致算法实现结语参考链接前言 深度学习就是从大…

抖音电商发布2023年食品健康行业8大趋势,新减负、新养生等成为关键词

2022抖音电商食品健康峰会暨年货盛典在杭州成功举行。抖音电商食品健康行业还联合欧睿共同发布了《2023年度食品健康行业趋势洞察报告》。图片来源:抖音电商抖音电商食品健康行业负责人白华在会上透露,过去一年,抖音电商食品健康行业呈现出有…

虚拟机数据库改密码ERROR 1396 (HY000): Operation ALTER USER failed for ‘root‘@‘localhost‘

注:原因为MySql 8.0.11 换了新的身份验证插件(caching_sha2_password), 原来的身份验证插件为(mysql_native_password)。而客户端工具Navicat Premium12 中找不到新的身份验证插件(caching_sha2_password&a…

Java实现多线程

目录 基本概念 1、程序、进程、线程 2、使用线程的优点 3、线程的分类 4、线程的生命周期 多线程的实现方法 1、继承Thread类 2、实现Runnable接口 3、实现Callable接口 4、使用线程池 线程同步 1、同步代码、同步方法 2、同步机制中的锁 3、锁(Lock&…

【电商】电商后台---采购管理模块

从供应商的管理到合同的管理,再到商品系统的模块的介绍、商品价格与税率维护策略,不知不觉已经完成了几篇文章,前期的准备工作完成后,接下来就应该进入到采购管理模块了。 几天来一直在构思如何写,写的内容让大家看过觉…

使用天地图加载Geoserver的图层

一、写在前面 在项目中往往使用地图作为底图(比如 天地图卫星图等),再其上覆盖你的通过geoserver发布自定义图层。本文记录了我的实现方法。 二、过程 2.1 我遇到的难题 遇到难题1:使用无人机拍摄制作的正射影像图有几百MB甚至1个G,直接展示图…