TensorFlow神经网络中间层的可视化

news2025/1/11 2:25:39

TensorFlow神经网络中间层的可视化

  • TensorFlow神经网络中间层的可视化
    • 1. 训练网络并保存为.h5文件
    • 2. 通过.h5文件导入网络
    • 3. 可视化网络中间层结果
      • (1)索引取层可视化
      • (2)通过名字取层可视化

TensorFlow神经网络中间层的可视化

1. 训练网络并保存为.h5文件

我们使用AlexNet为例,任务是手写数字识别,训练集使用手写数字集(mnist)。

网络的结构(我们使用的是28x28的黑白图):
在这里插入图片描述

网络搭建和训练的代码

# 最终版
import os.path
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2


# 画出训练过程的准确率和损失值的图像
def plotTrainHistory(history, train, val):
	plt.plot(history[train])
	plt.plot(history[val])
	plt.title('Train History')
	plt.xlabel('Epoch')
	plt.ylabel(train)
	plt.legend(['train', 'validation'], loc = 'upper left')
	plt.show()


(xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()

xTrain = tf.expand_dims(xTrain, axis = 3)
xTest = tf.expand_dims(xTest, axis = 3)
print(f"训练集数据大小:{xTrain.shape}")
print(f"训练集标签大小:{yTrain.shape}")
print(f"测试集数据大小:{xTest.shape}")
print(f"测试集标签大小:{yTest.shape}")

# 归一化
xTrainNormalize = tf.cast(xTrain, tf.float32) / 255
xTestNormalize = tf.cast(xTest, tf.float32) / 255
# 数据独热编码
yTrainOneHot = tf.keras.utils.to_categorical(yTrain)
yTestOneHot = tf.keras.utils.to_categorical(yTest)

model = tf.keras.models.Sequential([
	tf.keras.layers.Conv2D(
		filters = 96, kernel_size = 11, strides = 4, input_shape = (28, 28, 1),
		padding = 'SAME', activation = tf.keras.activations.relu
	),
	tf.keras.layers.BatchNormalization(),
	tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),
	tf.keras.layers.Conv2D(
		filters = 256, kernel_size = 5, strides = 1,
		padding = 'SAME', activation = tf.keras.activations.relu
	),
	tf.keras.layers.BatchNormalization(),
	tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),
	tf.keras.layers.Conv2D(
		filters = 384, kernel_size = 3, strides = 1,
		padding = 'SAME', activation = tf.keras.activations.relu
	),
	tf.keras.layers.Conv2D(
		filters = 384, kernel_size = 3, strides = 1,
		padding = 'SAME', activation = tf.keras.activations.relu
	),
	tf.keras.layers.Conv2D(
		filters = 256, kernel_size = 3, strides = 1,
		padding = 'SAME', activation = tf.keras.activations.relu
	),
	tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = 'SAME'),
	tf.keras.layers.Flatten(),
	tf.keras.layers.Dense(4096, activation = tf.keras.activations.relu),
	tf.keras.layers.Dropout(0.5),
	tf.keras.layers.Dense(4096, activation = tf.keras.activations.relu),
	tf.keras.layers.Dropout(0.5),
	tf.keras.layers.Dense(10, activation = tf.keras.activations.softmax)
])


weightsPath = './AlexNetModel/'

callback = tf.keras.callbacks.ModelCheckpoint(
	filepath = weightsPath,
	save_best_only = True,
	save_weights_only = True,
	verbose = 1
)

model.compile(
	loss = tf.keras.losses.CategoricalCrossentropy(),
	optimizer = tf.keras.optimizers.Adam(),
	metrics = ['accuracy']
)

model.summary()

# 不存在就训练模型
print('参数文件不存在,即将训练模型')
modelTrain = model.fit(
	xTrainNormalize, yTrainOneHot, validation_split = 0.2,
	epochs = 20, batch_size = 300, verbose = 1, callbacks = [callback]
)
model.save("./model.h5")
plotTrainHistory(modelTrain.history, 'loss', 'val_loss')
plotTrainHistory(modelTrain.history, 'accuracy', 'val_accuracy')

