tensorflow卷积层操作

news2024/11/16 18:38:14

全连接NN:

每个神经元与前后相邻层的每一个神经元都有全连接关系。输入是特征,输出为预测结果。

参数个数\sum(前层*后层+后层)

实际应用时,会先对原始图像进行特征提取,再把提取到的特征送给全连接网络

会先进行若干层提取,把提取的特征放入全连接网络。

卷积计算可以认为是一个有效提取图像特征的方法。

一般会用一个正方形的卷积核,按指定步长,在输入特征图上滑动,遍历输入特征图中的每个像素点。

当前卷积核的个数,决定了输出特征图的深度 

卷积利用立体卷积核实现参数空间共享。对应元素相乘+偏置项b。

红绿蓝三层分别和各层卷积相乘+b求和,得到输出特征图的一项。

按指定步长滑动。

使用CNN实现离散数据的分类(以图像分类为例)

感受野:卷积神经网络各层输出特征图中的每个像素点,在原始输入图片上映射区域的大小,

 当输入特征图边长大于10个像素点时,3*3性能要比5*5好。

全零填充

当使用全零填充时=输入特征 5/1=5

不全零填充 特征-卷积核+1/步长 向上取整

tf.keras.layers.Conv2D( filters = 卷积核个数,kernel_size= 卷积核尺寸, strides=滑动步长,padding="same"or"valid",activation="relu"...)

批标准化

标准化:使数据符合0的均值,1为标准差的分布,】

批标准化:对一小撮数据做标准化处理

Hi' = HiK - uK/δ

 BN操作将原本偏移的特征数据拉回线性区域。使得输入数据的微小变化更明显的体现到激活函数的输出,提升激活函数对数据的区分力,但使激活函数丧失了非线性特征,因此在BN操作中引入缩放因子γ和偏移因子β,保证函数的非线性表达力。

池化操作

池化用于减少特征数据量,最大值池化可提取图片纹理,均值池化可保留背景特征。

tf.keras.layers.MaxPool2D(pool_size=, strides=, padding=)

舍弃

在神经网络训练时,将一部分神经元按照一定概率从神经元中暂时舍弃,神经网络使用时被舍弃的神经元恢复链接

tf.keras.layers.Dropout(舍弃的概率)

卷积神经网络:借助卷积核特征提取后,送入全连接网络

卷积神经网络的八股套路。

卷积就是特征提取器。CDAPB 

cifar10数据集

cifar10=tf.keras.datasets.cifar10

load_data()读取训练集。

x_train 是5万个32行32列3通道的RGB像素点

卷积神经网络搭建示例

使用6个5x5的卷积核,过2*2的池化核,池化步长是2

过128个神经元的全连接层,最后要过一个10个神经元的全连接层,因为有10个特征

C(核6*5*5,步长1,填充:same)

B(yes)

A(relu)

P(max, 核2*2 ,步长:2,填充:same)

Flatten

Dense(神经元:128,激活relu, Dropout:0.2)

Dense(深刻警员:10,激活softmax )

class

 Baseline(Model):

        def __init__(self):

                super(Baseline, self).__init__()

           C  self.c1 = Conv2D(filters+6, kernel_size=(5, 5),padding='same') #6个卷积核,5*5,使用全零填充

           B  self.b1 = BarchNoermalization()     #使用BN操作

          A   self.a1 = Activation('relu')  #激活函数

         P self.p1 = MaxPool12D(pool_size=(2,2), strides=2,padding='same')  # 池化核是2*2, 池化步长是2

         D self.d1 = Dropout(0.2)        #百分之20休眠神经元

            self.flatten = Flatten() #拉直

             self.f1 = Dense(128, activation='relu') #送入128个神经元的全连接

            self.d2 = Dropout(0.2)# 按照百分之20休眠神经元

           self.f2 = Dense(10, activation='softmax') #送入10个神经元的全连接

使用call调用init函数每个网络结构,从输入到输出,过一次前向传播。返回推理结果

#配置训练方法,选择哪种优化器,选择哪个损失函数,那种评测指标
model.compile

# 告知训练集的输入特征和标签,,每个batch是多少,要迭代,训练集,告知多少次迭代测试集验证准确率,使用回调函数完成断点续训
history = model.fit

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model

