PyTorch深度学习实战(14)——类激活图

news2024/12/24 3:35:11

PyTorch深度学习实战(14)——类激活图

    • 0. 前言
    • 1. 类激活图
      • 1.1 基本概念
      • 1.2 类激活图生成
    • 2 数据集分析
    • 3 使用 PyTorch 生成 CAM
    • 小结
    • 系列链接

0. 前言

我们已经能够构建性能优异的神经网络模型,但对我们而言,卷积神经网络的决策过程难以解释和理解。类激活图 (Class Activation Map, CAM) 是一种可视化技术,用于解释深度学习模型在图像分类任务中的决策过程。它能够显示出对于分类结果影响最显著的图像区域,从而提供对模型决策的可解释性。通过观察类激活图,可以理解模型在分类决策中注重的区域和特征,这有助于我们分析和解释模型的决策依据,以及验证模型是否关注了正确的特征。本节中,将介绍类激活图的基本概念,并使用训练好的模型生成图像的类激活图。

1. 类激活图

1.1 基本概念

类激活图 (Class Activation Map, CAM) 是一种用于可视化卷积神经网络 (Convolutional Neural Networks, CNN) 中每个类别的局部重要性的技术,使用 CAM 可以帮助我们理解 CNN 的决策过程,以及哪些特征对于某个类别的分类最为重要。一个示例 CAM 如下所示,其中左侧是输入图像,右侧突出显示了用于类别预测的像素:

类激活图
根据以上激活图可以看到高激活区域集中在最有助于模型作出类别预测的部位。接下来,我们将继续介绍如何在训练好模型后生成 CAM

1.2 类激活图生成

特征图是卷积操作后的中间激活,通常特征图的形状为 batch size x height x width,其中 batch size 表示批大小,height 表示特征图高度,width 表示特征图宽度。如果我们计算这些激活的平均值,它们就显示了图像中所有类别的热点区域。但是,如果我们只对某个特定类别(比如猫)重要的位置感兴趣,那么我们需要找出 n 个通道中只负责该类别的特征映射。对于生成这些特征映射的卷积层,我们可以计算其相对于猫类的梯度。请注意,只有负责预测猫的通道才会具有较高的梯度。这意味着我们可以使用梯度信息赋予 n 个通道中的每一个权重,并且得到一个专门用于猫的激活映射。
具体而言,我们可以使用以下过程生成 CAM,过程参考自 Grad-CAM: Gradient-weighted Class Activation Mapping:

  1. 确定要为哪个类别计算 CAM,以及要为神经网络中的哪个卷积层计算 CAM
  2. 计算卷积层产生的激活——假设卷积层的特征形状为 512 x 7 x 7
  3. 获取从该层产生的关于兴趣类别的梯度值,输出梯度形状为 256 x 512 x 3 x 3 (卷积张量的形状——即,batch size x height x width x kernel size,其中 kernel size 表示核尺寸)
  4. 计算每个输出通道内梯度的平均值,输出形状为 512
  5. 计算加权激活图——512 个梯度均值乘以 512 个激活通道,输出形状为 512 x 7 x 7
  6. 计算加权激活图的平均值(跨 512 个通道),获取形状为 7 x 7 的输出
  7. 调整(放大)加权激活图输出的大小,以得到与输入大小相同的图像,目的是得到与原始图像尺寸相同的激活图
  8. 将加权激活图叠加到输入图像上

整个过程的关键在于步骤 5,考虑以下两个方面:

  • 如果某个像素很重要,那么卷积神经网络 (Convolutional Neural Networks, CNN) 将在这些像素处得到较大的激活
  • 如果某个卷积通道对于兴趣类别很重要,则该通道的梯度将非常大

将这两者相乘后,将得到所有像素的重要性映射,重要性映射 (map of importance) 是指将神经网络中某个像素或特征的重要性表示为一个热力图或概率分布图,通常用于可视化神经网络输出中哪些部分对识别特定类别最为重要。

