第56步 深度学习图像识别:CNN梯度权重类激活映射(TensorFlow)

news2024/11/29 10:50:43

基于WIN10的64位系统演示

一、写在前面

类激活映射(Class Activation Mapping,CAM)和梯度权重类激活映射(Gradient-weighted Class Activation Mapping,Grad-CAM)是两种可视化深度学习模型决策过程的技术。他们都是为了理解模型的决策过程,特别是对于图像分类任务,它们可以生成一种热力图,这种图可以突出显示模型在做出预测时关注的图像区域。

CAM:CAM是一种可视化卷积神经网络(Convolutional Neural Networks, CNN)决策依据的技术。对于图像分类任务,它可以生成一种热力图,突出显示模型在做出预测时关注的图像区域。CAM需要模型在全局平均池化(Global Average Pooling, GAP)层和最终的全连接层(Fully Connected, FC)之间没有其他隐藏层,这是其使用的限制。

Grad-CAM:Grad-CAM是为了克服CAM的限制而提出的一种方法,它使用的是类别得分关于特定层输出的梯度信息。这种方法不仅可以应用于卷积层,还可以应用于任何层的输出。因此,Grad-CAM可以用于多种类型的深度学习模型,包括图像分类、图像生成、强化学习等各种模型。这使得Grad-CAM在可视化模型决策过程方面更加灵活和强大。

这一期主要介绍Grad-CAM,用的模型是Mobilenet_v2,以为够快!!

二、Grad-CAM可视化实战

继续使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)Mobilenet_v2建模

######################################导入包###################################
from tensorflow import keras
import tensorflow as tf
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Activation, Reshape, Softmax, GlobalAveragePooling2D, BatchNormalization
from tensorflow.python.keras.layers.convolutional import Convolution2D, MaxPooling2D
from tensorflow.python.keras import Sequential
from tensorflow.python.keras import Model
from tensorflow.python.keras.optimizers import adam_v2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, image_dataset_from_directory
from tensorflow.python.keras.layers.preprocessing.image_preprocessing import RandomFlip, RandomRotation, RandomContrast, RandomZoom, RandomTranslation
import os,PIL,pathlib
import warnings
#设置GPU
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

################################导入数据集#####################################
#1.导入数据
data_dir = "./MTB"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

batch_size = 32
img_height = 100
img_width  = 100

train_ds = image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)
print(train_ds)


#2.检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

#3.配置数据
AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(800)
    .map(train_preprocessing)    
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .map(train_preprocessing) 
    .prefetch(buffer_size=AUTOTUNE)
)

#4. 数据可视化
plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

class_names = ["Tuberculosis","Normal"]

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

######################################数据增强函数################################

data_augmentation = Sequential([
  RandomFlip("horizontal_and_vertical"),
  RandomRotation(0.2),
  RandomContrast(1.0),
  RandomZoom(0.5,0.2),
  RandomTranslation(0.3,0.5),
])

def prepare(ds):
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)
    return ds
train_ds = prepare(train_ds)

################################导入mobilenet_v2################################
#获取预训练模型对输入的预处理方法
from tensorflow.python.keras.applications import mobilenet_v2
from tensorflow.python.keras import Input, regularizers
IMG_SIZE = (img_height, img_width, 3)

# 创建输入张量
inputs = Input(shape=IMG_SIZE)
# 定义基础模型,并将 inputs 传入
base_model = mobilenet_v2.MobileNetV2(input_tensor=inputs,
                                      include_top=False, 
                                      weights='imagenet')

#从基础模型中获取输出
x = base_model.output
#全局池化
x = GlobalAveragePooling2D()(x)
#BatchNormalization
x = BatchNormalization()(x)
#Dropout
x = Dropout(0.8)(x)
#Dense
x = Dense(128, kernel_regularizer=regularizers.l2(0.1))(x)  # 全连接层减少到128,添加 L2 正则化
#BatchNormalization
x = BatchNormalization()(x)
#激活函数
x = Activation('relu')(x)
#输出层
outputs = Dense(2, kernel_regularizer=regularizers.l2(0.1))(x)  # 添加 L2 正则化
#BatchNormalization
outputs = BatchNormalization()(outputs)
#激活函数
outputs = Activation('sigmoid')(outputs)
#整体封装
model = Model(inputs, outputs)
#打印模型结构
print(model.summary())

