【神经网络结构可视化】使用 Visualkeras 可视化 Keras / TensorFlow 神经网络结构

news2025/1/6 23:32:13

文章目录

    • Visualkeras介绍
    • 下载安装
    • 代码示例
      • 1、导入必要的库
      • 2、创建VGG16神经网络模型
      • 3、可视化神经网络结构
      • 4、完整代码
      • 5、使用教程
    • 可视化自己创建的神经网络结构
      • 1、导入要的库
      • 2、创建自己的神经网络模型
      • 3、可视化神经网络结构图
      • 4、完整代码


Visualkeras介绍

Visualkeras是一个Python包,用于帮助可视化Keras(独立或包含在tensorflow中)神经网络架构。它允许简单的造型来满足大多数需求。该模块支持分层风格的架构生成,这对CNN(卷积神经网络)非常有用。


下载安装

Visualkeras源代码链接:https://github.com/paulgavrikov/visualkeras

使用清华源安装Visualkeras

pip install visualkeras -i https://pypi.tuna.tsinghua.edu.cn/simple

代码示例

使用CNN经典网络VGG16作为示例,可视化神经网络结构。

1、导入必要的库

from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, Dropout, MaxPooling2D, InputLayer, ZeroPadding2D
from collections import defaultdict
import visualkeras
from PIL import ImageFont

2、创建VGG16神经网络模型

# create VGG16
image_size = 224
model = Sequential()
model.add(InputLayer(input_shape=(image_size, image_size, 3)))