2. 通过.h5文件导入网络

把刚才训练得到的模型重新读取,并且重新加载数据集

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


def plot_images(images, number, path, title, gray = False):
    plt.figure()
    plt.title(title)
    order = 1
    for i in range(0, number):

        plt.subplot(3, 3, order)
        if gray:
            plt.imshow(images[:, :, 0, i], cmap = 'gray')
        else:
            plt.imshow(images[:, :, 0, i])
        plt.colorbar()
        order = order + 1

    plt.savefig("./{}.png".format(path))
    plt.show()


if __name__ == '__main__':
    weightsPath = './AlexNetModel/'
    (xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()

    xTrain = tf.expand_dims(xTrain, axis = 3)
    xTest = tf.expand_dims(xTest, axis = 3)
    # print(f"训练集数据大小:{xTrain.shape}")
    # print(f"训练集标签大小:{yTrain.shape}")
    # print(f"测试集数据大小:{xTest.shape}")
    # print(f"测试集标签大小:{yTest.shape}")

    # 归一化
    xTrainNormalize = tf.cast(xTrain, tf.float32) / 255
    xTestNormalize = tf.cast(xTest, tf.float32) / 255
    # 数据独热编码
    yTrainOneHot = tf.keras.utils.to_categorical(yTrain)
    yTestOneHot = tf.keras.utils.to_categorical(yTest)

    model = tf.keras.models.load_model("model.h5")
    model.summary()
    print('Layer Number', len(model.layers))

    sample = xTrainNormalize[0]
    plt.imshow(sample)
    plt.colorbar()
    plt.savefig('./train.png')

3. 可视化网络中间层结果

测试的数字,5

在这里插入图片描述

(1)索引取层可视化

model.layers中存放着这个神经网络的全部层,它是一个list类型变量

AlexNet一共16层(卷积层、全连接层、池化层等都算入),全部存储在里面

model = tf.keras.models.load_model("model.h5")
print('Layer Number', len(model.layers))

可视化的时候我们取出一部分层,然后来预测,预测结果就是取出来这部分层的结果,因此就看到了中间层的结果

output = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),
    model.layers[0],
    model.layers[1],
    model.layers[2],
]).predict(sample)
print('output.shape', output.shape)
plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))

查看三层的结果,即Conv2D+BN+MaxPool,结果是 (28, 4, 1, 96),这里画出前9个

在这里插入图片描述
把这96个叠加在一起的结果

t = output[:, :, 0, 0]
for i in range(1, output.shape[3]):
    t = t + output[:, :, 0, i]
plt.imshow(t)
plt.colorbar()
plt.savefig('./5_Conv2D_BN_MP_1_All.png')

在这里插入图片描述
下面的代码是画出神经网络三个中间层的结果

在这里插入图片描述

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


def plot_images(images, number, path, title, gray = False):
    plt.figure()
    plt.title(title)
    order = 1
    for i in range(0, number):

        plt.subplot(3, 3, order)
        if gray:
            plt.imshow(images[:, :, 0, i], cmap = 'gray')
        else:
            plt.imshow(images[:, :, 0, i])
        plt.colorbar()
        order = order + 1

    plt.savefig("./{}.png".format(path))
    plt.show()


