import torch
from torch.nn import functional as F
from torchvision import models, transforms
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
# 加载经过训练的 ResNet 模型
model = models.resnet50(pretrained=True)
model.eval()
# 载入图像并进行预处理
image_path = 'airline.png'
image = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image).unsqueeze(0)
# 前向传播获取特征图
with torch.no_grad():
features = model.conv1(input_tensor)
features = model.layer1(features)
features = model.layer2(features)
features = model.layer3(features)
features = model.layer4(features)
# 获取模型的权重
weight = model.fc.weight
print(1)
# 假设 cam 和 resized_tensor 是 PyTorch 张量
# 将它们转换为 NumPy 数组
import cv2
bz, nc, h, w = features.shape
beforeDot = features.reshape((nc, h*w))
cam = torch.matmul(weight[1], beforeDot)#404
cam = cam.reshape(h, w)
size_upsample = (256, 256)
cam = cam - torch.min(cam)
cam_img = cam / torch.max(cam)
# cam_img = torch.uint8(255 * cam_img)
# import torch
import torch.nn.functional as F
# 使用 interpolate 函数将其调整为 [224, 224]
resized_tensor = F.interpolate(cam_img.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)
# 现在 resized_tensor 是一个大小为 [1, 1, 224, 224] 的 PyTorch 张量
# 如果需要,你可以使用 .squeeze() 方法来移除不必要的维度
output_cam = resized_tensor.squeeze()
import numpy as np
cam_np = output_cam.detach().numpy()
# 假设 image 是你的图像数据
# cam_np = cam_np.astype(np.uint8)
resized_tensor_np = input_tensor.detach().numpy()
# 将 image 的形状调整为 (3, 224, 224)
image = resized_tensor_np.squeeze()
# 转换图像通道顺序,从 (3, 224, 224) 调整为 (224, 224, 3)
image = np.transpose(image, (1, 2, 0))
import matplotlib.pyplot as plt
# 创建一个新的图形
plt.figure(figsize=(8, 8))
# 绘制原始图像
plt.subplot(1, 2, 1)
plt.imshow(image)#, cmap='gray')
plt.title('Original Image')
# 绘制 CAM
plt.subplot(1, 2, 2)
plt.imshow(cam_np, cmap='jet') # 使用 'jet' 颜色映射以突出 CAM
plt.title('Class Activation Map (CAM)')
# 显示图形
plt.show()