python 可视化解释模型

news2025/1/16 0:58:08

1. 自定义DataSet

MakeDataset.py
首先准备好一个数据集文件,这里以mydata文件夹存放图片数据, 实现自定义DataSet

class MyDataset(Dataset):
	def __init__(self,resize):
		super(MyDataset,self).__init__()
		self.resize = resize
	
	def __len__(self):
		return len(images)
	
	def __getitem__(self,idx):
		img = images[idx]
		tf = transforms.Compose([
			lambda x:Image.open(x).convert('RGB'),
			transforms.Resize((self.resize,self.resize)),
			transforms.ToTensor(),
			transforms.Normalize(mean = [0.485,0.456,0.406],
										std = [0.229,0.224,0.225])		
		])
		img_tensor = tf(image)
		# `mydata\\ICH\\1470718-1.JPG`
		label_tensor = torch.tensor(class_name_index[image.split(os.sep)[-2]])
		return img_tensor,label_tensor

2 模型定义及训练

2.1 模型

这里以一个玩具模型作为演示,模型的定义如下:
MyModle.py

class MyResNet(nn.Module):
	def __init__(self):
		super(MyResNet,self).__init__()
		general_features = 32
		
		# Initial convolution block
		self.conv0 = nn.Conv2d(3,general_features,3,1,padding=1)
		self.conv1 = nn.Conv2d(general_features,general_features,3,1,padding =1)
		self.relu1 = nn.ReLU()
		self.conv2 = nn.Conv2d(general_features,general_features,3,1,padding=1)
		self.relu2 = nn.ReLU()
		
		# Down sample 1/2
		self.downsample0 = nn.Maxpool2d(2,2)
		self.downsample1 = nn.Maxpool2d(2,2)
		self.downsample2 = nn.Maxpool2d(2,2)
		self.downsample3 = nn.Maxpool2d(2,2)
		
		self.fc0 = nn.Linear(32*8*8, 2)
	
	def forward(self,x):
		x = self.conv0(x)			#[1,32,128,128]
		x = self.downsample0(x)		#[1,32,64,64]
		x = self.downsample1(x)		#[1,32,32,32]
		
		x = self.relu1(self.conv1(x)) #[1,32,32,32]
		x = self.downsample2(x)		  # [1,32,16,16]
		
		x = self.relu2(self.conv2(x)) #[1,32,16,16]
		x = self.downsample3(x)		  # [1,32,8,8]
		
		x = x.view(x.shape[0],-1)     # Flatten
		
		x = x.softmax(self.fc0(x),dim=1)
		return x

# x = torch.randn(1,3,128,128)
# m = myResNet()
# summary(m,(3,128,128))
# print(m(x).shape)	

2.2 训练

训练train.py获得权重文件

import torch
from torch import optim,nn
from torch.utils.data import Dataloader
from MakeDataSet import MyDataset

from MyModel import MyResNet

train_db = MyDataset(resize = 128)
train_loader = DataLoader(train_db,batch_size=4,shuffle=True)
print('num_train:',len(train_loader.dataset))

model = MyResNet()

optimizer = optim.Adam(model.parameters(),lr =0.001)
criteon = nn.CrossEntropyLoss()

epochs = 5

for epoch in range(epochs):
	for step,(x,y) in enumerate(train_loader):
		model.train()
		logits = model(x)
		loss = criteon(logits,y)
		
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
	print('Epochs:',epoch,'Loss:',loss)

torch.save(model.state_dict(),'weights_MyResNet.mdl')
print('Save Done')

3 利用SmoothGradCAMpp对特征图可视化

Visualize_featrue_map, 这里介绍smooth gradcampp用法

import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from torchcam.methods import SmoothGradCAMpp,CAM,GradCAM,GradCAMpp,XGradCAM,ScoreCAM
from torchcam.utils import overlay_mask
from MyModel import MyResNet
from PIL import image
import matplotlib.pyplot import plt

tf = transforms.Compose([
		lambda x:Image.open(x).convert('RGB')
		transforms.Resize(128,128)
		transforms.ToTensor(),
		transforms.Normalize(mean = [0.485,0.456,0.406],
							std = [0.229,0.224,0.225]
							)
])

img_ICH_test = tf('ICH_test.jpg').unsqueeze(dim=0)
#print(img_ICH_test.shape)

img_Normal_test = tf('Normal_test.jpg').unsqueeze(dim=0)

model = MyResNet()
model.load_state_dict(torch.load('weights_MyResNet.mdl'))
print('loaded from ckpt')
model.eval()

