神经网络可视化:卷积核可视化

news2025/1/10 18:16:36

文章目录

  • 前言
      • 一般过程:
  • 一、代码示例
  • 二、卷积核和输入图片相乘可视化
  • 总结


前言

卷积核可视化是一种用于理解卷积神经网络 (CNN) 中卷积层的工作原理和特征提取能力的方法。通过可视化卷积核,我们可以观察卷积层学习到的特征模式,帮助我们理解网络如何对输入进行处理。

本文给出了一个具体的pytorch实现的例子。
本文还给了一个用权重核和直接卷积图像的例子。

一般过程:

导入必要的库和模型:首先,你需要导入相关的库,如 PyTorch、NumPy 和 Matplotlib,并加载已经训练好的 CNN 模型。

获取卷积层权重:从模型中获取卷积层的权重。这些权重通常存储在模型的卷积层参数中。

可视化卷积核:对于每个卷积层,获取对应的权重,并以图像形式展示。可以使用 Matplotlib 或其他图像处理库来显示卷积核。

可选:对于多通道的卷积核,你可以将每个通道的权重分别可视化,以更好地理解卷积核的组成。


一、代码示例

以resnet50为例进行第一层可视化

import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms


def visualize_conv_filters():
    # 设置GPU设备
    torch.cuda.set_device(0) # 没有GPU可以删除

    model = models.resnet50(pretrained=True)
    model = model.cuda()  # 将模型移动到GPU上 # 没有GPU可以删除

    # 获取第一个卷积层的权重
    conv1_weights = model.conv1.weight.data.cpu().numpy()
    # 调整权重形状,从 [out_channels, in_channels, kernel_size, kernel_size] 变为 [out_channels, kernel_size, kernel_size, in_channels]
    conv1_weights = np.transpose(conv1_weights, (0, 2, 3, 1))

    # 可视化卷积核
    fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(12, 12))
    for i, ax in enumerate(axes.flat):
        ax.imshow(conv1_weights[i])
        ax.axis('off')

    plt.show()


if __name__ == '__main__':
    visualize_conv_filters()

获得图像:
在这里插入图片描述


二、卷积核和输入图片相乘可视化

把卷积核(作为权重)和图片进行卷积操作

import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms

def visualize_conv_filters(model_conv, img_file):
    # 设置GPU设备
    # torch.cuda.set_device(gpu_number)

    # model = models.resnet50(pretrained=True)
    # model = model.cuda()  # 将模型移动到GPU2上

    # 获取第一个卷积层的权重
    conv1_weights = model_conv.weight.data.cpu().numpy()
    # 调整权重形状,从 [out_channels, in_channels, kernel_size, kernel_size] 变为 [out_channels, kernel_size, kernel_size, in_channels]
    conv1_weights = np.transpose(conv1_weights, (0, 2, 3, 1))

    # 读入图像
    image = Image.open(img_file)
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0)  # 添加批次维度并将图像移动到GPU上

    # 对每个卷积核进行卷积操作并绘制图像
    fig, axes = plt.subplots(nrows=8, ncols=4, figsize=(12, 6),dpi=100)

    al_list = range(conv1_weights.shape[0])

    for i, ax in zip(al_list, axes.flat):
        conv1_weight = conv1_weights[i]
        conv1_weight = torch.from_numpy(conv1_weight).permute(2, 0, 1).cuda()  # 将当前卷积核移动到GPU上

        # 对图像的每个通道进行卷积操作
        output_channels = []
        for channel in range(conv1_weight.size(0)):
            conv1_weight_ = torch.unsqueeze(conv1_weight, 0)
            ## 重点代码,卷积操作
            output_channel = torch.nn.functional.conv2d(image[:,channel:channel+1,:,:],
                                                        conv1_weight_[:,channel:channel+1,:,:],
                                                        stride=(1, 1), padding=0)
            output_channels.append(output_channel)

        # 合并卷积后的通道
        output = torch.cat(output_channels, dim=1)

        # 转换为NumPy数组并绘制图像
        output = output.squeeze(0).cpu().detach().numpy()  # 移除批次维度,并将结果移动到CPU上
        output = np.transpose(output, (1, 2, 0))  # 调整形状为 [height, width, channels]
        ax.imshow(output)
        ax.axis('off')
        # input('>>')

    if not os.path.exists('./img/'):
        os.mkdir('./img/')
    plt.savefig('./img/abc.png')

