Variations-of-SFANet-for-Crowd-Counting可视化代码

news2024/11/15 7:34:09

前文对Variations-of-SFANet-for-Crowd-Counting做了一点基础梳理,链接如下:Variations-of-SFANet-for-Crowd-Counting记录-CSDN博客

本次对其中两个可视化代码进行梳理

1.Visualization_ShanghaiTech.ipynb

不太习惯用jupyter notebook, 这里改成了python代码测试,下面代码提到的测试数据都是项目自带的,权重自己下载一下吧,前文提到了一些需要下载的权重或者数据。

import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
from matplotlib import cm as CM

import os
import numpy as np
from scipy.io import loadmat
from PIL import Image; import cv2
import torch
from torchvision import transforms
from models import M_SFANet
part = 'B'; index = 4
DATA_PATH = f"./ShanghaiTech_Crowd_Counting_Dataset/part_{part}_final/test_data/"
fname = os.path.join(DATA_PATH, "ground_truth", f"GT_IMG_{index}.mat")
img = Image.open(os.path.join(DATA_PATH, "images", f"IMG_{index}.jpg")).convert('RGB')
plt.imshow(img)
plt.gca().set_axis_off()
plt.show()
gt = loadmat(fname)["image_info"]
location = gt[0, 0][0, 0][0]
count = location.shape[0]
print(fname)
print('label:', count)
model = M_SFANet.Model()
model.load_state_dict(torch.load(f"./ShanghaitechWeights/checkpoint_best_MSFANet_{part}.pth", 
                                 map_location=torch.device('cpu'))["model"]);
trans = transforms.Compose([transforms.ToTensor(), 
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                           ])

height, width = img.size[1], img.size[0]
height = round(height / 16) * 16
width = round(width / 16) * 16
img = cv2.resize(np.array(img), (width,height), Image.BILINEAR)
img = trans(Image.fromarray(img))[None, :]
model.eval()
density_map, attention_map = model(img)
print('Estimated count:', torch.sum(density_map).item())
print("Visualize estimated density map")
plt.gca().set_axis_off()
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(density_map[0][0].detach().numpy(), cmap = CM.jet)
# plt.savefig(fname=..., dpi=300)
plt.show()

运行结果如下,还有两张可视化的图

上面这样看是不是不太直观,下面这张图够直观

2.Visualization_UCF-QNRF.ipynb

同上改成了python代码测试

import torch
import os
import numpy as np
from datasets.crowd import Crowd
from models.vgg import vgg19
import argparse
from PIL import Image
import cv2
import sys
# sys.path.insert(0, '/home/pongpisit/CSRNet_keras/')
from models import M_SegNet_UCF_QNRF
from matplotlib import pyplot as plt
from matplotlib import cm as CM
datasets = Crowd(os.path.join('/home/pongpisit/CSRNet_keras/CSRNet-keras/wnet_playground/W-Net-Keras/data/UCF-QNRF_ECCV18/processed/', 'test'), 512, 8, is_gray=False, method='val')
dataloader = torch.utils.data.DataLoader(datasets, 1, shuffle=False,
                                         num_workers=8, pin_memory=False)
model = M_SegNet_UCF_QNRF.Model()
device = torch.device('cuda')
model.to(device)
# model.load_state_dict(torch.load(os.path.join('./u_logs/0331-111426/', 'best_model.pth'), device))
model.load_state_dict(torch.load(os.path.join('./seg_logs/0327-172121/', 'best_model.pth'), device))
model.eval()

epoch_minus = []
preds = []
gts = []

for inputs, count, name in dataloader:
    inputs = inputs.to(device)
    assert inputs.size(0) == 1, 'the batch size should equal to 1'
    with torch.set_grad_enabled(False):
        outputs = model(inputs)
        temp_minu = count[0].item() - (torch.sum(outputs).item())
        preds.append(torch.sum(outputs).item())
        gts.append(count[0].item())
        print(name, temp_minu, count[0].item(), torch.sum(outputs).item())
        epoch_minus.append(temp_minu)