cam_extractor = SmoothGradCAMpp(model,input_shape=(3,128,128))
# cam_extractor = GradCAMpp(model,input_shape=(3,128,128))
# cam_extractor = XGradCAM(model,input_shape=(3,128,128))
# cam_extractor = ScoreCAM(model,input_shape=(3,128,128))
# cam_extractor = SSCAM(model,input_shape=(3,128,128))
# cam_extractor =ISCAM(model,input_shape=(3,128,128))
# cam_extractor = LayerCAM(model,input_shape=(3,128,128))
  • 载入测试图片Normal_test.jpg
    在这里插入图片描述
  • 加载预训练权重,实例化模型
output = model(img_Normal_test)
print(output)

activation_map = cam_extractor(output.sequeeze(0).argmax().item(),output)
print(activation_map[0],activation_map[0].min(),activation_map[0].max(),activation_map[0].shape)

#fused_map = cam_extractor.fuse_cams(activation_map)
#print(fused_map[0],fused_map[0].min(),fused_map[0].max(),fused_map[0].shape)


result = overlay_mask(to_pil_image(img_Normal_test[0]),
					to_pil_image(activation_map[0],mode='F'),alpha=0.3)
plt.imshow(result)
plt.show()
  • 将模型的输出预测类别索引送入到构建的cam_extractor对象中,由于activation_map输出的是一个tuple,通过索引0取值
  • 接着用overlay_mask进行可视化效果展示,传入原图和激活map,并利用alpha参数设置一定的透明度
  • 由于输出的result是PIL格式,所以可以直接用imshow显示
    在这里插入图片描述
    最热的区域就是模型主要依据这部分来判断类别,这里没有指定可视化feature map的哪一层的话,就默认是全连接测上一层feature map

这个包的主页在: https://pypi.org/project/torchcam/,感兴趣的可以看看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/10265.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【软件分析第13讲-学习笔记】符号执行 Symbolic Execution

文章目录前言正文符号执行基于霍尔逻辑的符号执行谓词转换计算最弱前置条件动态符号执行符号执行:进一步探究小结参考文献前言 创作开始时间:2022年11月16日18:46:31 如题,学习一下符号执行 Symbolic Execution的相关知识。参考&#xff1a…

计算机毕业设计jsp家校互动系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 家校互动系统 是一套完善的web设计系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发,数据库为Mysql,使用jav…

Linux用户操作(22.9.21)

学习目标: 用户账号管理Linux用户操作Linux用户组操作(一)用户账号管理 1、用户与用户组文件 在Linux系统当中,默认情况下所有用户信息保存在 /etc/passwd文件内,用户密码信息保存在/etc/shadow文件内;所…

43、Spring AMQP TopicExchange

1、TopicExchange 2、案例 3、通过配置类实现 1、配置TopicConfig 2、添加Listener 3、测试结果 4、通过注解实现 1、配置Linstener 2、测试结果 5、总结分析 学到这里,关于RabbitMQ的五种消息模型就结束了。 1、第一种消息模型:单个队列&#xff0c…

相机模型总结

目录相机模型前言1. pinhole 针孔模型2. Omnidirectional Camera Model 全向相机模型2.1 Unified model for catadioptric cameras 反射式相机统一模型2.2 Extended Unified model for catadioptric cameras (EUCM)2.3 Omnidirectional Camera Model By Scaramuzza畸变模型1. E…

linux篇【10】:进程信号

