PyTorch实战:模型训练中的特征图可视化技巧

news2024/11/24 17:21:48

1.特征图可视化,这种方法是最简单,输入一张照片,然后把网络中间某层的输出的特征图按通道作为图片进行可视化展示即可。

2.特征图可视化代码如下:

def featuremap_visual(feature, 
                      out_dir=None,  # 特征图保存路径文件
                      save_feature=True,  # 是否以图片形式保存特征图
                      show_feature=True,  # 是否使用plt显示特征图
                      feature_title=None,  # 特征图名字,默认以shape作为title
                      num_ch=-1,  # 显示特征图前几个通道,-1 or None 都显示
                      nrow=8,  # 每行显示多少个特征图通道
                      padding=10,  # 特征图之间间隔多少像素值
                      pad_value=1  # 特征图之间的间隔像素
                      ):
    import matplotlib.pylab as plt
    import torchvision
    import os
    # feature = feature.detach().cpu()
    b, c, h, w = feature.shape
    feature = feature[0]
    feature = feature.unsqueeze(1)

    if c > num_ch > 0:
        feature = feature[:num_ch]

    img = torchvision.utils.make_grid(feature, nrow=nrow, padding=padding, pad_value=pad_value)
    img = img.detach().cpu()
    img = img.numpy()
    images = img.transpose((1, 2, 0))

    # title = str(images.shape) if feature_title is None else str(feature_title)
    title = str('hwc-') + str(h) + '-' + str(w) + '-' + str(c) if feature_title is None else str(feature_title)

    plt.title(title)
    plt.imshow(images)
    if save_feature:
        # root=r'C:\Users\Administrator\Desktop\CODE_TJ\123'
        # plt.savefig(os.path.join(root,'1.jpg'))
        out_root = title + '.jpg' if out_dir == '' or out_dir is None else os.path.join(out_dir, title + '.jpg')
        plt.savefig(out_root)



    if show_feature:        plt.show()

3.结合resnet网络整体可视化(主要将其featuremap_visual函数插入forward中,即可),整体代码如下:

resnet网络结构在我博客:
残差网络ResNet(超详细代码解析) :你必须要知道backbone模块成员之一 - tangjunjun - 博客园

"""
@author: tangjun
@contact: 511026664@qq.com
@time: 2020/12/7 22:48
@desc: 残差ackbone改写,用于构建特征提取模块
"""

import torch.nn as nn
import torch
from collections import OrderedDict