epoch_minus = np.array(epoch_minus)
mse = np.sqrt(np.mean(np.square(epoch_minus)))
mae = np.mean(np.abs(epoch_minus))
log_str = 'Final Test: mae {}, mse {}'.format(mae, mse)
print(log_str)
met = []
for i in range(len(preds)):
    met.append(100 * np.abs(preds[i] - gts[i]) / gts[i])

idxs = []
for i in range(len(met)):
    idxs.append(np.argmin(met))
    if len(idxs) == 5: break
    met[np.argmin(met)] += 100000000
print(set(idxs))
def resize(density_map, image):
    density_map = 255*density_map/np.max(density_map)
    density_map= density_map[0][0]
    image= image[0]
    print(density_map.shape)
    result_img = np.zeros((density_map.shape[0]*2, density_map.shape[1]*2))
    for i in range(result_img.shape[0]):
        for j in range(result_img.shape[1]):
            result_img[i][j] = density_map[int(i / 2)][int(j / 2)] / 4
    result_img  = result_img.astype(np.uint8, copy=False)
    return result_img

def vis_densitymap(o, den, cc, img_path):
    fig=plt.figure()
    columns = 2
    rows = 1
#     X = np.transpose(o, (1, 2, 0))
    X = o
    summ = int(np.sum(den))
    
    den = resize(den, o)
    
    for i in range(1, columns*rows +1):
        # image plot
        if i == 1:
            img = X
            fig.add_subplot(rows, columns, i)
            plt.gca().set_axis_off()
            plt.margins(0,0)
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
            plt.imshow(img)
            
        # Density plot
        if i == 2:
            img = den
            fig.add_subplot(rows, columns, i)
            plt.gca().set_axis_off()
            plt.margins(0,0)
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
            plt.text(1, 80, 'M-SegNet* Est: '+str(summ)+', Gt:'+str(cc), fontsize=7, weight="bold", color = 'w')
            plt.imshow(img, cmap=CM.jet)
    
    filename = img_path.split('/')[-1]
    filename = filename.replace('.jpg', '_heatpmap.png')
    print('Save at', filename)
    plt.savefig('seg_'+filename, transparent=True, bbox_inches='tight', pad_inches=0.0, dpi=200)
    processed_dir = '/home/pongpisit/CSRNet_keras/CSRNet-keras/wnet_playground/W-Net-Keras/data/UCF-QNRF_ECCV18/processed/test/'
    model.eval()
    c = 0
    for inputs, count, name in dataloader:
        img_path = os.path.join(processed_dir, name[0]) + '.jpg'
        if c in set(idxs):
            inputs = inputs.to(device)
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                
                img = Image.open(img_path).convert('RGB')
                height, width = img.size[1], img.size[0]
                height = round(height / 16) * 16
                width = round(width / 16) * 16
                img = cv2.resize(np.array(img), (width,height), cv2.INTER_CUBIC)
                
                print('Do VIS')
                vis_densitymap(img, outputs.cpu().detach().numpy(), int(count.item()), img_path)
                c += 1        
        else:
            c += 1

但是该代码要用UCF-QNRF_ECCV18数据集,官网的太慢了,给个靠谱的链接:UCF-QNRF_数据集-阿里云天池

下载下来,然后利用bayesian_preprocess_sh.py这个代码处理一下就可以用于上述代码了,注意一下UCF-QNRF_ECCV18的mat文件中点坐标的读取代码有点问题,自己输出一下mat文件信息就看得出来了。输出文件夹中会有相应的jpg和npy文件。

运行可视化代码,这期间遇到了一个报错

ImportError: cannot import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant' (C:\Anaconda3\lib\site-packages\charset_normalizer\constant.py)

邪门解决方案,安装一个chardet

pip install chardet -i https://pypi.tuna.tsinghua.edu.cn/simple