np.set_printoptions(threshold=np.inf)

cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


class Baseline(Model):
    def __init__(self):
        super(Baseline, self).__init__()
        self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same')  # 卷积层
        self.b1 = BatchNormalization()  # BN层
        self.a1 = Activation('relu')  # 激活层
        self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')  # 池化层
        self.d1 = Dropout(0.2)  # dropout层

        self.flatten = Flatten()
        self.f1 = Dense(128, activation='relu')
        self.d2 = Dropout(0.2)
        self.f2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y


model = Baseline()
#配置训练方法,选择哪种优化器,选择哪个损失函数,那种评测指标
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/Baseline.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
# 告知训练集的输入特征和标签
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1919641.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

复杂度(上卷)

前言 在正式进入今天的主题之前,我们不妨先来回顾一下初步学习数据结构后必须知道的概念。🎶 数据结构 数据结构是计算机存储、组织数据的方式,指相互间存在一种或多种特定关系的数据元素的集合。 (没有一种单一的数据结构能够…

在centos7中安装MySQL5.7,是否必须卸载centos7自带的mariadb?

在CentOS 7 中安装 MySQL 5.7 时,不一定必须卸载系统自带的 MariaDB,但为了避免冲突和确保 MySQL 的正常运行,通常建议先卸载 MariaDB。以下是具体的步骤: 卸载 MariaDB(如果已经安装): sudo sy…

强化学习驱动的狼人游戏语言智能体战略玩法

Language Agents with Reinforcement Learning for Strategic Play in the Werewolf Game 论文地址: https://arxiv.org/abs/2310.18940https://arxiv.org/abs/2310.18940 1.概述 在AI领域,构建具备逻辑推理、战略决策以及人类沟通能力的智能体一直被视为长远追求。大规模语…

小阿轩yx-NoSQL 之 Redis 配置与优化

小阿轩yx-NoSQL 之 Redis 配置与优化 Redis 数据库介绍 是一个非关系型数据库 关系数据库与非关系型数据库 按照数据库结构划分的 关系型数据库 是一个结构化的数据库,创建在关系模型基础上,一般面向于记录借助集合代数等数学概念和方法处理数据库…

设计模式探索:责任链模式

1. 什么是责任链模式 责任链模式 (Chain of Responsibility Pattern) 是一种行为型设计模式。定义如下: 避免将一个请求的发送者与接收者耦合在一起,让多个对象都有机会处理请求。将接收请求的对象连接成一条链,并且沿着这条链传递请求&…

数列分块<2>

本期是数列分块入门<2>。该系列的所有题目来自hzwer在LOJ上提供的数列分块入门系列。 Blog:http://hzwer.com/8053.html sto hzwer orz %%% [转载] 好像上面的链接↑打不开&#xff0c;放一个转载:https://www.cnblogs.…

CUDA原子操作

代码 #include <cuda_runtime.h> #include <stdio.h>__global__ void atomicAddAndGet(int *result, int *valueToAdd) {// 原子加法int addedValue atomicAdd(result, *valueToAdd);// 通过原子操作后读取值&#xff0c;确保是加法后的值addedValue *valueToAd…

LabVIEW开发CAN总线多传感器液位检测系统

设计并实现了一个基于CAN总线和LabVIEW的多传感器液位检测系统。该系统利用STM32F107单片机进行模拟信号与数字信号的转换&#xff0c;通过TJA1050实现CAN总线通信&#xff0c;并使用USB-CAN分析仪连接PC。LabVIEW用于数据采集、人机交互界面的设计、数据分析和仪器标定。系统能…

前端必修技能:高手进阶核心知识分享 - 三万字帮你搞定CSS动画(形变动画、过渡动画、关键帧动画)

在CSS的世界里,存在着多种能体现动画效果的属性:CSS transform、CSS Transition 和 CSS Animation。让开始接触CSS的同学感到困惑。要搞清楚CSS的动画,我们就必须先把这几种属性做一下区别。 CSS transform 属性、CSS Transition 属性、 CSS Animation 属性的区别 CSS tra…

FL Studio21.5.3.21中文版破解安装包!音乐制作新神器,让创意无限飞扬!

&#x1f3b6; 音乐制作&#xff0c;轻松入门&#xff01;FL Studio21中文版本体验分享 嘿&#xff01;各位音乐小能手和创作小白们&#xff0c;今天我要给大家安利一个超酷炫的音乐制作软件——FL Studio21中文版&#xff01;&#x1f389; FL Studio21汉化版下载网盘链接: …

Python函数 之 模块和包---练习

题目 1 1.定义一个模块 toolls.py , 定义函数实现对两个数据进行加法操作的函数 add_2_num &#xff0c;并返回相加之和的结 果&#xff1b; 再定义一个实现对三个数据进行加法操作的函数 add_3_num &#xff0c;并返回相加之和的结果&#xff1b; 2.最后新定义一个代码文件 …

AutoMQ vs Kafka: 来自小红书的独立深度评测与对比

测试背景 当前小红书消息引擎团队与 AutoMQ 团队正在深度合作&#xff0c;共同推动社区建设&#xff0c;探索云原生消息引擎的前沿技术。本文基于 OpenMessaging 框架&#xff0c;对 AutoMQ 进行了全面测评。欢迎大家参与社区并分享测评体验。 01 测试结论 本文主要测评云…

JavaDS —— 单链表 与 LinkedList

顺序表和链表区别 ArrayList &#xff1a; 底层使用连续的空间&#xff0c;可以随机访问某下标的元素&#xff0c;时间复杂度为O&#xff08;1&#xff09; 但是在插入和删除操作的时候&#xff0c;需要将该位置的后序元素整体往前或者向后移动&#xff0c;时间复杂度为O&…

二分查找算法——部分OJ题详解

目录 关于二分查找算法 部分OJ题详解 704.二分查找 一&#xff0c;分析题目 二&#xff0c;细节处理 三&#xff0c;题目代码 四&#xff0c;*总结朴素模板 *34.在排序数组中查找元素的第一个和最后一个位置 一&#xff0c;查找左端点 二&#xff0c;处理左端点细…

ts实现将相同类型的数据通过排序放在一起

看下效果&#xff0c;可以将相同表名称的字段放在一起 排序适用于中英文、数字 // 排序 function sortByType(items: any) {// 先按照类型进行排序items.sort((a: any, b: any) > {if (a.label < b.label) return -1;if (a.label > b.label) return 1;return 0;});r…

【记录】LaTex|LaTex调整算法、公式、表格内的字体大小(10种内置字号)

文章目录 【记录】LaTex&#xff5c;LaTex调整算法、公式、表格内的字体大小&#xff08;10种内置字号&#xff09;省流版1 字体大小2 测试代码 详细版1 \tiny2 \scriptsize3 \footnotesize4 \small5 \normalsize6 \large7 \Large8 \LARGE9 \huge10 \Huge 【记录】LaTex&#x…

实验02 黑盒测试(组合测试、场景法)

1. 组合测试用例设计技术 指出等价类划分法和边界值分析法通常假设输入变量相互独立&#xff0c;但实际情况中变量间可能存在关联。全面测试&#xff1a;覆盖所有输入变量的所有可能组合&#xff0c;测试用例数量随输入变量的增加而指数增长。 全面测试需要对所有输入的各个取…

Geoserver源码解读六 插件

系列文章目录 Geoserver源码解读一 环境搭建 Geoserver源码解读二 主入口 Geoserver源码解读三 GeoServerBasePage Geoserver源码解读四 REST服务 Geoserver源码解读五 Catalog Geoserver源码解读六 插件&#xff08;怎么在开发模式下使用&#xff09; 目录 系列文章目…

ubuntu计划任务反弹

目录 实验环境 实验步骤 目标主机构造任务计划 构造语句 语句解释 kali开启监听 监听成功 问题 原因 实验环境 攻击者 操作系统&#xff1a;kali IP&#xff1a;192.168.244.141 目标主机 操作系统&#xff1a;ubuntu IP&#xff1a;192.168.244.151 实验步骤 目…

CSS 中的 ::before 和 ::after 伪元素

目录 一、CSS 伪元素 二、::before ::after 介绍 1、::before 2、::after 3、content 常用属性值 三、::before ::after 应用场景 1、设置统一字符 2、通过背景添加图片 3、添加装饰线 4、右侧展开箭头 5、对话框小三角 6、插入icon图标 一、CSS 伪元素 CSS伪元…