2 数据集分析

Malaria Cell Images Dataset 是一个常用的数据集,用于训练和评估计算机视觉模型在疟疾细胞图像分类任务上的性能。该数据集包含了被感染和未被感染的红血细胞图像,通常用于研究和开发自动检测和分类疟疾细胞的算法。通过使用这个数据集,可以构建深度学习模型来自动识别感染疟疾的红血细胞,从而帮助医生进行准确的诊断,在 Kaggle 官方网站下载 Malaria Cell Images Dataset

3 使用 PyTorch 生成 CAM

接下来,我们使用 PyTorch 实现 CAM 生成策略,以了解 CNN 模型能够预测图像可能出现疟疾事件的原因。

(1) 下载数据集,并导入相关库:

import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from glob import glob
from random import randint
import cv2
from pathlib import Path
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'

(2) 指定与输出类别对应的索引:

id2int = {'Parasitized': 0, 'Uninfected': 1}

(3) 执行图像转换操作:

from torchvision import transforms as T

trn_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize(128),
    T.CenterCrop(128),
    T.ColorJitter(brightness=(0.95,1.05), 
                  contrast=(0.95,1.05), 
                  saturation=(0.95,1.05), 
                  hue=0.05),
    T.RandomAffine(5, translate=(0.01,0.1)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], 
                std=[0.5, 0.5, 0.5]),
])

在以上代码中,对输入图像进行了一系列转换——首先将图像尺寸调整为 128 (最小边为 128),然后从图像中心进行裁剪。此外,我们还进行了随机颜色抖动和仿射变换,并使用 .ToTensor() 方法对图像进行缩放(使像素值位于0到1之间),最后对图像进行归一化处理。

对验证集图像执行转换:

val_tfms = T.Compose([
    T.ToPILImage(),
    T.Resize(128),
    T.CenterCrop(128),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], 
                std=[0.5, 0.5, 0.5]),
])

(4) 定义数据集类 MalariaImages

class MalariaImages(Dataset):
    def __init__(self, files, transform=None):
        self.files = files
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, ix):
        fpath = self.files[ix]
        clss = os.path.basename(Path(fpath).parent)
        img = cv2.imread(fpath)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img, clss

    def choose(self):
        return self[randint(len(self))]

    def collate_fn(self, batch):
        _imgs, classes = list(zip(*batch))
        if self.transform:
            imgs = [self.transform(img)[None] for img in _imgs]
        classes = [torch.tensor([id2int[clss]]) for clss in classes]
        imgs, classes = [torch.cat(i).to(device) for i in [imgs, classes]]
        return imgs, classes, _imgs

(5) 获取训练、验证数据集和数据加载器:

all_files = glob('cell_images/*/*.png')
np.random.shuffle(all_files)

from sklearn.model_selection import train_test_split
trn_files, val_files = train_test_split(all_files, random_state=1)

trn_ds = MalariaImages(trn_files, transform=trn_tfms)
val_ds = MalariaImages(val_files, transform=val_tfms)
trn_dl = DataLoader(trn_ds, 32, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, 32, shuffle=False, collate_fn=val_ds.collate_fn)

(6) 定义模型 MalariaClassifier

def convBlock(ni, no):
    return nn.Sequential(
        nn.Dropout(0.2),
        nn.Conv2d(ni, no, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(no),
        nn.MaxPool2d(2),
    )
    
class MalariaClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            convBlock(3, 64),
            convBlock(64, 64),
            convBlock(64, 128),
            convBlock(128, 256),
            convBlock(256, 512),
            convBlock(512, 64),
            nn.Flatten(),
            nn.Linear(256, 256),
            nn.Dropout(0.2),
            nn.ReLU(inplace=True),
            nn.Linear(256, len(id2int))
        )
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def compute_metrics(self, preds, targets):
        loss = self.loss_fn(preds, targets)
        acc = (torch.max(preds, 1)[1] == targets).float().mean()
        return loss, acc