if __name__ == '__main__':
    model = models.resnet50(pretrained=True)
    torch.cuda.set_device(0)
    img_file = 'img_1.png'
    visualize_conv_filters(model.conv1, img_file)
    

原图:
在这里插入图片描述
乘以卷积核之后的可视化结果:
在这里插入图片描述


总结

以上就是今天要讲的内容,本文仅仅简单介绍了卷积核可视化的计算

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1194495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

煤矿企业如何选择合适的设备健康管理系统

在煤矿开采的过程中,机电设备发挥着重要的作用。但大量的机电设备的使用也给煤矿企业设备管理提出了一定的要求。随着工业领域数字化的深入应用,煤矿机电设备的自动化、智能化管理已经成为煤矿企业发展的重要手段。保障机电设备的正常运行,减…

跨境电商源码搭建:开启你的全球贸易新纪元

随着全球电子商务的快速发展,跨境电商已经成为越来越多企业的必然选择。通过跨境电商平台,企业可以拓展海外市场,扩大销售范围,提升品牌影响力。而要实现这一目标,源码搭建是不可或缺的一环。本文将为你揭示跨境电商源…

【紫光同创国产FPGA教程】——【PGL22G第十章】DDR3读写实验例程

本原创教程由深圳市小眼睛科技有限公司创作,版权归本公司所有,如需转载,需授权并注www.meyesemi.com) 适用于板卡型号: 紫光同创PGL22G开发平台(盘古22K) 一:盘古22K开发板(紫光…

ChatGPT 宕机?OpenAI 将中断归咎于 DDoS 攻击

您的 ChatGPT 已关闭吗?您是否遇到 ChatGPT 问题,例如连接问题或遇到“长响应时出现网络错误”?– ChatGPT 遭受了一系列 DDoS 攻击,显然是由匿名苏丹组织策划的。 OpenAI 的 ChatGPT 是一款流行的人工智能聊天机器人,…

centos7安装Nexus(Maven私服)与配置使用教程

之前有位大佬问我,他说有个第三方的Jar包,在idea导出库中使用,现在要部署上线测试,要如何导进去打包。 我说,不用那么麻烦,搞个Nexus私服,将Jar上传上去,然后配置Maven的setting文件…

PHP的curl会话

介绍: Curl(Client for URLs)在PHP中是一个强大而灵活的工具,用于进行各种网络请求。PHP中的Curl库允许开发者通过代码模拟HTTP请求、与API交互、进行数据传输等。在这里,我们将详细解析PHP中Curl会话的各个方面,涵盖…

【博士每天一篇文献-算法】Modular state space of echo state network

阅读时间:2023-11-2 1 介绍 年份:2013 作者:陈卫彪,华南理工大学计算机科学与工程学院, 期刊:Neurocomputing 引用量:17 本文介绍了一种改进回声状态网络(ESN)预测性能的新方法。该…

@CreateCache:深度解析其功能与优势

1. CreateCache前言 在现代Web应用程序开发中,缓存是提高性能和响应速度的重要手段之一。CreateCache注解是JetCache框架中用于创建缓存的注解。本文将介绍CreateCache注解以及它在缓存管理中的作用。 2. CreateCache使用示例 以下是使用CreateCache注解的一个简…

影刀掌握手头,仿佛自由人--更符合中国宝宝体质的自动化工具

以前,影刀是一个邂逅的初见小工具,新奇在里头,踌躇在外头; 现在,影刀是一个稳定的职场贾维斯,高效在里头,悠闲在外头; 以后,影刀是一个潜力的知己老司机,有序…

Pow(x, n)

题目链接 Pow(x, n) 题目描述 注意点 n 是一个整数要么 x 不为零&#xff0c;要么 n > 0-100.0 < x < 100.0 解答思路 完成x的n次方的功能 代码 class Solution {public double myPow(double x, int n) {long N n;return N > 0 ? quickMul(x, N) : 1.0 / …

java项目之网上跳蚤市场(ssm框架)

项目简介 网上跳蚤市场实现了以下功能&#xff1a; 管理员功能需求 管理员登陆后&#xff0c;主要模块包括首页&#xff0c;个人中心&#xff0c;会员管理&#xff0c;商品分类管理&#xff0c;商品信息管理&#xff0c;求购信息管理&#xff0c;留言板管理&#xff0c;系统管…

安卓手机搭建博客网站发布公网访问:Termux+Hexo结合内网穿透工具轻松实现

文章目录 前言 1.安装 Hexo2.安装cpolar3.远程访问4.固定公网地址 前言 Hexo 是一个用 Nodejs 编写的快速、简洁且高效的博客框架。Hexo 使用 Markdown 解析文章&#xff0c;在几秒内&#xff0c;即可利用靓丽的主题生成静态网页。 下面介绍在Termux中安装个人hexo博客并结合…

Python爬虫——入门爬取网页数据

目录 前言 一、Python爬虫入门 二、使用代理IP 三、反爬虫技术 1. 间隔时间 2. 随机UA 3. 使用Cookies 四、总结 前言 本文介绍Python爬虫入门教程&#xff0c;主要讲解如何使用Python爬取网页数据&#xff0c;包括基本的网页数据抓取、使用代理IP和反爬虫技术。 一、…

如何开发你的第一个Flutter App?

Flutter这些年发展的很快&#xff0c;特别是在 Google 持续的加持下&#xff0c;Flutter SDK 的版本号已经来到了 3开头&#xff0c;也正式开始对 Windows、macOS 和 Linux 桌面环境提供支持。如果从 Flutter 特有的优势来看&#xff0c;我个人认为主要是它已经几乎和原生的性能…

6.2.1 邻接矩阵

邻接矩阵 表示方法&#xff1a;优点&#xff1a;缺点&#xff1a;适用情况&#xff1a;案例代码 邻接矩阵是一种常见的图的存储结构&#xff0c;用于表示图中顶点之间的连接关系。它是一个二维数组&#xff0c;其中行和列分别表示图中的顶点&#xff0c;而数组中的值表示连接顶…

工商银行卡安全码怎么看

工商银行的安全码&#xff0c;作为一项至关重要的安全措施&#xff0c;旨在保护用户的银行账户和交易安全。为了查看工商银行的安全码用户需要按照以下步骤操作&#xff1a; 首先&#xff0c;用户需要使用电脑或手机访问工商银行的网上银行平台。在平台首页&#xff0c;用户需要…

创建一个事务级临时表或者会话级临时表继续测试,在什么情况下临时表里的数据会消失

目录 一、测试事务级临时表 1、创建事务级临时表 2、插入测试数据 3、查看表中的数据 4、提交事务 5、再次查看表中数据 二、测试会话级临时表 1、创建会话级临时表 2、插入测试数据 3、查看表中的数据 4、提交事务再次查看数据 5、关闭当前会话 6、再次进入数据库…

Android发热监控实践

一、背景 相信移动端高度普及的现在&#xff0c;大家或多或少都会存在电量焦虑&#xff0c;拥有过手机发热发烫的糟糕体验。而发热问题是一个长时间、多场景的指标存在&#xff0c;且涉及到端侧应用层、手机 ROM 厂商系统、外界环境等多方面的影响。如何有效衡量发热场景、定位…

【GUI软件开发】小红书评论采集:自动采集1w多条,含二级评论!

文章目录 一、爬取目标1.1 效果截图1.2 演示视频1.3 软件说明 二、代码讲解2.1 爬虫采集模块2.2 软件界面模块2.3 日志模块 三、附完整源码及软件 一、爬取目标 您好&#xff01;我是马哥python说 &#xff0c;一名10年程序猿。 我用python开发了一个爬虫采集软件&#xff0c…

8年经验之谈 —— 性能压测工具选型对比!

本文致力于给出性能压测的概念与背景介绍&#xff0c;同时针对市场上的一些性能压测工具&#xff0c;给出相应的对比&#xff0c;从而帮助大家更好地针对自身需求实现性能压测。 为什么要做性能压测 在介绍性能压测概念与背景之前&#xff0c;首先解释下为什么要做性能压测。…