神经网络构成、优化、常用函数+激活函数

news2024/9/25 17:12:04

Iris分类

数据集介绍,共有数据150组,每组包括长宽等4个输入特征,同时给出输入特征对应的Iris类别,分别用0,1,2表示。

从sklearn包datasets读入数据集。

from sklearn import darasets
from pandas import DataFrame
import pandas as pd
x_data = datasets.load_iris().data #  输入特征
y_data = datasets.load_iris().target # 标签
x_data = DataFrame(x_data, columns=["花萼长度",'花萼宽度','花瓣长度','花瓣宽度'])
pd.set_option('display.unicode.east_asian_width',True) # 设置列名对齐
x_dara['类别'] = y_data # 新增一列

神经网络实现分类步骤。

1.准备数据:

数据集读入,数据集乱序,生成训练集和测试集,配成输入特征/标签对,每次读入一部分

2.搭建网络

定义神经网络中的所有可训练参数

3.参数优化

嵌套循环迭代,with结构更新参数,显示当前loss

4.测试效果

计算当前向前传播后的准确率,显示当前acc

5可视化acc/loss

# -*- coding: UTF-8 -*-
# 利用鸢尾花数据集,实现前向传播、反向传播,可视化loss曲线

# 导入所需模块
import tensorflow as tf
from sklearn import datasets
from matplotlib import pyplot as plt
import numpy as np

# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target

# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)

# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]

# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)

# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
# 使用seed使每次生成的随机数相同(方便教学,使大家结果都一致,在现实使用时不写seed)
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))

lr = 0.1  # 学习率为0.1
train_loss_results = []  # 将每轮的loss记录在此列表中,为后续画loss曲线提供数据
test_acc = []  # 将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 500  # 循环500轮
loss_all = 0  # 每轮分4个step,loss_all记录四个step生成的4个loss的和

# 训练部分
for epoch in range(epoch):  #数据集级别的循环,每个epoch循环一次数据集
    for step, (x_train, y_train) in enumerate(train_db):  #batch级别的循环 ,每个step循环一个batch
        with tf.GradientTape() as tape:  # with结构记录梯度信息
            y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
            y = tf.nn.softmax(y)  # 使输出y符合概率分布(此操作后与独热码同量级,可相减求loss)
            y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss和accuracy
            loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-out)^2)
            loss_all += loss.numpy()  # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确
        # 计算loss对各个参数的梯度
        grads = tape.gradient(loss, [w1, b1])

        # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
        w1.assign_sub(lr * grads[0])  # 参数w1自更新
        b1.assign_sub(lr * grads[1])  # 参数b自更新

    # 每个epoch,打印loss信息
    print("Epoch {}, loss: {}".format(epoch, loss_all/4))
    train_loss_results.append(loss_all / 4)  # 将4个step的loss求平均记录在此变量中
    loss_all = 0  # loss_all归零,为记录下一个epoch的loss做准备

    # 测试部分
    # total_correct为预测对的样本个数, total_number为测试的总样本数,将这两个变量都初始化为0
    total_correct, total_number = 0, 0
    for x_test, y_test in test_db:
        # 使用更新后的参数进行预测
        y = tf.matmul(x_test, w1) + b1
        y = tf.nn.softmax(y)
        pred = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类
        # 将pred转换为y_test的数据类型
        pred = tf.cast(pred, dtype=y_test.dtype)
        # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
        correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
        # 将每个batch的correct数加起来
        correct = tf.reduce_sum(correct)
        # 将所有batch中的correct数加起来
        total_correct += int(correct)
        # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
        total_number += x_test.shape[0]
    # 总的准确率等于total_correct/total_number
    acc = total_correct / total_number
    test_acc.append(acc)
    print("Test_acc:", acc)
    print("--------------------------")

# 绘制 loss 曲线
plt.title('Loss Function Curve')  # 图片标题
plt.xlabel('Epoch')  # x轴变量名称
plt.ylabel('Loss')  # y轴变量名称
plt.plot(train_loss_results, label="$Loss$")  # 逐点画出trian_loss_results值并连线,连线图标是Loss
plt.legend()  # 画出曲线图标
plt.show()  # 画出图像

# 绘制 Accuracy 曲线
plt.title('Acc Curve')  # 图片标题
plt.xlabel('Epoch')  # x轴变量名称
plt.ylabel('Acc')  # y轴变量名称
plt.plot(test_acc, label="$Accuracy$")  # 逐点画出test_acc值并连线,连线图标是Accuracy
plt.legend()
plt.show()

 根据MP模型可以看出,求出的y实际上就是计算出属于哪一种分类的概率

 

 其求出的loss就是概率和0/1相减,再将loss和w与b求偏导,通过公式运算得到w和b

