查看神经网络中间层特征矩阵及卷积核参数

news2024/9/30 17:28:01

可视化feature maps以及kernel weights,使用alexnet模型进行演示。

1. 查看中间层特征矩阵

alexnet模型,修改了向前传播

import torch
from torch import nn
from torch.nn import functional as F

# 对花图像数据进行分类
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(3,48,11,4,2)
        self.pool1 = nn.MaxPool2d(3,2)
        self.conv2 = nn.Conv2d(48,128,5,padding=2)
        self.pool2 = nn.MaxPool2d(3,2)
        self.conv3 = nn.Conv2d(128,192,3,padding=1)
        self.conv4 = nn.Conv2d(192,192,3,padding=1)
        self.conv5 = nn.Conv2d(192,128,3,padding=1)
        self.pool3 = nn.MaxPool2d(3,2)

        self.fc1 = nn.Linear(128*6*6,2048)
        self.fc2 = nn.Linear(2048,2048)
        self.fc3 = nn.Linear(2048,num_classes)
        # 是否进行初始化
        # 其实我们并不需要对其进行初始化,因为在pytorch中,对我们对卷积及全连接层,自动使用了凯明初始化方法进行了初始化
        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        outputs = []  # 定义一个列表,返回我们要查看的哪一层的输出特征矩阵
        x = self.conv1(x)
        outputs.append(x)
        x = self.pool1(F.relu(x,inplace=True))
        x = self.conv2(x)
        outputs.append(x)
        x = self.pool2(F.relu(x,inplace=True))
        x = self.conv3(x)
        outputs.append(x)
        x = F.relu(x,inplace=True)
        x = F.relu(self.conv4(x),inplace=True)
        x = self.pool3(F.relu(self.conv5(x),inplace=True))
        x = x.view(-1,128*6*6)
        x = F.dropout(x,p=0.5)
        x = F.relu(self.fc1(x),inplace=True)
        x = F.dropout(x,p=0.5)
        x = F.relu(self.fc2(x),inplace=True)
        x = self.fc3(x)

        # for name,module in self.named_children():
        #     x = module(x)
        #     if name == ["conv1","conv2","conv3"]:
        #         outputs.append(x)
        return outputs

    # 初始化权重
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                # 凯明初始化 - 何凯明
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m,nn.Linear):
                nn.init.normal_(m.weight, 0,0.01)  # 使用正态分布给权重赋值进行初始化
                nn.init.constant_(m.bias,0)

拿到向前传播的结果,对特征图进行可视化,这里,我们使用训练好的模型,直接加载模型参数。

注意,要使用与训练时相同的数据预处理。

import matplotlib.pyplot as plt
from torchvision import transforms
import alexnet_model
import torch
from PIL import Image
import numpy as np
from alexnet_model import AlexNet

# AlexNet 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)

image = Image.open("./images/yjx.jpg")
image = transform(image)
image = image.unsqueeze(0)

with torch.no_grad():
    output = model(image)

for feature_map in output:
    # (N,C,W,H) -> (C,W,H)
    im = np.squeeze(feature_map.detach().numpy())
    # (C,W,H) -> (W,H,C)
    im = np.transpose(im,[1,2,0])
    plt.figure()
    # 展示当前层的前12个通道
    for i in range(12):
        ax = plt.subplot(3,4,i+1) # i+1: 每个图的索引
        plt.imshow(im[:,:,i],cmap='gray')
    plt.show()

结果:

在这里插入图片描述


2. 查看卷积核参数

import matplotlib.pyplot as plt
import numpy as np
import torch

from AlexNet.model import AlexNet

# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)

weights_keys = model.state_dict().keys()
for key in weights_keys:
    if "num_batches_tracked" in key:
        continue
    weight_t = model.state_dict()[key].numpy()
    weight_mean = weight_t.mean()
    weight_std = weight_t.std(ddof=1)
    weight_min = weight_t.min()
    weight_max = weight_t.max()
    print("mean is {}, std is {}, min is {}, max is {}".format(weight_mean, weight_std, weight_min, weight_max))

    weight_vec = np.reshape(weight_t,[-1])
    plt.hist(weight_vec,bins=50)
    plt.title(key)
    plt.show()

结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1395695.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Elasticsearch8 集群搭建(二)配置篇:(3)安全配置

此篇记录Elasticsearch 8.x传输层的安全配置。 传输层节点间: 如果集群有多个节点,必须在节点间配置TLS。生产模式下,如果不启用TLS,集群将无法启动。 图片来源:Set up basic security for the Elastic Stack | Elas…

2018年认证杯SPSSPRO杯数学建模D题(第二阶段)投篮的最佳出手点全过程文档及程序

2018年认证杯SPSSPRO杯数学建模 D题 投篮的最佳出手点 原题再现: 影响投篮命中率的因素不仅仅有出手角度、球感、出手速度,还有出手点的选择。规范的投篮动作包含两膝微屈、重心落在两脚掌上、下肢蹬地发力、身体随之向前上方伸展、同时抬肘向投篮方向…

vue2使用mapbox