(7) 定义使用批数据对模型进行训练和验证的函数:

def train_batch(model, data, optimizer, criterion):
    model.train()
    ims, labels, _ = data
    _preds = model(ims)
    optimizer.zero_grad()
    loss, acc = criterion(_preds, labels)
    loss.backward()
    optimizer.step()
    return loss.item(), acc.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    model.eval()
    ims, labels, _ = data
    _preds = model(ims)
    loss, acc = criterion(_preds, labels)
    return loss.item(), acc.item()

(8) 训练模型:

model = MalariaClassifier().to(device)
criterion = model.compute_metrics
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 5

for ex in range(n_epochs):
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    N = len(trn_dl)
    for bx, data in enumerate(trn_dl):
        loss, acc = train_batch(model, data, optimizer, criterion)
        train_loss.append(loss)
        train_acc.append(acc)
    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss, acc = validate_batch(model, data, criterion)
        val_loss.append(loss)
        val_acc.append(acc)
    avg_train_loss = np.average(train_loss)
    avg_train_acc = np.average(train_acc)
    avg_val_loss = np.average(val_loss)
    avg_val_acc = np.average(val_acc)
    print(f"EPOCH: {ex}	trn_loss: {avg_train_loss}	trn_acc: {avg_train_acc}	val_loss: {avg_val_loss}	val_acc: {avg_val_acc}")

(9) 获取模型中第五个 convBlock 中的卷积层:

im2fmap = nn.Sequential(*(list(model.model[:5].children()) + list(model.model[5][:2].children())))

在以上代码中,获取模型的第四层以及 convBlock 中的前两层(均为 Conv2D 层)。

(10) 定义 im2gradCAM 函数,该函数接受输入图像并获取与图像激活对应的热力图:

def im2gradCAM(x):
    model.eval()
    logits = model(x)
    heatmaps = []
    activations = im2fmap(x)
    print(activations.shape)
    pred = logits.max(-1)[-1]
    # 获取模型预测
    model.zero_grad()
    # 计算相对于模型置信度最高的 logits 的梯度
    logits[0,pred].backward(retain_graph=True)
    # 获取所需特征图位置的梯度,并对每个特征图取平均梯度
    pooled_grads = model.model[-6][1].weight.grad.data.mean((1,2,3))
    # 将每个激活图与对应的梯度平均值相乘
    for i in range(activations.shape[1]):
        activations[:,i,:,:] *= pooled_grads[i]
    # 计算所有加权激活图的平均值
    heatmap = torch.mean(activations, dim=1)[0].cpu().detach()
    return heatmap, 'Uninfected' if pred.item() else 'Parasitized'

(11) 定义 upsampleHeatmap 函数将热图上采样为与图像形状对应的形状:

SZ = 120
def upsampleHeatmap(map, img):
    m,M = map.min(), map.max()
    map = 255 * ((map-m) / (M-m))
    map = np.uint8(map)
    map = cv2.resize(map, (SZ,SZ))
    map = cv2.applyColorMap(255-map, cv2.COLORMAP_JET)
    map = np.uint8(map)
    map = np.uint8(map*0.7 + img*0.3)
    return map

在前面的代码行中,我们对图像进行了反归一化,并将热图覆盖在图像之上。

(12) 使用一组测试图像调用上述函数:

N = 20
_val_dl = DataLoader(val_ds, batch_size=N, shuffle=True, collate_fn=val_ds.collate_fn)
x,y,z = next(iter(_val_dl))

for i in range(N):
    image = cv2.resize(z[i], (SZ, SZ))
    heatmap, pred = im2gradCAM(x[i:i+1])
    if(pred=='Uninfected'):
        continue
    heatmap = upsampleHeatmap(heatmap, image)
    plt.figure(figsize=(5,3))
    plt.subplot(121)
    plt.imshow(image)
    plt.subplot(122)
    plt.imshow(heatmap)
    plt.suptitle(pred)
    plt.show()