def Conv(in_planes, out_planes, **kwargs):
    "3x3 convolution with padding"
    padding = kwargs.get('padding', 1)
    bias = kwargs.get('bias', False)
    stride = kwargs.get('stride', 1)
    kernel_size = kwargs.get('kernel_size', 3)
    out = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
    return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = Conv(inplanes, planes, stride=stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = Conv(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Resnet(nn.Module):
    arch_settings = {
        18: (BasicBlock, (2, 2, 2, 2)),
        34: (BasicBlock, (3, 4, 6, 3)),
        50: (Bottleneck, (3, 4, 6, 3)),
        101: (Bottleneck, (3, 4, 23, 3)),
        152: (Bottleneck, (3, 8, 36, 3))
    }

    def __init__(self,
                 depth=50,
                 in_channels=None,
                 pretrained=None,
                 frozen_stages=-1
                 # num_classes=None
                 ):
        super(Resnet, self).__init__()
        self.inplanes = 64
        self.inchannels = in_channels if in_channels is not None else 3  # 输入通道
        # self.num_classes=num_classes
        self.block, layers = self.arch_settings[depth]
        self.frozen_stages = frozen_stages
        self.conv1 = nn.Conv2d(self.inchannels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(self.block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(self.block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(self.block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(self.block, 512, layers[3], stride=2)

        # self.avgpool = nn.AvgPool2d(7)
        # self.fc = nn.Linear(512 * self.block.expansion, self.num_classes)
        self._freeze_stages()  # 冻结函数
        if pretrained is not None:
            self.init_weights(pretrained=pretrained)

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.norm1.eval()
            for m in [self.conv1, self.norm1]:
                for param in m.parameters():
                    param.requires_grad = False
        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            m.eval()
            for param in m.parameters():
                param.requires_grad = False

    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            self.load_checkpoint(pretrained)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='relu')
                    if hasattr(m, 'bias') and m.bias is not None:  # m包含该属性且m.bias非None # hasattr(对象,属性)表示对象是否包含该属性
                        nn.init.constant_(m.bias, 0)

                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

    def load_checkpoint(self, pretrained):

        checkpoint = torch.load(pretrained)
        if isinstance(checkpoint, OrderedDict):
            state_dict = checkpoint
        elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']

        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}

        unexpected_keys = []  # 保存checkpoint不在module中的key
        model_state = self.state_dict()  # 模型变量

        for name, param in state_dict.items():  # 循环遍历pretrained的权重
            if name not in model_state:
                unexpected_keys.append(name)
                continue
            if isinstance(param, torch.nn.Parameter):
                # backwards compatibility for serialized parameters
                param = param.data

            try:
                model_state[name].copy_(param)  # 试图赋值给模型
            except Exception:
                raise RuntimeError(
                    'While copying the parameter named {}, '
                    'whose dimensions in the model are {} not equal '
                    'whose dimensions in the checkpoint are {}.'.format(
                        name, model_state[name].size(), param.size()))
        missing_keys = set(model_state.keys()) - set(state_dict.keys())
        print('missing_keys:', missing_keys)

    def _make_layer(self, block, planes, num_blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, num_blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        outs = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        outs.append(x)
        featuremap_visual(x)

        x = self.layer2(x)
        outs.append(x)

        featuremap_visual(x)

        x = self.layer3(x)
        outs.append(x)
        featuremap_visual(x)

        x = self.layer4(x)
        outs.append(x)

        # x = self.avgpool(x)
        # x = x.view(x.size(0), -1)
        # x = self.fc(x)

        return tuple(outs)


def featuremap_visual(feature,
                      out_dir=None,  # 特征图保存路径文件
                      save_feature=True,  # 是否以图片形式保存特征图
                      show_feature=True,  # 是否使用plt显示特征图
                      feature_title=None,  # 特征图名字,默认以shape作为title
                      num_ch=-1,  # 显示特征图前几个通道,-1 or None 都显示
                      nrow=8,  # 每行显示多少个特征图通道
                      padding=10,  # 特征图之间间隔多少像素值
                      pad_value=1  # 特征图之间的间隔像素
                      ):
    import matplotlib.pylab as plt
    import torchvision
    import os
    # feature = feature.detach().cpu()
    b, c, h, w = feature.shape
    feature = feature[0]
    feature = feature.unsqueeze(1)

    if c > num_ch > 0:
        feature = feature[:num_ch]

    img = torchvision.utils.make_grid(feature, nrow=nrow, padding=padding, pad_value=pad_value)
    img = img.detach().cpu()
    img = img.numpy()
    images = img.transpose((1, 2, 0))

    # title = str(images.shape) if feature_title is None else str(feature_title)
    title = str('hwc-') + str(h) + '-' + str(w) + '-' + str(c) if feature_title is None else str(feature_title)

    plt.title(title)
    plt.imshow(images)
    if save_feature:
        # root=r'C:\Users\Administrator\Desktop\CODE_TJ\123'
        # plt.savefig(os.path.join(root,'1.jpg'))
        out_root = title + '.jpg' if out_dir == '' or out_dir is None else os.path.join(out_dir, title + '.jpg')
        plt.savefig(out_root)



    if show_feature:        plt.show()


import cv2
import numpy as np


def imnormalize(img,
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True
                ):
    if to_rgb:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.astype(np.float32)
    return (img - mean) / std


if __name__ == '__main__':
    import matplotlib.pylab as plt



    img = cv2.imread('1.jpg')  # 读取图片

    img = imnormalize(img)
    img = torch.from_numpy(img)

    img = torch.unsqueeze(img, 0)
    img = img.permute(0, 3, 1, 2)
    img = torch.tensor(img, dtype=torch.float32)
    img = img.to('cuda:0')

    model = Resnet(depth=50)
    model.init_weights(pretrained='./resnet50.pth')  # 可以使用,也可以注释
    model = model.cuda()
    out = model(img)

运行结果

参考:

PyTorch模型训练特征图可视化 - tangjunjun - 博客园 (cnblogs.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1847123.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

算法基础精选题单 动态规划(dp)(区间dp)(个人题解)

目录 前言: 正文: 题单:【237题】算法基础精选题单_ACM竞赛_ACM/CSP/ICPC/CCPC/比赛经验/题解/资讯_牛客竞赛OJ_牛客网 (nowcoder.com) NC50493 石子合并: NC50500 凸多边形的划分: NC235246 田忌赛马&#xff1a…

数据库浅识及MySQL的二进制安装

数据库基础概念与MySQL二进制安装与初始化 使用数据库的必要性 数据库可以结构化储存大量数据信息,方便用户进行有效的检索访问 有效的保持数据信息的一致性,完整性,降低数据冗余 可以满足应用的共享和安全方面的要求 数据库基本概念 数据…

[dataworks]从mysql导入数据、将结果导入到mysql、处理写错表名问题、创建依赖任务

一、从mysql导入数据 在ods的数据集成下点新建-->离线同步 1、起名imp_t_ods_uc_cst_terminal_dtl_df 前缀imp是import的缩写 t代表trade即MySQL的交易库(trade)的简写 ods即导入到ods层 uc_cst_terminal_dt为MySQL对应的表名 df为日全量导入(di为日增量导…

真实还原汽车引擎声浪——WT2003Hx语音芯片方案

PART.01 产品市场 WT2003Hx是一款高性能的MP3音频解码芯片,具有成本效益、低功耗和高可靠性等特点,适用于多种场景,包括但不限于汽车娱乐系统、玩具、教育设备以及专业音响设备等。在模拟汽车引擎声的应用中,这一芯片的特性被特…

推荐一个十分好用的AI工具

推荐一个很好用的ai工具 链接在最下面 **介绍** ChatGPT 是由OpenAI开发的先进语言模型,旨在通过自然而流畅的对话方式与用户交互。无论是解决问题、提供建议,还是进行创意灵感的激发,ChatGPT都能为用户提供帮助。 **特点与优势** 1. **广泛…

JMeter的基本使用与性能测试,完整入门篇保姆式教程

Jmeter 的简介 JMeter是一个纯Java编写的开源软件,主要用于进行性能测试和功能测试。它支持测试的应用/服务/协议包括Web (HTTP, HTTPS)、SOAP/REST Webservices、FTP、Database via JDBC等。我们最常使用的是HTTP和HTTPS协议。 Jmeter主要组件 线程组&#xff08…

【C++进阶学习】第三弹——菱形继承和虚拟继承——菱形继承的二义性和数据冗余问题

继承(上):【C进阶学习】第一弹——继承(上)——探索代码复用的乐趣-CSDN博客 继承(下):【C进阶学习】第二弹——继承(下)——挖掘继承深处的奥秘-CSDN博客 …

雷池社区版自动SSL

正常安装雷池,并配置站点,暂时不配置ssl 不使用雷池自带的证书申请。 安装(acme.sh),使用域名验证方式生成证书 先安装git yum install git 或者 apt-get install git 安装完成后使用 git clone https://gitee.com/n…

(项目实战)RocketMQ5.0延迟消息在聚合支付系统中的应用

1 基于业务场景掌握RocketMQ5.0 本篇文章主要结合聚合支付系统中的业务场景来落地RocketMQ中间件的应用,聚合支付系统主要在支付系统超时订单和商户支付结果异步通知场景中会使用到RocketMQ消息中间件。本文使用到了RocketMQ中的延迟消息知识点,RocketM…

SD-WAN为什么适合小企业

SD-WAN(软件定义广域网)是一种革新性的网络技术,通过软件智能管理,实现灵活和高效的网络连接。在数字化转型浪潮中,企业对网络稳定性和性能的要求不断提升,SD-WAN因此受到了广泛关注。对于资源有限的小型企…

laravel中如何向字段标签添加工具提示

首先,您可以使用 轻松自定义字段标签->label()。我相信您知道这一点。但您知道吗……标签输出未转义?这意味着您也可以在标签中包含 HTML。 为了尽快实现上述目标,我只是采取了一个快速而粗糙的解决方案: CRUD::field(nickna…

扭转引伸计技术资料YYJ-10 6-N

一、 工作原理 利用专门设计的扭转引伸计夹持系统,可靠地装夹在试样上,采用应变片夹式引伸计进行机械量与电信号的转换,使之完成扭转应变的自动测试。 二、技术指标 1、扭转引伸计的标距:该装置分别配置50mm、100mm标距联接延伸横…

一键制作,打造高质量的数字刊物

随着数字化时代的到来,数字刊物已经成为信息传播的重要载体。它以便捷、环保、互动性强等特点,受到了越来越多人的青睐。然而,如何快速、高效地制作出高质量的数字刊物,成为许多创作者面临的难题。今天,教大家一个制作…

浅析MySQL-基础02

目录 MySQL一行记录是怎么存储的? MySQL的数据存放在哪? 表空间文件的结构是怎么样的? InnoDB行格式有哪些? Compact行格式是啥样的? 记录的额外信息 1、变长字段长度列表 2、NULL值列表 3、记录头信息 记录…

LeetCode题练习与总结:克隆图--133

一、题目描述 给你无向 连通 图中一个节点的引用,请你返回该图的 深拷贝(克隆)。 图中的每个节点都包含它的值 val(int) 和其邻居的列表(list[Node])。 class Node {public int val;public L…

【EndNote】EndNote进行文献管理可能遇到的问题和解决方案

一、安装GB/T7714-2015(numberic)文献style windows:https://blog.csdn.net/qq_36235935/article/details/115629694 mac os:Mac版Endnote 20导入中文参考格式Chinese Std GBT7714 (numeric)-CSDN博客 安装完之后需要调整Author Name格式:…

Linux内核学习——linux内核体系结构(1)

1 Linux内核模式 学习的是Linux 0.11内核,采用的是单内核模式。单内核模式的主要优点是内核代码结构紧凑、执行速度快,但是层次结构性不强。 操作系统如何提供的服务流程? 应用主程序使用指定的参数值执行系统调用指令(int x80)&#xff0…

用进程和线程完成TCP进行通信操作及广播和组播的通信

进程 代码 #include <stdio.h>#include <sys/types.h>#include <sys/socket.h>#include <netinet/in.h>#include <arpa/inet.h>#include <string.h>#include <unistd.h>#include <stdlib.h>#include <signal.h>#includ…

如何使用idea连接Oracle数据库?

idea版本&#xff1a;2021.3.3 Oracle版本&#xff1a;10.2.0.1.0&#xff08;在虚拟机Windows sever 2003 远程连接数据库&#xff09; 数据库管理系统&#xff1a;PLSQL Developer 在idea里面找到database&#xff0c;在idea侧面 选择左上角加号&#xff0c;新建&#xff…

消息队列kafka中间件详解:案例解析(第10天)

系列文章目录 1- 消息队列&#xff08;熟悉&#xff09;2- Kafka的基本介绍&#xff08;掌握架构&#xff0c;其他了解&#xff09;3- Kafka的相关使用&#xff08;掌握kafka常用shell命令&#xff09;4- Kafka的Python API的操作&#xff08;熟悉&#xff09; 文章目录 系列文…