预备函数

tf.where

a=tf.constant([1,2,3,1,1])

b=tf.constant([0,1,3,4,5])

c=tf.where(tf.greater(a,b),a,b)

2.np.random.RandomState.rand(维度)返回[0,1]的随机数

3.np.vstack() 将两个数组按垂直方向叠加

4np.mgridp[起始值:结束值:步长,....] [)

5.x.ravel() 将x变为一维数组

6.np.c_ [数组1,数组2]返回的间隔数值点配对

神经网络复杂度

指数衰减学习率

可以先用较大的学习率,快速得到较优解,然后逐步减小学习率,使模型在训练后期稳定

指数衰减学习率=初始学习率*学习率衰减率^(当前轮数/多少轮衰减一次)更新频率

epoch = 40
LR_BASE = 0.2
LR_DECAY = 0.99
LR_STEP = 1
for epoch in range(epoch):
    lr = LR_BASE*LR_DECAY**(epoch/LR_STEP)
    with tf.GradientTape() as tape:
        loss = tf.square(w + 1)
       grads = tape.gradient(loss, w)
       w.assign_sub(;lr*grads)

激活函数

1.Signmoid函数

特点:容易造成梯度消失,输出非0均值,收敛慢,幂运算复杂,训练时间长

2.Tanh函数

特点:输出是0均值,容易造成梯度消失,幂运算复杂,训练时间长

3.Relu函数

解决梯度消失问题正区间,容易造成神经元死亡,改变随机初始化,避免过多设置更小学习率,减少参数的巨大变化,避免训练中产生过多负数特征进入函数

Leaky Rely

1首选relu函数

2学习率设置较小值

3输入特征标准化,既让输入特征满足以0为均值,1为标准差的正态分布

4初始参数中心化,既让随机生成的参数满足以0为均值,sqrt(2/当前层输入特征个数)为标准差的正态分布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1914790.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Puppeteer 是什么以及如何在网络抓取中使用它 | 2024 完整指南

网页抓取已经成为任何处理网页数据提取的人都必须掌握的一项重要技能。无论你是开发者、数据科学家还是希望从网站收集信息的爱好者,Puppeteer都是你可以使用的最强大工具之一。本完整指南将深入探讨什么是Puppeteer以及如何有效地在网页抓取中使用它。 Puppeteer简…

真正高水平的一流领导,从不和员工打成一片,这3点原因太真实

真正高水平的一流领导,从不和员工打成一片,这3点原因太真实 第一个:分化团队 在团队管理过程中,如果人不多,那还好。 可一旦人数多了,领导就不可能面面俱到,顾及到每一个人。 肯定会出现&am…

基于Java技术的网上图书商城系统

你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:Java 数据库:MySQL 技术:Java技术、SpringBoot框架 工具:Eclipse、Navicat、Maven 系统展示 首页 用户注册界面…

Java中的公平锁和非公平锁

1、什么是公平锁和非公平锁 公平锁和非公平锁是指在多线程环境下,如何对锁进行获取的顺序和策略的不同。 公平锁是指多个线程按照申请锁的顺序来获取锁,即先到先得的策略。当一个线程释放锁之后,等待时间最长的线程将获得锁。公平锁的优点是保…

第2章 信息技术知识

第2章 信息技术知识 本章简要叙述了信息技术相关基础知识,包含软件工程、面向对象系统分析与设计、应用集成技术、计算机网络技术和新一代信息技术内容。 2.1 软件工程 随着所开发软件的规模越来越大、复杂度越来越高,加之用户需求又并不十分明确&…

【RIP实验-熟悉基础配置】

实验拓扑 实验要求 根据实验拓扑的IP地址分配,为所有设备配置对应的IP地址和环回地址。全网运行RIPv2,将R1、R2、R3和R4的物理端口、Loopback地址和10.1.00网段进行宣告。并在rip协议下配置路由自动汇总,观察R1/R2是否能够收到10.0.0.0的详细…

【Pytorch实用教程】transformer中创建嵌入层的模块nn.Embedding的用法

文章目录 1. nn.Embedding的简单介绍1.1 基本用法1.2 示例代码1.3 注意事项2. 通俗的理解num_embeddings和embedding_dim2.1 num_embeddings2.2 embedding_dim2.3 使用场景举例结合示例1. nn.Embedding的简单介绍 nn.Embedding 是 PyTorch 中的一个模块,用于创建一个嵌入层。…