#############################编译模型#########################################
#定义优化器
from tensorflow.python.keras.optimizers import adam_v2, rmsprop_v2
optimizer = adam_v2.Adam()


#编译模型
model.compile(optimizer=optimizer,
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

#训练模型
from tensorflow.python.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateScheduler

NO_EPOCHS = 50
PATIENCE  = 10
VERBOSE   = 1

# 设置动态学习率
annealer = LearningRateScheduler(lambda x: 1e-5 * 0.99 ** (x+NO_EPOCHS))

# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)

# 
checkpointer = ModelCheckpoint('mtb_jet_best_model_mobilenetv3samll.h5',
                                monitor='val_accuracy',
                                verbose=VERBOSE,
                                save_best_only=True,
                                save_weights_only=True)

train_model  = model.fit(train_ds,
                  epochs=NO_EPOCHS,
                  verbose=1,
                  validation_data=val_ds,
                  callbacks=[earlystopper, checkpointer, annealer])

#保存模型
model.save('mtb_jet_best_model_mobilenet.h5')
print("The trained model has been saved.")

(b)Grad-CAM

import numpy as np
from PIL import Image, ImageOps
from tensorflow.python.keras.preprocessing import image
from tensorflow.python.keras.applications.mobilenet_v2 import preprocess_input
from tensorflow.python.keras.models import load_model
import tensorflow as tf
from tensorflow.python.keras import Model
import matplotlib.pyplot as plt

# 你的模型路径
model_path = 'mtb_jet_best_model_mobilenet.h5'

# 你的图像路径
image_path = './MTB/Tuberculosis/Tuberculosis-666.png'

# 加载你的模型
model = load_model(model_path)

def grad_cam(img_path, cls, model, layer_name='block_7_project'):
    # 加载图像并预处理
    img = image.load_img(img_path, target_size=(100, 100))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    # 获取预测类别
    preds = model.predict(x)
    pred_class = np.argmax(preds[0])

    # 使用 GradientTape 计算 Grad-CAM
    with tf.GradientTape() as tape:
        last_conv_layer = model.get_layer(layer_name)
        iterate = Model([model.inputs], [model.output, last_conv_layer.output])
        model_out, last_conv_layer = iterate(x)
        class_out = model_out[:, pred_class]

    # 得到的梯度
    grads = tape.gradient(class_out, last_conv_layer)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # 我们把梯度在每个特征图上进行平均
    heatmap = tf.reduce_mean(tf.multiply(pooled_grads, last_conv_layer), axis=-1)
    
    # 调整 heatmap 的形状和数值范围
    heatmap = tf.squeeze(heatmap)  # 去掉尺寸为1的维度
    heatmap = np.maximum(heatmap, 0)  # 去掉小于0的值
    max_heat = np.max(heatmap)
    if max_heat == 0:
        max_heat = 1e-10  # 防止除以0
    heatmap /= max_heat  # 归一化到0-1之间
    heatmap = np.uint8(255 * heatmap)  # 转换到0-255之间并转为uint8类型

    # 加载原始图像
    img = Image.open(img_path)

    # 将热力图转换为 PIL 图像并调整其尺寸
    heatmap = Image.fromarray(heatmap)
    heatmap = heatmap.resize((img.height, img.width))

    # 将单通道热力图转换为彩色(RGB)图像
    heatmap = ImageOps.colorize(heatmap, 'blue', 'red')

    # 将彩色热力图转换为带透明度的(RGBA)图像
    heatmap = heatmap.convert('RGBA')
    heatmap_with_alpha = Image.new('RGBA', heatmap.size)
    for x in range(heatmap.width):
        for y in range(heatmap.height):
            r, g, b, a = heatmap.getpixel((x, y))
            heatmap_with_alpha.putpixel((x, y), (r, g, b, int(a * 0.5)))

    # 将原始图像转换为 RGBA 图像
    img = img.convert('RGBA')

    # 叠加图像
    overlay = Image.alpha_composite(img, heatmap_with_alpha)

    # 将叠加后的图像转换为numpy数组
    overlay = np.array(overlay)

    # 使用matplotlib显示图像
    plt.imshow(overlay)
    plt.axis('off')  # 不显示坐标轴
    plt.show()
    
    print(pred_class)

# 绘制热力图
grad_cam(image_path, 0, model)

这个代码需要调整的参数就只有“layer_name”,也就是使用哪一层的信息来可视化。当然,首先我们得先知道每一层的名称:

#查看 Keras 模型每一层的名称
for layer in model.layers:
    print(layer.name)

输出如下:

然后,用哪一层呢?

其实吧,选择哪一层用于Grad-CAM的计算并没有一条明确的规则,这完全取决于你的模型结构以及你的具体需求。

一般来说,Convolutional Neural Networks(CNN,卷积神经网络)的前面几层往往捕捉到的是图像的低级特征,比如边缘、色彩和纹理等,而后面的层则可以捕捉到更为高级的特征,比如物体的部分或者整体。所以,如果你想要看到模型在判断图像时,主要关注了图像中的哪些部分或者物体,你可能需要选择离输出层更近一些的卷积层。

但是这也不是绝对的。在实际应用中,你可能需要尝试不同的层,看看哪一层生成的Grad-CAM热力图最能满足你的需求。

比如我试了试:'block_1_project':

 'block_7_project':

 'block_10_project':

 'block_2_add':

 综上,似乎一切随缘,太抽象了!!!

三、写在最后

略~

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/843419.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一文教你看懂Golang协程调度【GMP设计思想】

一文教你看懂Golang协程调度【GMP设计思想】 1 Golang调度器的由来 1.1 单进程的问题:进程阻塞、CPU浪费时间 单一执行程序、计算机只能一个任务一个任务来进行处理进程阻塞所带来的CPU浪费时间 1.2 多进程、多线程问题:设计复杂、高内存、CPU占用 设计…

面试热题(倒数第k个结点)

输入一个链表,输出该链表中倒数第k个节点。为了符合大多数人的习惯,本题从1开始计数,即链表的尾节点是倒数第1个节点。 例如,一个链表有 6 个节点,从头节点开始,它们的值依次是 1、2、3、4、5、6。这个链表…

通过cpolar内网穿透发布网页测试

通过内网穿透发布网页测试 文章目录 通过内网穿透发布网页测试 对于网站开发者来说,对完成的网页进行测试十分必要,同时还要在测试过程中充分采纳委托制作方的意见,及时根据甲方意见进行修改,但在传统的测试方式中,必须…

Scrum是什么意思,Scrum敏捷项目管理工具有哪些?

一、什么是Scrum? Scrum是一种敏捷项目管理方法,旨在帮助团队高效地开展软件开发和项目管理工作。 Scrum强调迭代和增量开发,通过将项目分解为多个短期的开发周期(称为Sprint),团队可以更好地应对需求变…

【CSS3】CSS3 2D 转换 - scale 缩放 ③ ( 使用 scale 设置制作可缩放的按钮案例 )

文章目录 一、需求分析二、代码分析三、代码示例四、执行结果 一、需求分析 设置一个 按钮 , 默认状态下显示的样式如下 : 按钮 外部 有 圆形的外边框 ;按钮 中的文本 , 水平居中对齐 , 垂直居中对齐 ; 当鼠标移动到 按钮 上之后 , 鼠标 变为 小手 样式 , 并且 按钮 以 中心位…

实战项目——多功能电子时钟

一,项目要求 二,理论原理 通过按键来控制状态机的状态,在将状态值传送到各个模块进行驱动,在空闲状态下,数码管显示基础时钟,基础时钟是由7个计数器组合而成,当在ADJUST状态下可以调整时间&…

五、PC远程控制ESP32 LED灯

1. 整体思路 2. 代码 # 整体流程 # 1. 链接wifi # 2. 启动网络功能(UDP) # 3. 接收网络数据 # 4. 处理接收的数据import socket import time import network import machinedef do_connect():wlan = network.WLAN(network.STA_IF)wlan.active(True)if not wlan.isconnected(…

LVS集群

目录 1、lvs简介: 2、lvs架构图: 3、 lvs的工作模式: 1) VS/NAT: 即(Virtual Server via Network Address Translation) 2)VS/TUN :即(Virtual Server v…

手写SpringCloud系列-一分钟理解微服务注册中心(Nacos)原理。

手写SpringCLoud项目地址,求个star github:https://github.com/huangjianguo2000/spring-cloud-lightweight gitee:https://gitee.com/huangjianguo2000/spring-cloud-lightweigh 一:什么是注册中心 1. 总结服务注册中心 我们可以理解注册中心就是一个…

LeetCode 热题 100JavaScript--2. 两数相加

给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外,这两个数都不会以 0 …

手机上的照片怎么压缩?推荐这几种压缩方法

手机上的照片怎么压缩?如果你需要通过电子邮件或短信发送照片,则可能需要将其压缩为较小的文件大小以便于发送。另外,如果您你的手机存储空间有限,可以通过压缩照片来节省空间。下面就给大家介绍几种压缩手机照片的方法。 1、使用…

Spring5.2.x 源码使用Gradle成功构建

一 前置准备 1 Spring5.2.x下载 1.1 Spring5.2.x Git下载地址 https://gitcode.net/mirrors/spring-projects/spring-framework.git 1.2 Spring5.2.x zip源码包下载,解压后倒入idea https://gitcode.net/mirrors/spring-projects/spring-framework/-/…

地球人口承载力估计 解析和C++代码

Description 假设地球上的新生资源按恒定速度增长。照此测算,地球上现有资源加上新生资源可供x亿人生活a年,或供y亿人生活b年。 为了能够实现可持续发展,避免资源枯竭,地球最多能够养活多少亿人? Input 一行&#xf…

共治、公开、透明!龙蜥社区 7 月技术委员会会议顺利召开!

2023 年 7 月 14 日上午 10 点召开了龙蜥社区7月技术委员会线上会议,共计 39 人参会,本次会议由浪潮信息苏志远博士主持,开放原子 TOC 导师陈阳、霍海涛、徐亮、余杰共同参会,技术委员们来自 Arm、阿里云、飞腾、海光、红旗软件、…

springcloud:对象存储组件MinIO(十六)

0. 引言 在实际开发中,我们经常会面临需要存储文档、存储图片等文件存储需求,并且在分布式架构下,文件又需要实现各节点共享,类似于共享文件夹类的需求,在分布式服务器中创建共享文件夹成本较大,甚至当需要…

Java课题笔记~ 不使用 AOP 的开发方式(理解)

Step1:项目 aop_leadin1 先定义好接口与一个实现类,该实现类中除了要实现接口中的方法外,还要再写两个非业务方法。非业务方法也称为交叉业务逻辑: doTransaction():用于事务处理 doLog():用于日志处理 …

第一天 什么是CSRF ?

✅作者简介:大家好,我是Cisyam,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Cisyam-Shark的博客 💞当前专栏: 每天一个知识点 ✨特色专…

【小沐学C++】C++ 基于CMake构建工程项目(Windows、Linux)

文章目录 1、简介2、下载cmake3、安装cmake4、测试cmake4.1 单个源文件4.2 同一目录下多个源文件4.3 不同目录下多个源文件4.4 标准组织结构4.5 动态库和静态库的编译4.6 对库进行链接4.7 添加编译选项4.8 添加控制选项 5、构建最小项目5.1 新建代码文件5.2 新建CMakeLists.txt…

neo4j查询语言Cypher详解(二)--Pattern和类型

Patterns 图形模式匹配是Cypher的核心。它是一种用于通过应用声明性模式从图中导航、描述和提取数据的机制。在MATCH子句中,可以使用图模式定义要搜索的数据和要返回的数据。图模式匹配也可以在不使用MATCH子句的情况下在EXISTS、COUNT和COLLECT子查询中使用。 图…

Java Map集合详解 :HashMap类

Map 是一种键-值对(key-value)集合,Map 集合中的每一个元素都包含一个键(key)对象和一个值(value)对象。用于保存具有映射关系的数据。 Map 集合里保存着两组值,一组值用于保存 Map …