测试图像热力图
从图中可以看出,预测结果是由红色高亮区域的内容决定的(这部分区域具有最高的 CAM 值)。学习了如何使用训练好的模型生成图像的类激活热力图,我们就可以解释是什么原因导致了模型产生某个分类结果。

小结

类激活图的生成方法主要基于卷积神经网络模型的最后一个卷积层和全局平均池化层。首先,通过前向传播将图像输入到卷积神经网络中,然后获取最后一个卷积层的特征图。接着,利用全局平均池化对每个特征通道的权重进行计算,得到每个通道对分类结果的重要性。之后,将原始图像与特征图的权重相乘,并进行叠加,最终得到一个类激活图。类激活图中像素值较高的区域表示对于分类结果贡献较大的图像区域,而像素值较低的区域则表示贡献较小或无关的区域。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/982522.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式Linux驱动开发(LCD屏幕专题)(一)

一、LCD简介 总的分辨率是 yres*xres。 1.1、像素颜色的表示 以下三种方式表示颜色 1.2、如何将颜色数据发送给屏幕 每个屏幕都有一个内存(framebuffer)如下图,内存中每块数据对用屏幕上的一个像素点,设置好LCD后&#xff…

西门子PLC与三菱PLC之间能否实现无线MODUBS通讯

对一个大型工厂,由于生产线的不断改造、新老流程的不断更新,这些PLC系统往往是由不同的制造商提供的。那么在智慧工厂的实现中,常会遇到不同品牌PLC之间需要进行相互通讯的情况。由于场地和生产能效的原因,在后期的系统改造中&…

提升效率:PostgreSQL准确且快速的数据对比方法

作为一款强大而广受欢迎的开源关系型数据库管理系统,PostgreSQL 在数据库领域拥有显著的市场份额。其出色的可扩展性、稳定性使其成为众多企业和项目的首选数据库。而在很多场景下(开发|生产环境同步、备份恢复验证、数据迁移、数据合并等)&a…

基于Python开发的学生信息管理系统控制台程序(源码+可执行程序exe文件+程序配置说明书+程序使用说明书)

一、项目简介 本项目是一套基于Python开发的学生信息管理系统控制台程序,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的Python学习者。 包含:项目源码、项目文档、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都…

《2023人工智能发展白皮书》发布(118页)

导读 本白皮书由七大部分组成。第一章人工智能产业链分析,描绘人工智能产业链全景图,并对产业链各环节进行深入分析;第二章人工智能行业环境,明确中国人工智能行业生命周期和竞争结构;第三章人工智能发展概况,阐述国内外人工智能…

4.3.3 【MySQL】Redundant行格式

现在我们把表demo 的行格式修改为 Redundant : 为了方便大家理解和节省篇幅,我们直接把表 demo 在Redundant 行格式下的两条记录的真实存储数据提供出来,之后我们着重分析两种行格式的不同即可。 下边我们从各个方面看一下 Redundant 行格式有…

十二、集合(4)

本章概要 集合 Set映射 Map队列 Queue 优先级队列 PriorityQueue 集合与迭代器 集合Set Set 不保存重复的元素。 如果试图将相同对象的多个实例添加到 Set 中,那么它会阻止这种重复行为。 Set 最常见的用途是测试归属性,可以很轻松地询问某个对象是否…

SQL Server 2012下载和安装配置详细教程手册

SQL Server 2012 下载和安装详细教程 目录 SQL Server 2012 下载和安装详细教程1、软件下载2、软件安装3、软件验证 1、软件下载 (1)官网地址 https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads (可能不太行) &a…

简明SQL别名指南:掌握AS实现列名更名

在 SQL 查询中,使用 {原始字段名} as {别名} 的语法来为查询结果的列赋予更直观的名称,以提高查询结果的可读性和可理解性。 以下是用到的表。 用AS更名 例如,查询表1的name字段,并将其更名为"名字",同时查…