要是上述方法还不好使就换一个,更新一下charset_normalizer,或者卸载重装charset_normalizer

pip install --upgrade charset-normalizer

要是出现如下报错

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

把代码中的num_workers改成0,跑起来结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1153110.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spring解决后端显示时区的问题

spring解决后端显示时区的问题 出现的问题: 数据库中的数据: 解决方法 spring:jackson:date-format: yyyy-MM-dd HH:mm:sstime-zone: Asia/Shanghai

vscode前端必备插件

安装插件的位置如下: 1、Chinese (Simplified) Language Pack 中文简体插件 2、Vetur Vue官方钦定插件,包括:语法高亮,智能提示,错误提示,格式化,自动补全等等 3、ESLint 语法检查工具&#…

客户端性能测试基础知识

目录 1、客户端性能 1.1、客户端性能基础知识 2、客户端性能工具介绍与环境搭建 2.1.1、perfdog的使用 2.1.2、renderdoc的使用 1、客户端性能 1.1、客户端性能基础知识 客户端性能知识这里对2D和3D类游戏进行展开进行,讲述的有内存、CPU、GPU、帧率这几个模块…

云栖大会十五年:开放创新,未来愿景

时光荏苒,转眼间云栖大会已经走过了十五个年头,这一场中国云计算行业的盛会已经成为业内不可或缺的一部分。在这个特殊的时刻,我想分享一些对未来云栖大会的期待与建议,希望这个盛会能够继续推动云计算领域的创新和发展。 云栖大会…

数据库深入浅出,数据库介绍,SQL介绍,DDL、DML、DQL、TCL介绍

一、基础知识: 1.数据库基础知识 数据(Data):文本信息(字母、数字、符号等)、音频、视频、图片等; 数据库(DataBase):存储数据的仓库,本质文件,以文件的形式将数据保存到电脑磁盘中 数据库管理系统(DBMS)&…

LSF 概览——了解 LSF 是如何满足您的作业要求,并找到最佳资源来运行该作业的

LSF 概览 了解 LSF 是如何满足您的作业要求,并找到最佳资源来运行该作业的。 IBM Spectrum LSF ("LSF", load sharing facility 的简称) 软件是行业领先的企业级软件。LSF 将工作分散在现有的各种 IT 资源中,以创建共享的,可扩展…

国内内卷太严重,还不考虑一下在海外接单?那这几个平台你知道吗?

作为一个程序员,在平台上接单赚点外快是再正常不过的事情了,但是现今国内各个平台都内卷比较严重,你是否考虑过去“外面的世界”看看? 如果想过,那么这几个外国的接单平台你都知道吗? 接下来就和我一起来看…

vmWare虚拟机扩容及pip国内镜像源

扩展虚拟机容量 打开虚拟机.sudo apt-get install gparted pip镜像源 pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple-i https://pypi.douban.com/simple-i https:// mirrors.aliyun.com/pypi/simple

Linux CentOS7 shell

学好linux,首先要深入理解shell。 shell俗称壳,它包裹在内核的外面,是用户命令的翻译官。 作用:接收用户的命令,翻译后(处理一下)交给Linux内核处理。 用户执行命令 -> shell -> 内核 -> CPU -> 内核 -…

C/C++笔试易错与高频题型图解知识点(三)——数据结构部分(持续更新中)

目录 1. 排序 1.1 冒泡排序的改进 2. 二叉树 2.1 二叉树的性质 3. 栈 & 队列 3.1 循环队列 3.2 链式队列 4. 平衡二叉搜索树——AVL树、红黑树 5 优先级队列(堆) 1. 排序 1.1 冒泡排序的改进 下面的排序方法中,关键字比较次数与记录的初…

LeetCode 996.正方形数组的数目

和上一道状压的区别在于我们要去重一下~ 思路都是和上一篇博客是一样的&#xff0c;感兴趣的同学可以看一下 const int N 15; int dp[1<<N][N]; int n; vector<int>nums1;bool check(int x){int tem sqrt(x);if(tem*temx)return 1;return 0; }int dfs(int u,in…