if __name__ == '__main__':
    weightsPath = './AlexNetModel/'
    (xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()

    xTrain = tf.expand_dims(xTrain, axis = 3)
    xTest = tf.expand_dims(xTest, axis = 3)
    # print(f"训练集数据大小:{xTrain.shape}")
    # print(f"训练集标签大小:{yTrain.shape}")
    # print(f"测试集数据大小:{xTest.shape}")
    # print(f"测试集标签大小:{yTest.shape}")

    # 归一化
    xTrainNormalize = tf.cast(xTrain, tf.float32) / 255
    xTestNormalize = tf.cast(xTest, tf.float32) / 255
    # 数据独热编码
    yTrainOneHot = tf.keras.utils.to_categorical(yTrain)
    yTestOneHot = tf.keras.utils.to_categorical(yTest)

    model = tf.keras.models.load_model("model.h5")
    model.summary()
    print('Layer Number', len(model.layers))

    sample = xTrainNormalize[0]
    plt.imshow(sample)
    plt.colorbar()
    plt.savefig('./train.png')

    output = tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),
        model.layers[0],
        model.layers[1],
        model.layers[2],
    ]).predict(sample)
    print('output.shape', output.shape)
    plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))

    t = output[:, :, 0, 0]
    for i in range(1, output.shape[3]):
        t = t + output[:, :, 0, i]
    plt.imshow(t)
    plt.colorbar()
    plt.savefig('./5_Conv2D_BN_MP_1_All.png')

    output = tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
        model.layers[0],
        model.layers[1],
        model.layers[2],
        model.layers[3],
        model.layers[4],
        model.layers[5],
    ]).predict(sample)
    print('output.shape', output.shape)
    plot_images(output, 9, '5_Conv2D_BN_MP_2', str(output.shape))

    t = output[:, :, 0, 0]
    for i in range(1, output.shape[3]):
        t = t + output[:, :, 0, i]
    plt.imshow(t)
    plt.colorbar()
    plt.savefig('./5_Conv2D_BN_MP_2_All.png')

    output = tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
        model.layers[0],
        model.layers[1],
        model.layers[2],
        model.layers[3],
        model.layers[4],
        model.layers[5],
        model.layers[6],
        model.layers[7],
        model.layers[8],
        model.layers[9],
    ]).predict(sample)
    print('output.shape', output.shape)
    plot_images(output, 9, '5_Conv2D_3_MP', str(output.shape))

    t = output[:, :, 0, 0]
    for i in range(1, output.shape[3]):
        t = t + output[:, :, 0, i]
    plt.imshow(t)
    plt.colorbar()
    plt.savefig('./5_Conv2D_3_MP_All.png')

0和5的结果

在这里插入图片描述

(2)通过名字取层可视化

模型的**summary()**成员函数可以查看网络每一层名字和参数情况

model.summary()

博客中使用的AlexNet每一层名字和参数情况
在这里插入图片描述
通过名字来取中间层,并且预测得到中间层可视化结果

在这里插入图片描述
如果我们要看这个池化层的结果,这样写代码

model = tf.keras.models.load_model("../model.h5")
model.summary()

sample = xTrainNormalize[0]
plt.imshow(sample)
plt.colorbar()
plt.savefig('./train.png')

output = tf.keras.models.Model(
    inputs=model.get_layer('conv2d').input,
    outputs=model.get_layer('max_pooling2d').output
).predict(sample)

通过get_layer获取指定名字的层

inputs指定输入层,outputs指定输出层

每一层的名字可以在创建的时候使用name参数指定

...
tf.keras.layers.Conv2D(
		filters = 96, kernel_size = 11, strides = 4, input_shape = (28, 28, 1),
		padding = 'SAME', activation = tf.keras.activations.relu, name = 'Conv2D_1'
	),
...

每一层的名字红色框框出

在这里插入图片描述

下面是例子:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


def plot_images(images, number, path, title, gray = False):
    plt.figure()
    plt.title(title)
    order = 1
    for i in range(0, number):

        plt.subplot(3, 3, order)
        if gray:
            plt.imshow(images[:, :, 0, i], cmap = 'gray')
        else:
            plt.imshow(images[:, :, 0, i])
        plt.colorbar()
        order = order + 1

    plt.savefig("./{}.png".format(path))
    plt.show()