1.安装mapbox 这里安装的是"mapbox-gl": "^3.0.1", npm install --save mapbox-gl 安装mapbox 2.安装worker-loader npm install worker-loader --save-dev 安装worker-loader 配置vue.config.js const { defineConfig } require(vue/cli-servic…

MFC 序列化机制

目录 文件操作相关类 序列化机制相关类 序列化机制使用 序列化机制执行过程 序列化类对象 文件操作相关类 CFile:文件操作类,封装了关于文件读写等操作,常见的方法: CFile::Open:打开或者创建文件CFile::Write/…

AI图片物体移除器:高效、便捷的AI照片物体擦除工具

在我们的日常生活中,照片是一种重要的记录和表达方式。然而,有时候我们会遇到需要将照片中的某些物体和元素去除的情况。这时候,传统的图像处理软件可能过于复杂,让人望而却步。为了解决这个问题,AI图片物体移除器的软…

目标检测--02(Two Stage目标检测算法1)

Two Stage目标检测算法 R-CNN R-CNN有哪些创新点? 使用CNN(ConvNet)对 region proposals 计算 feature vectors。从经验驱动特征(SIFT、HOG)到数据驱动特征(CNN feature map),提高特…

游泳耳机有什么好处?四款适合水下听歌的优质游泳耳机分享

游泳是一项健康有益的运动,而搭配一副高质量的游泳耳机,更能在游泳过程中享受音乐的陪伴。本文将介绍游泳耳机的好处,并为大家推荐四款适合水下听歌的游泳耳机,让大家在游泳中拥有更加丰富的体验。 接下来跟我一起看看游泳耳机的好…

GAN在图像数据增强中的应用

在图像数据增强领域,生成对抗网络(GAN)的应用主要集中在通过生成新的图像数据来扩展现有数据集的规模和多样性。这种方法特别适用于训练数据有限的情况,可以通过增加数据的多样性来提高机器学习模型的性能和泛化能力。 以下是GAN在…

Java如何做到无感知刷新token含示例代码(值得珍藏)

1. 前言 在系统页面进行业务操作时,有时会突然遇到应用闪退,并被重定向至登录页面,要求重新登录。此问题的出现,通常与系统中用于存储用户ID和token信息的Redis缓存有关。具体来说,这可能是由于token过期所导致的身份…

VScode远程开发

VScode远程开发 在SSH远程连接一文中,我么介绍了如何使用ssh远程连接Jetson nano端,但是也存在诸多不便,比如:编辑文件内容时,需要使用vi编辑器,且在一个终端内,无法同时编辑多个文件。本节将介绍一较为实用…

mybatisPlus注解将List集合插入到数据库

1.maven引入依赖&#xff08;特别注意版本&#xff0c;3.1以下不支持&#xff09; <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.4.3.1</version></dependency&g…

Docker 安装 MongoDb4

Docker 安装mongoDb 获取mongodb安装问题汇总参考 获取mongodb 注意&#xff1a; WARNING: MongoDB 5.0 requires a CPU with AVX support, and your current system does not appear to have that! **hub官网&#xff08;需要梯子&#xff09;&#xff1a;**https://hub.dock…

数据分析案例-图书书籍数据可视化分析(文末送书)

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

iPhone解锁工具---AnyMP4 iPhone Unlocker 中文

AnyMP4 iPhone Unlocker是一款功能强大的iPhone解锁软件&#xff0c;旨在帮助用户轻松解锁iPhone&#xff0c;从而在电脑上进行数据备份、传输和编辑。该软件支持多种iPhone型号&#xff0c;包括最新的iPhone 14系列&#xff0c;并支持多种解锁模式&#xff0c;如屏幕密码解锁、…

PyTorch各种损失函数解析:深度学习模型优化的关键(2)

目录 详解pytorch中各种Loss functions mse_loss 用途 用法 使用技巧 注意事项 参数 数学理论公式 代码演示 margin_ranking_loss 用途 用法 使用技巧 注意事项 参数 数学理论公式 代码演示 multilabel_margin_loss 用途 用法 使用技巧 注意事项 参数 …

最新开源付费小剧场短剧小程序源码/影视小程序源码/带支付收益+运营代理推广等功能【搭建教程】

源码介绍&#xff1a; 最新开源付费小剧场短剧小程序源码、影视小程序源码&#xff0c;它有带支付收益、运营代理推广等功能&#xff0c;另有搭建教程好测试上手。仿抖音滑动小短剧影视带支付收益等模式的微信小程序源码。 这是一款功能强大的全开源付费短剧小程序源码&#…

MySQL存储函数与存储过程习题

创建表并插入数据&#xff1a; 字段名 数据类型 主键 外键 非空 唯一 自增 id INT 是 否 是 是 否 name VARCHAR(50) 否 否 是 否 否 glass VARCHAR(50) 否 否 是 否 否 ​ ​ sch 表内容 id name glass 1 xiaommg glass 1 2 xiaojun glass 2 1、创建一个可以统计表格内记录…

protobuf学习日记 | 认识protobuf中的类型

目录 前言 一、标量数据类型 二、protobuf中的 “数组” 三、特殊类型 1、枚举类型 &#xff08;1&#xff09;类型讲解 &#xff08;2&#xff09;升级通讯录 2、Any类型 &#xff08;1&#xff09;类型讲解 &#xff08;2&#xff09;升级通讯录 3、oneof类型 …

【Linux修行路】基本指令

目录 推荐 前言 1、重新认识操作系统 1.1 操作系统是什么? 1.2操作系统的作用 1.3 我们在计算机上的所有操作 1.4 Linux操作的特点 2、Linux基本指令 2.1 ls 指令 2.2 pwd 命令 2.3 cd 指令 2.3.1 Linux中的目录结构 2.3.2 绝对路径和相对路径 2.3.3 cd 指令 …

C++、QT 数字合成游戏

一、项目介绍 数字合成游戏 基本要求&#xff1a; 1&#xff09;要求游戏界面简洁美观&#xff0c;且符合扫雷的游戏风格。 2&#xff09;需要有游戏操作或者规则说明&#xff0c;方便玩家上手。 3&#xff09;需具有开始游戏&#xff0c;暂停游戏&#xff0c;结束游戏等方便玩…