MATLAB engine for python调用m文件函数输出变量值python调用MATLAB函数混合编程

MATLAB engine for python调用m文件函数输出变量值python调用MATLAB函数混合编程 说明(废话)解决方案总结 说明(废话) python调用MATLAB函数,MATLAB函数实现在m文件,python直接调用MATLAB中的函数。 首先还是要安装好MATLAB engine python setup.py ins…

3d已经做好的模型怎么改单位?---模大狮模型网

在展览3D模型设计行业中,经常会遇到需要将已完成的模型进行单位转换的需求。这可能涉及从一种度量单位转换为另一种,例如从英制单位转换为公制单位,或者根据特定的展览场地要求进行尺寸调整。本文将探讨如何有效地修改已完成的3D模型的单位&a…

标签印刷检测,如何做到百分百准确?

印刷标签是一种用于标识、识别或包装产品的平面印刷制品。这些标签通常在纸张、塑料膜、金属箔等材料上印刷产品信息、条形码、图像或公司标识,以便于产品识别和管理。印刷标签有各种形状、尺寸和材质,可以根据具体需求进行定制设计。常见的印刷标签包括…

分类模型的算法性能评价

一、概述 分类模型是机器学习中一种最常见的问题模型,在许多问题场景中有着广泛的运用,是模式识别问题中一种主要的实现手段。分类问题概况起来就是,对一堆高度抽象了的样本,由经验标定了每个样本所属的实际类别,由特定…

南京邮电大学运筹学课程实验报告1 线性规划求解 指导

一、题目描述 实验 一 线性规划求解 实验属性: 验证性     实验目的 1.理解线性规划解的基本概念; 2.掌握运筹学软件的使用方法; 3. 掌握线性规划的求解原理和方法。 实验内容 认…

MQTT是什么,物联网

写文思路: 以下从几个方面介绍MQTT,包括:MQTT是什么,MQTT和webSocket的结合,以及使用场景, 一、MQTT是什么 MQTT(Message Queuing Telemetry Transport)是一种轻量级的发布/订阅消息…

手慢无,速看︱PMO大会内部学习资料

全国PMO专业人士年度盛会 每届PMO大会,组委会都把所有演讲嘉宾的PPT印刷在了会刊里面,供大家会后回顾与深入学习。 第十三届中国PMO大会-会刊 《2024第十三届中国PMO大会-会刊》 (内含演讲PPT) 会刊:750个页码&…

泽众云真机-远程自助兼容性测试报告新功能已上线

之前泽众云真机平台有用户咨询,是否支持团队管理,因项目时间上线紧急,需要多人参与测试, 否则项目有延期的风险,后来用户选择了泽众云测试平台的第三方专业兼容性测试服务。事后技术团队认真评估用户需要,从…

京东商品历史价格查询

当前资料来源于网络,禁止用于商用,仅限于学习。 下载京东APP 登录后 打开商品详情就可以看到 要获取京东商品的历史价格,你可以在京东网站上搜索该商品,并进入该商品的详情页面。然后,在页面中找到“商品详情”一栏&…

电脑32位和62位是什么意思

在现代计算机世界中,32位和64位是两个常见的术语,但许多用户可能不太清楚它们的确切含义以及它们之间的区别。本文将详细介绍32位和64位计算机的基本概念、如何查看您的计算机是32位还是64位,以及它们对用户的实际影响。 32位与64位的基本概…

针对tcp不出网打——HTTP封装隧道代理(以CFS演示)

目录 上传工具到攻击机 使用说明 生成后门文件 由于电脑短路无法拖动文件,我就wget发送到目标主机tunnel.php文件​ 成功上传​ 可以访问上传的文件 启动代理监听 成功带出 后台私信获取弹药库工具reGeorg 上传工具到攻击机 使用说明 生成后门文件 pyt…

【Linux】进程(9):进程控制3(进程程序替换)

大家好,我是苏貝,本篇博客带大家了解Linux进程(9)进程控制1,如果你觉得我写的还不错的话,可以给我一个赞👍吗,感谢❤️ 目录 (A)什么是进程程序替换&#xf…

PLC数采网关在实际应用中有哪些效能?天拓四方

在工业自动化领域中,PLC扮演着至关重要的角色,它负责控制和监测生产线的各个环节。然而,随着工业4.0的推进和智能制造的快速发展,单纯依靠PLC进行现场控制已无法满足企业对数据集中管理、远程监控和智能分析的需求。因此&#xff…