多语言开发(vant

参考:https://blog.csdn.net/qq_44649801/article/details/131878128?spm1001.2014.3001.5506 一、抛出字段对象A export default { } 二、引入汇总文件,(主要的是 模块分割 汇总,对A 等的处理 export default { A,B,…

设计模式之桥接模式、组合模式与享元模式

目录 桥接模式 简介 优缺点 结构 实现 运用场景 组合模式 简介 优缺点 结构 实现 运用场景 享元模式 简介 优缺点 结构 实现 运用场景 桥接模式 简介 将抽象与实现分离,使它们可以独立变化。它是用组合关系代替继承关系来实现,从而降…

【rar密码】如何修改rar压缩包密码?

rar压缩包设置了密码,想要修改密码,有方法吗?目前看来,还是没有,我们只输入密码将压缩包中的文件解压出来。 然后将解压出来的文件再重新压缩,重新设置一个密码进行压缩即可达到修改密码的目的了 想要修改密…

leetcode 925. 长按键入

2023.9.7 我的基本思路是两数组字符逐一对比,遇到不同的字符,判断一下typed与上一字符是否相同,不相同返回false,相同则继续对比。 最后要分别判断name和typed分别先遍历完时的情况。直接看代码: class Solution { p…

B094-人力资源项目-微服务授权Oauth2

目录 背景OAUTH2总体流程Oauth2授权码模式oauth2的三方授权流程图案例演示代码讲解-整合oauth2springsecurityjwt先创建一个用于认证的服务增加AuthorizationServerConfig 模块总结认证服务整合zuul获取令牌加入网关后的变化代码详解测试 背景 微服务架构下应用散步在不同的服…

DC电源模块对电磁干扰的影响

BOSHIDA DC电源模块对电磁干扰的影响 DC电源模块是一种常用的电源转换设备,可以将交流电转换成直流电,并通过电路电子元件对电压、电流等参数进行调整,以满足外部设备对电源的需求。然而,由于DC电源模块自身的工作特性&#xff…

Deep Java Library(四)使用DJL Serving部署JAVA模型 For Windows

1.下载Windows版DJL Serving Windows版DJL Serving下载地址: https://publish.djl.ai/djl-serving/serving-0.23.0.zip 下载下来是一个zip压缩包,大约50M左右,目前最新版本为0.23.0 2.安装DJL Serving 解压serving-0.23.0.zip后目录如下 …

网络技术七:命令行基础

命令行操作基础 命令类型 常见设备管理命令 H3C路由交换产品连接方法 使用console线本地连接 协议Serial,接口com口,波特率9600 适用于设备的初次调试 使用Telnet远程访问 适用于设备上架配置好后的维护管理 使用SSH远程访问 数据传输过程加密&…

设计模式系列-外观模式

一、上篇回顾 上篇我们主要讲述了创建型模式中的最后一个模式-原型模式,我们主要讲述了原型模式的几类实现方案,和原型模式的应用的场景和特点,原型模式 适合在哪些场景下使用呢?我们先来回顾一下我们上篇讲述的3个常用的场景。 1…

聚观早报|全球AI领域投资总额暴降;哈啰租车开启假期早鸟预定

【聚观365】9月7日消息 全球AI领域投资总额暴降 哈啰租车开启假期早鸟预定 微信上线“腾讯混元助手”小程序 腾讯“QQ群恢复”功能下线 一嗨租车1.1万个直营网点接入滴滴APP 全球AI领域投资总额暴降 媒体报道称,科技巨头Meta(原Facebook&#xff…

TCP/IP基础

前言: TCP/IP协议是计算机网络领域中最基本的协议之一,它被广泛应用于互联网和局域网中,实现了不同类型、不同厂家、运行不同操作系统的计算机之间的相互通信。本文将介绍TCP/IP协议栈的层次结构、各层功能以及数据封装过程,帮助您…