model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(64, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(64, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(128, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(128, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(MaxPooling2D())
model.add(visualkeras.SpacingDummyLayer())

model.add(Flatten())

model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1000, activation='softmax'))

3、可视化神经网络结构

# Now visualize the model!

color_map = defaultdict(dict)
color_map[Conv2D]['fill'] = 'orange'
color_map[ZeroPadding2D]['fill'] = 'gray'
color_map[Dropout]['fill'] = 'pink'
color_map[MaxPooling2D]['fill'] = 'red'
color_map[Dense]['fill'] = 'green'
color_map[Flatten]['fill'] = 'teal'

font = ImageFont.truetype("./Arial.ttf", 32)

visualkeras.layered_view(model, to_file='./figures/vgg16.png', type_ignore=[visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_legend.png', type_ignore=[visualkeras.SpacingDummyLayer],
                         legend=True, font=font)
visualkeras.layered_view(model, to_file='./figures/vgg16_spacing_layers.png', spacing=0)
visualkeras.layered_view(model, to_file='./figures/vgg16_type_ignore.png',
                         type_ignore=[ZeroPadding2D, Dropout, Flatten, visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_color_map.png',
                         color_map=color_map, type_ignore=[visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_flat.png',
                         draw_volume=False, type_ignore=[visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_scaling.png',
                         scale_xy=1, scale_z=1, max_z=1000, type_ignore=[visualkeras.SpacingDummyLayer])

4、完整代码

from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, Dropout, MaxPooling2D, InputLayer, ZeroPadding2D
from collections import defaultdict
import visualkeras
from PIL import ImageFont

# create VGG16
image_size = 224
model = Sequential()
model.add(InputLayer(input_shape=(image_size, image_size, 3)))

model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(64, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(64, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(128, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(128, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(256, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(visualkeras.SpacingDummyLayer())

model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(ZeroPadding2D((1, 1)))
model.add(Conv2D(512, activation='relu', kernel_size=(3, 3)))
model.add(MaxPooling2D())
model.add(visualkeras.SpacingDummyLayer())

model.add(Flatten())

model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1000, activation='softmax'))

# Now visualize the model!

color_map = defaultdict(dict)
color_map[Conv2D]['fill'] = 'orange'
color_map[ZeroPadding2D]['fill'] = 'gray'
color_map[Dropout]['fill'] = 'pink'
color_map[MaxPooling2D]['fill'] = 'red'
color_map[Dense]['fill'] = 'green'
color_map[Flatten]['fill'] = 'teal'

font = ImageFont.truetype("./Arial.ttf", 32)

visualkeras.layered_view(model, to_file='./figures/vgg16.png', type_ignore=[visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_legend.png', type_ignore=[visualkeras.SpacingDummyLayer],
                         legend=True, font=font)
visualkeras.layered_view(model, to_file='./figures/vgg16_spacing_layers.png', spacing=0)
visualkeras.layered_view(model, to_file='./figures/vgg16_type_ignore.png',
                         type_ignore=[ZeroPadding2D, Dropout, Flatten, visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_color_map.png',
                         color_map=color_map, type_ignore=[visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_flat.png',
                         draw_volume=False, type_ignore=[visualkeras.SpacingDummyLayer])
visualkeras.layered_view(model, to_file='./figures/vgg16_scaling.png',
                         scale_xy=1, scale_z=1, max_z=1000, type_ignore=[visualkeras.SpacingDummyLayer])

5、使用教程

  • 创建一个项目文件夹(例如:Project)
  • 在创建的项目文件夹Project 中新建一个文件夹(文件夹名为 figures )
  • 通过链接(https://ultralytics.com/assets/Arial.ttf)下载 Arial.ttf 字体文件
  • 将下载的 Arial.ttf 字体文件 放在 项目文件夹Project 下
  • 在 项目文件夹Project 下新建一个py文件(如:examples.py)
  • 将上述的完整代码复制到 examples.py 中
  • 运行examples.py
  • 在 figures文件夹中查看生成的可视化图
  • vgg16.png
    在这里插入图片描述
  • vgg16_legend.png
    在这里插入图片描述

可视化自己创建的神经网络结构

1、导入要的库

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import models,layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Flatten, Dense
from tensorflow.keras.callbacks import Callback, ModelCheckpoint
import visualkeras

2、创建自己的神经网络模型

将以下代码替换为自己的Keras / TensorFlow 神经网络结构。

model = models.Sequential()
# 第一层卷积层
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(48, 48, 1)))  # 假设输入图像大小为48x48,1为灰度图
model.add(layers.MaxPooling2D((2, 2)))
# 第二层卷积层
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# 展平层
model.add(layers.Flatten())
# 全连接层
model.add(layers.Dense(64, activation='relu'))
# 输出层,假设分类任务有7个类别
model.add(layers.Dense(7, activation='softmax'))

3、可视化神经网络结构图

显示层风格图

visualkeras.layered_view(model).show() # 只显示图
# visualkeras.layered_view(model, to_file='output.png').show() # 保存和显示图

在这里插入图片描述
显示带有标签的层风格图

from PIL import ImageFont
font = ImageFont.truetype("./Arial.ttf", 32)

visualkeras.layered_view(model, legend=True, font=font).show() # 只显示图
# visualkeras.layered_view(model, to_file='output_legend.png', legend=True, font=font).show()  # 保存和显示图

在这里插入图片描述

4、完整代码

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import models,layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Flatten, Dense
from tensorflow.keras.callbacks import Callback, ModelCheckpoint
import visualkeras

# 可以将下面这部分创建模型的代码更换你自己的神经网络结构
model = models.Sequential()
# 第一层卷积层
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(48, 48, 1)))  # 假设输入图像大小为48x48,1为灰度图
model.add(layers.MaxPooling2D((2, 2)))
# 第二层卷积层
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# 展平层
model.add(layers.Flatten())
# 全连接层
model.add(layers.Dense(64, activation='relu'))
# 输出层,假设分类任务有7个类别
model.add(layers.Dense(7, activation='softmax'))

visualkeras.layered_view(model).show() # 只显示图
# visualkeras.layered_view(model, to_file='output.png').show() # 保存和显示图

from PIL import ImageFont
font = ImageFont.truetype("./Arial.ttf", 32)

visualkeras.layered_view(model, legend=True, font=font).show() # 只显示图
# visualkeras.layered_view(model, to_file='output_legend.png', legend=True, font=font).show()  # 保存和显示图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1697793.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据信用报告查询有哪些作用?哪个平台更好?

大数据信用是基于大数据技术,通过大数据系统生成的大数据信用报告,报告收集了查询人在非银环境下的申贷数据以及履约行为和信用风险的综合性报告。很多人都会问,大数据信用报告查询有哪些作用?哪个查询平台更好的疑问,下文就详细…

【教学类-58-04】黑白三角拼图04(2-10宫格,每个宫格随机1张-6张,带空格纸)

背景需求: 前期制作了黑白三角拼图2*2、3*3、4*4,确定了基本模板,就可以批量制作更多格子数 【教学类-58-01】黑白三角拼图01(2*2宫格)固定256种随机抽取10张-CSDN博客文章浏览阅读522次,点赞13次&#x…

【JavaEE】深入学习Spring MVC:掌握参数传递与映射

目录 3. 请求传递单个参数传递多个参数传递对象后端参数重命名传递数组传递集合 3. 请求 请求就是要学习如何传参 传递单个参数多个参数对象数组/集合…… 传递单个参数 RequestMapping("/m1") public String m1(String name){return "接收到的参数name:&qu…

python如何把字符串变成小写字母

Python中,将字符串中的字母转换成小写字母,字符串变量提供了2种方法,分别是title()、lower()。 Python title()方法 title()方法用于将字符串中每个单词的首字母转为大写,其他字母全部转为小写,转换完成后&#xff0…

RAG概述(二):Advanced RAG 高级RAG

目录 概述 Advanced RAG Pre-Retrieval预检索 优化索引 增强数据粒度 粗粒度 细粒度 展开说说 优化索引 Chunk策略 Small2Big方法 元数据 引入假设性问题 对齐优化 混合检索 查询优化 查询扩展 查询转换 Post-Retrieval后检索 参考 概述 Native RAG&#…

Kafka SASL_SSL集群认证

背景 公司需要对kafka环境进行安全验证,目前考虑到的方案有Kerberos和SSL和SASL_SSL,最终考虑到安全和功能的丰富度,我们最终选择了SASL_SSL方案。处于知识积累的角度,记录一下kafka SASL_SSL安装部署的步骤。 机器规划 目前测试环境公搭建了三台kafka主机服务,现在将详…

iOS--锁的学习

iOS--锁的学习 锁的介绍线程安全 锁的分类自旋锁和互斥锁OSSpinLockos_unfair_lockpthread_mutexpthread_mutex的属性 NSLockNSRecursiveLockNSConditionNSConditionLockdispatch_semaphoredispatch_queuesynchronizedatomicpthread_rwlock:读写锁dispatch_barrier_…

react【框架原理详解】JSX 的本质、SyntheticEvent 合成事件机制、组件渲染过程、组件更新过程

JSX 的本质 JSX 代码本身并不是 HTML,也不是 Javascript,在渲染页面前,需先通过解析工具(如babel)解析之后才能在浏览器中运行。 babel官网可查看 JSX 解析后的效果 更早之前,Babel 会把 JSX 转译成一个 R…

Linux 内核

查看内核的发行版 $ uname -r 5.4.0-150-genericcd /lib/modules/5.4.0-150-generic, 内核源码所在的位置:/usr/src 这里的内核源码路径(–kernel-source-path)即为: cd /usr/src/linux-headers-5.4.0-150-generic/ 临时生效: …

自建公式,VBA在Excel中轻松获取反义词

自建公式,VBA在Excel中轻松获取反义词 文章目录 前言一、爬取网站数据二、代码1.创建数据发送及返回方法2.汉字转UTF8编码2.获取反义词 三、运行效果截图 前言 小学语文中,近义词、反义词是必考内容之一。家长不能随时辅导怎么办?有VBA&…

dsPIC单片机buck-boost拓扑双向DC-DC电源变换器设计

为实现电池储能装置的双向DC-DC变换器,本系统以buck-boost拓扑电路为核心,通过DSPICFJ256GP710单片机最小系统控制拓扑的切换,从而进行buck恒流充电和boost恒压放电。充电时效率≥94%,放电时效率≥95.5%,具有过压保护及…

引流500+创业粉,抖音口播工具

在抖音平台运营一个专注于口播的工具号,旨在集结超过500位热衷于创业的粉丝,这需要精心筹划的内容策略和周到的运营计划。首先,明确你的口播工具号所专注的领域,无论是分享创业经验、财务管理技巧还是案例分析,确保你所…

springboot错误

错误总结 1、使用IDEA 的 initialalzer显示2、IDEA 新建文件 没有 java class3、java: 错误: 不支持发行版本 22解决方法4、IDEA-SpringBoot项目yml配置文件不自动提示解决办法 1、使用IDEA 的 initialalzer显示 IDEA创建SpringBoot项目时出现:Initialization fail…

秋招突击——算法——模板题——区间DP(1)——加分二叉树

文章目录 题目描述思路分析实现代码分析总结 题目描述 思路分析 实现代码 不过我的代码写的真的不够简洁&#xff0c;逻辑不够清晰&#xff0c;后续多练练吧。 // 组合数问题 #include <iostream> #include <algorithm>using namespace std;const int N 35; int…

JDBC使用QreryRunner简化SQL查询注意事项

QreryRunner是Dbutils的核心类之一&#xff0c;它显著的简化了SQL查询&#xff0c;并与ResultSetHandler协同工作将使编码量大为减少。 注意事项 1. 使用QreryRunner必须保证实体类的变量名&#xff0c;和sql语句中要查找的字段名必须相同&#xff0c;否则查询 不到数据,会出…

视频号小店去哪里找货源?最全货源渠道分享!

大家好&#xff0c;我是电商糖果 视频号小店因为是这两年电商行业新出来的黑马&#xff0c;吸引着不少商家入驻。 入驻了商家中很多都没有自己的货源渠道。 他们基本都是从无货源开始起步&#xff0c;后期通过积累资源&#xff0c;慢慢搭建属于自己的货源渠道。 可是渐渐的…

FreeRTOS中断中释放信号量

串口接收&#xff1a;中断程序中逆序打印字符串 串口接收&#xff1a;逆序回环实验思路 注&#xff1a;任务优先级较高会自动的切换上下文进行运行 FreeRTOS中的顶半操作和底半操作 顶半操作和底半操作“这种叫法源自与Linux”在嵌入式开发中&#xff0c;为了和Linux操作系统做…

leetcode 1631. 最小体力消耗路径 二分+BFS、并查集、Dijkstra算法

最小体力消耗路径 题目与水位上升的泳池中游泳类似 二分查找BFS 首先&#xff0c;采用二分查找&#xff0c;确定一个体力值&#xff0c;再从左上角&#xff0c;进行BFS&#xff0c;查看能否到达右下角&#xff0c;如果不行&#xff0c;二分查找就往大的数字进行查找&#xff…

终端安全管理系统、天锐DLP(数据泄露防护系统)| 数据透明加密保护,防止外泄!

终端作为企业员工日常办公、数据处理和信息交流的关键工具&#xff0c;承载着企业运营的核心信息资产。一旦终端安全受到威胁&#xff0c;企业的敏感数据将面临泄露风险&#xff0c;业务流程可能遭受中断&#xff0c;甚至整个企业的运营稳定性都会受到严重影响。 因此&#xff…

Java——认识Java

一、介绍 1、起源 Java 是由 Sun Microsystems 于 1995 年推出的一种面向对象的编程语言和计算平台。由詹姆斯高斯林&#xff08;James Gosling&#xff0c;后来被称为Java之父&#xff09;和他的同事们共同研发。后来&#xff0c;Sun 公司被 Oracle&#xff08;甲骨文&#…