if __name__ == '__main__':
    (xTrain, yTrain), (xTest, yTest) = tf.keras.datasets.mnist.load_data()

    xTrain = tf.expand_dims(xTrain, axis = 3)
    xTest = tf.expand_dims(xTest, axis = 3)

    # 归一化
    xTrainNormalize = tf.cast(xTrain, tf.float32) / 255
    xTestNormalize = tf.cast(xTest, tf.float32) / 255
    # 数据独热编码
    yTrainOneHot = tf.keras.utils.to_categorical(yTrain)
    yTestOneHot = tf.keras.utils.to_categorical(yTest)

    model = tf.keras.models.load_model("../model.h5")
    model.summary()

    sample = xTrainNormalize[0]
    plt.imshow(sample)
    plt.colorbar()
    plt.savefig('./train.png')

    output = tf.keras.models.Model(
        inputs=model.get_layer('conv2d').input,
        outputs=model.get_layer('max_pooling2d').output
    ).predict(sample)

    # output = tf.keras.models.Sequential([
    #     tf.keras.layers.InputLayer(input_shape = (28, 28, 1)),
    #     model.layers[0],
    #     model.layers[1],
    #     model.layers[2],
    # ]).predict(sample)
    print('output.shape', output.shape)
    # plot_images(output, 9, '5_Conv2D_BN_MP_1', str(output.shape))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316350.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

nodejs+vue+微信小程序+python+PHP技术下的音乐推送系统-计算机毕业设计推荐

3.2.1前台用户功能 前台注册用户的功能如下: 注册登录:用户填写个人信息,并验证手机号码进行账户注册,注册成功后方可登录系统。 歌手介绍:用户可以在线进行歌手介绍信息查看等。 音乐库:用户可以在音乐库查…

Flink的处理函数