比较Excel中的两列目录编号是否一致

使用java代码比较excel中两列是否有包含关系&#xff0c;若有包含关系&#xff0c;核对编号是否一致。 excel数据样例如下&#xff1a; package com.itownet.hg;import org.apache.poi.xssf.usermodel.XSSFSheet; import org.apache.poi.xssf.usermodel.XSSFWorkbook;import j…

网站如何改成HTTPS访问

在今天的互联网环境中&#xff0c;将网站更改成HTTPS访问已经成为了一种标准做法。HTTPS不仅有助于提高网站的安全性&#xff0c;还可以提高搜索引擎排名&#xff0c;并增强用户信任。因此&#xff0c;转换为HTTPS是一个重要的举措&#xff0c;无论您拥有个人博客、电子商务网站…

如何将你的PC电脑数据迁移到Mac电脑?使用“迁移助理”从 PC 传输到 Mac的具体操作教程

有的小伙伴因为某一项工作或者其它原因由Windows电脑换成了Mac电脑&#xff0c;但是数据和文件都在原先的Windows电脑上&#xff0c;不知道怎么传输。接下来小编就为大家介绍使用“迁移助理”将你的通讯录、日历、电子邮件帐户等内容从 Windows PC 传输到 Mac 上的相应位置。 在…

PicoDiagnostics (NVH设备软件)-Mongoose识别不了VIN码

如果Mongoose J2534诊断线识别不到车辆的VIN码&#xff0c;通常在PD软件中会像下图那样提示。 遇到这种情况&#xff0c;首先确保你的电脑是否已经安装J2534驱动&#xff1a;打开【设备管理器】&#xff0c;如果你将示波器和Mongoose J2534诊断线连接到电脑&#xff0c;【设备管…

EtherCAT FP介绍系列文章—RAS

RAS扩展功能包是acontis公司在EC-Master EtherCAT主站基础上提供的一套基于TCP/IP的客户端/服务器架构的Remote API。Remote API旨在远程API提供了一个接口&#xff0c;解决在操作系统中当第二个进程&#xff08;例如OPC服务器&#xff09;可能访问EtherCAT总线的数据或在Ether…

oracle 校验左括号和有括号是否对称匹配

校验数据比如名称字段的左括号和有括号是否匹配。不匹配情况有&#xff1a; 左括号是英文的&#xff0c;右括号是中文的&#xff1b;右括号是中文的&#xff0c;左括号是英文的&#xff1b; 通过正则表达式对名称进行校验&#xff0c;校验脚本如下&#xff1a; SELECT NAMEFR…

超越YOLOv8?基于Gold YOLO的自定义数据集训练

Gold-YOLO的出色性能是对Noahs Ark Lab团队的奉献和专业知识的证明。它不仅超越了其前身YOLOv8&#xff0c;还为实时目标检测设定了新标准。凭借其闪电般快速的处理能力和出色的准确性&#xff0c;Gold-YOLO承诺革命化一系列应用&#xff0c;从自动驾驶车辆到监视系统等等。 我…

影响产品开发决策的认知偏见

认知偏见存在于每个人的内心&#xff0c;并在不断影响人们的工作和生活。认识并承认自己有偏见&#xff0c;并寻求相应的解决方案&#xff0c;可以帮助我们更好的做出产品决策、团队建设和架构设计。原文: The cognitive biases that influence product development decisions …

某国产中间件企业:提升研发安全能力,助力数字化建设安全发展

​某国产中间件企业是我国中间件领导者&#xff0c;国内领先的大安全及行业信息化解决方案提供商&#xff0c;为各个行业领域近万家企业客户提供先进的中间件、信息安全及行业数字化产品、解决方案及服务支撑&#xff0c;致力于构建安全科学的数字世界&#xff0c;帮助客户实现…