目录 一.信号入门 1.信号是操作系统内一个内置机制 2.前后台进程的几条命令与ctrlc 3.信号分类 4.信号产生是异步的 5.进程是如何记住这个信号 (3)存储方式:位图 二.signal ——对某信号设置自定义行为(捕捉)的函数 (1&a…

【Linux】进程间通信之消息队列

系列目录 进程间通信——共享内存 进程间通信——信号量 文章目录 一、概念 二、消息队列函数 1.msgget 2.magsnd 3.msgrcv 4.msgctl 三、掌握消息队列操作 一、概念 提供了一种从另一种进程发送一个数据块的方法。而且每个数据块都被认为含有一个类型,接…

Python3《机器学习实战》学习笔记(十):ANN人工神经网络代码详解(数字识别案例以及人脸识别案例)

文章目录一、构建基本代码结构1.1预处理数据的工具包1.2 初始化参数1.3工具类sigmoid1.4工具类矩阵变换1.5初始化theta1.6正向传播1.7反向传播1.8梯度下降1.9训练模块二、MNIST数字识别三、人脸识别四、总结一、构建基本代码结构 1.1预处理数据的工具包 """Dat…

2021年认证杯SPSSPRO杯数学建模C题(第一阶段)破局共享汽车求解全过程文档及程序

2021年认证杯SPSSPRO杯数学建模 C题 破局共享汽车 原题再现: 自 2015 年以来,共享汽车行业曾经“百花齐放”,多个项目获得巨额融资。但因为模式过重、运营成本过高、无法盈利等问题,陆续有共享汽车公司因为资金链断裂而倒闭。据…

RocketMQ存储设计的奥妙

RocketMQ作为一款基于磁盘存储的中间件,具有无限积压能力,并提供高吞吐、低延迟的服务能力,其最核心的部分必然是它优雅的存储设计。 1、存储概述 RocketMQ存储的文件主要包括Commitlog文件、ConsumeQueue文件、Index文件。 RocketMQ将所有…

温振传感器有几种传输方式?

在现代化社会中,各种机器无时无刻参与着我们的日常生活,承担在我们的周围承担起重要作用,轴承、电机、泵体等也成为工业文明中关键存在,它们的温度和状态影响着整个工业自动化系统运行的健康和效率。 长期以来,传感器技…

数字集成电路设计(四、Verilog HDL数字逻辑设计方法)(一)

文章目录1.Verilog语言的设计思想和可综合特性2. 组合电路的设计2.1 数字加法器2.2 数据比较器2.3 数据选择器2.4 数字编码器2.4.1 3位二进制8线-3线编码器2.4.2 8线-3线优先编码器2.4.3 二进制转化十进制8421BCD编码器(重要)2.4.4 8421BCD十进制余3编码…

ue4使用Niagara粒子实现下雨效果,使用蓝图调节雨量

一、使用Niagara粒子系统实现下雨效果 1. 首先创建一个雨水的材质 新建 — 材质 2. 创建Niagara系统 新建 新建 — FX — Niagara系统 — 来自所选发射器的新系统 — 下一步 — 选择Fountain — 点击号,点击完成 删除下面的“Add Velocity in Cone” 添加“…

矩池云如何自定义端口,访问自己的web项目

本文将向您介绍如何在矩池云租用服务器的时候自定义端口,并将您的 web 项目部署到自定义端口,最后实现在本地通过自定义端口对应链接访问服务。 上传代码和数据 首先,您需要将本地的项目代码和数据上传到矩池云网盘。这里为了方便您测试使用…

类似ps的python工具lama cleaner

Lama Cleaner是个类似ps图片的工具,可以把图片中不想要的部分p掉,或者填补图片中丢失的部分。用下来感觉还蛮靠谱,对于不会ps的人是福音,记录一下。 相关介绍:https://github.com/Sanster/lama-cleaner 1.安装 安装…

react 中 ref 管理列表

背景 最近在看 react 新的官方文档 的时候,看到这么一个标题,How to manage a list of refs using a ref callback,就是一个图片的列表,类似这样 然后点击按钮的时候,通过 scrollIntoView 这个 api 来让他滚动&#…

python生成模拟微信气泡图片

0. 起因 众所周知,借刀杀人最为致命,聊天也是如此。 最近我的群聊画风逐渐变味: 当然,这种图片的生产成本很低,只需在设置页关闭昵称显示,把聊天背景重置为灰色,然后利用截图工具截图&#xf…

【金融项目】尚融宝项目(十三)

25、充值 25.1、需求介绍 25.1.1、投资人充值 **1、需求描述 ** 标的产生后,平台展示标的,投资人就可以在平台投资标的,获取收益;投资人投资标的必须满足以下条件: 充值过程与绑定过程一致,也是在平台发…

Delphi 11.2 Alexandria程序集代码

Delphi 11.2 Alexandria程序集代码 高DPI VCL设计器-VCL设计器现在在设计时使用类似Microsoft Windows的样式,这意味着除非禁用此功能,否则设计器中的控件始终使用此样式绘制。此样式与Windows当前使用的浅色或深色主题相匹配。 编辑器选项卡-在版本11.2…

【3D目标检测】Frustum PointNets for 3D Object Detection from RGB-D Data

目录概述细节网络结构视锥候选框3D实例分割边界框参数回归损失函数概述 首先本文是基于图像和点云的,属于早期的模态融合的成果,是串行的算法,而非并行的,更多的是考虑如何根据图像和点云这两个模态的数据进行3D目标检测。 提出动…