之前的流处理API,无论是基本的转换、聚合,还是更为复杂的窗口操作,其实都是基于DataStream进行转换的,所以可以统称为DataStream API。 在Flink更底层,我们可以不定义任何具体的算子(比如map,f…

Arrays.asList()方法:陷阱与解决之道

在Java编程中,Arrays类提供了一系列用于操作数组的实用方法。其中,​Arrays.asList()​方法是一个常用的方法,用于快速将数组转换为List集合。然而,这个方法存在一些潜在的陷阱,可能导致出现意外的行为。本文将介绍​A…

智能优化算法应用:基于和声算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于和声算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于和声算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.和声算法4.实验参数设定5.算法结果6.参考文献7.MA…

管理类联考——数学——真题篇——按知识分类——几何——解析几何

文章目录 解析几何2023真题(2023-07)-几何-解析几何-最值-画图求最值-两线相减求最大-联想三角形的“两边差小于第三边”,当为第三边为最大真题(2023-19)-几何-解析几何-最值-画图求最值-圆方程画出圆的形状-两点间距离…

Mr. Cappuccino的第67杯咖啡——MacOS通过PD安装Win11

MacOS通过PD安装Win11 下载ParallelsDesktop安装ParallelsDesktop激活ParallelsDesktop下载Windows11安装Windows11激活Windows11 下载ParallelsDesktop ParallelsDesktop下载地址 安装ParallelsDesktop 关闭上面的窗口,继续操作 激活ParallelsDesktop 关闭上面的…

在官网免费创建一个云mongoDB数据库

MongoDB的设计目标是提供高性能、高可用性、可扩展性和易用性。它采用了文档存储模型,将数据以类似JSON的BSON(Binary JSON)格式存储,并且支持动态模式,允许应用程序更灵活地存储和查询数据。MongoDB还支持水平扩展&am…

Postman接口测试工具使用总结

一、前言 在前后端分离开发时,后端工作人员完成系统接口开发后,需要与前端人员对接,测试调试接口,验证接口的正确性可用性。而这要求前端开发进度和后端进度保持基本一致,任何一方的进度跟不上,都无法及时…

three.js模拟太阳系

地球的旋转轨迹目前设置为了圆形&#xff0c;效果&#xff1a; <template><div><el-container><el-main><div class"box-card-left"><div id"threejs" style"border: 1px solid red"></div><div c…

实操Nginx(七层代理)+Tomcat多实例部署,实现负载均衡和动静分离

目录 Tomcat多实例部署&#xff08;192.168.17.27&#xff09; 1.安装jdk&#xff0c;设置jdk的环境变量 2.安装tomcat在一台已经部署了tomcat的机器上复制tomcat的配置文件取名tomcat1 ​编辑 编辑配置文件更改端口号&#xff0c;将端口号改为8081 启动 tomcat&#xff…

【机器学习】libsvm 简单使用示例(C++)

libsvm简单使用demo 一、libsvm使用说明 二、svm.h源码 #ifndef _LIBSVM_H //如果没有定义 _LIBSVM_H 宏 #define _LIBSVM_H //则定义 _LIBSVM_H 宏&#xff0c;用于防止重复包含#define LIBSVM_VERSION 317 //定义一个宏&#xff0c;表示 libsvm 的版本号#ifdef __cplusplus /…

uniapp之屏幕右侧出现滚动条去掉、隐藏、删除【好用!】

目录 问题解决大佬地址最后 问题 解决 在最外层view上加上class“content”;输入以下样式。注意&#xff1a;两个都必须存在在生效。 .content {/* 跟屏幕高度一样高,不管view中有没有内容,都撑开屏幕高的高度 */height: 100vh; overflow: auto; } .content::-webkit-scrollb…

计算机网络考研辨析(后续整理入笔记)

文章目录 体系结构物理层速率辨析交换方式辨析编码调制辨析 链路层链路层功能介质访问控制&#xff08;MAC&#xff09;信道划分控制之——CDMA随机访问控制轮询访问控制 扩展以太网交换机 网络层网络层功能IPv4协议IP地址IP数据报分析ICMP 网络拓扑与转发分析&#xff08;重点…

软件设计师——计算机网络(三)

&#x1f4d1;前言 本文主要是【计算机网络】——软件设计师——计算机网络的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1…

Mac安装Typora实现markdown自由

一、什么是markdown Markdown 是一种轻量级标记语言&#xff0c;创始人为约翰格鲁伯&#xff08;John Gruber&#xff09;。 它允许人们使用易读易写的纯文本格式编写文档&#xff0c;然后转换成有效的 XHTML&#xff08;或者HTML&#xff09;文档。这种语言吸收了很多在电子邮…

【AI工具】GitHub Copilot IDEA安装与使用

GitHub Copilot是一款AI编程助手&#xff0c;它可以帮助开发者编写代码&#xff0c;提供代码建议和自动完成功能。以下是GitHub Copilot在IDEA中的安装和使用步骤&#xff1a; 安装步骤&#xff1a; 打开IDEA&#xff0c;点击File -> Settings -> Plugins。在搜索框中输…

棋牌的电脑计时计费管理系统教程,棋牌灯控管理软件操作教程

一、前言 有的棋牌室在计时的时候&#xff0c;需要使用灯控管理&#xff0c;在开始计时的时候打开灯&#xff0c;在结账后关闭灯&#xff0c;也有的不需要用灯控&#xff0c;只用来计时。 下面以 佳易王棋牌计时计费管理系统软件为例说明&#xff1a; 软件试用版下载或技术支…

Selenium安装WebDriver:ChromeDriver与谷歌浏览器版本快速匹配_最新版120

最近在使用通过selenium操作Chrome浏览器时&#xff0c;安装中遇到了Chrome版本与浏览器驱动不匹配的的问题&#xff0c;在此记录安装下过程&#xff0c;如何快速找到与谷歌浏览器相匹配的ChromeDriver驱动版本。 1. 确定Chrome版本 我们首先确定自己的Chrome版本 Chrome设置…

NE555汽车防盗报警电路图

实用汽车防盗报警电路如图所示。它主要由防盗部分和报警两大部分电路组成。防盗电路&#xff1a;当汽车主人离开汽车时&#xff0c;将防盗开关S置于“B”位置&#xff0c;使汽车进入防盗状态。当有窃贼进入驾驶室企图发动汽车将其盗走时&#xff0c;只要拧动点火开关&#xff0…

python+requests+pytest 接口自动化实现

最近工作之余拿公司的项目写了一个接口测试框架&#xff0c;功能还不是很完善&#xff0c;算是抛砖引玉了&#xff0c;欢迎各位来吐槽。 主要思路&#xff1a; ①对 requests 进行二次封装&#xff0c;做到定制化效果 ②使用 excel 存放接口请求数据&#xff0c;作为数据驱动 ③…