Pytorch 计算Monte-Carlo Dropout不确定度

news2024/10/21 15:16:56

为了实现Monte Carlo Dropout (MC Dropout),我们需要在模型评估阶段保留Dropout层的功能,而不是像通常那样在评估模式下关闭Dropout。这可以通过在预测过程中多次运行模型,并且每次运行时都启用Dropout来完成。下面是如何修改你的代码以实现MC Dropout的步骤:

参考文献: Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learnin

1设置模型为训练模式:

即使是在评估时,也需要将模型设置为train()模式,这样Dropout层才会工作。不过需要注意的是,这样做可能会导致Batch Normalization等层的行为发生变化,所以如果你的模型中使用了这些层,可能需要额外处理。

2多次预测:

对于每个样本,你需要多次通过模型进行前向传播,每次都会因为Dropout的影响产生不同的输出。

3计算均值和方差:

对于每个样本的所有预测结果,计算均值作为最终预测值,同时计算方差来估计模型的不确定性。

具体代码见以下的6、7节

import torch
from torch.utils.data import DataLoader, random_split
from dataset import split_dataset, find_bmp_files, BMPDataset
from model import  MobileNetV2
import pandas as pd
import numpy as np

# 1、设定随机种子
torch.manual_seed(40)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(40)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# 2、数据集初始化和分割
directory_path = './data/'
bmp_file_paths = find_bmp_files(directory_path)
train_ratio = 0
val_ratio = 1
test_ratio = 0.0
dataset = BMPDataset(bmp_file_paths)
total_length = len(dataset)
train_length = int(train_ratio * total_length)
val_length = int(val_ratio * total_length)
test_length = total_length - train_length - val_length
_, val_dataset, _ = random_split(dataset, [train_length, val_length, test_length])

print(len(val_dataset))
# 3、定义数据加载器
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# 4、初始化模型、设备和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MobileNetV2().to(device)
# 5、加载模型权重
state_dict = torch.load('model.pth', map_location=device)  # 直接加载到指定设备
model.load_state_dict(state_dict)

# 6、定义预测次数T
T = 10  # 可以调整这个数字来增加或减少预测次数

# 7、测试模型
all_predictions = []
all_predictions_variances = []
all_labels = []
all_image_names = []

model.train()  # 开启Dropout

with torch.no_grad():
    for images, labels, image_names in val_loader:
        predictions_list = []
        for t in range(T):
            predictions = model(images.to(device))
            predictions_list.append(predictions.cpu().numpy())
        
        # 计算预测的均值和方差
        predictions_array = np.array(predictions_list)
        mean_predictions = np.mean(predictions_array, axis=0)
        var_predictions = np.var(predictions_array, axis=0)
        
        all_predictions.extend(mean_predictions)
        all_predictions_variances.extend(var_predictions)
        all_labels.extend(labels.cpu().numpy())
        all_image_names.extend(image_names)

# 8、将预测结果、标签和图像名称合并到DataFrame中
results_df = pd.DataFrame({
    'Image Name': all_image_names,
    'Predicted S Mean': [pred[0] for pred in all_predictions],
    'Predicted T Mean': [pred[1] for pred in all_predictions],
    'Predicted S Variance': [var[0] for var in all_predictions_variances],
    'Predicted T Variance': [var[1] for var in all_predictions_variances],
    'Actual S': [label[0] for label in all_labels],
    'Actual T': [label[1] for label in all_labels],
})

# 9、保存结果到Excel文件
results_df.to_excel('MC_dropout.xlsx', index=False)

print("Test results with MC Dropout saved to 'MC_dropout.xlsx'")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2220109.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【记录】VSCode|自用设置项

文章目录 1 基础配置1.1 自动保存1.2 编辑区自动换行1.3 选项卡换行1.4 空格代替制表符1.5 开启滚轮缩放 2 进阶设置2.1 选项卡不自我覆盖2.2 选项卡限制宽度2.3 选项卡组限制高度2.4 字体设置2.5 字体加粗2.6 侧边栏2.7 沉浸式代码模式 Zen Mode2.8 设置 Zen 模式的选项卡组 3…

Vxe UI vue vxe-table grid 如何滚动、定位到指定行或列

Vxe UI vue vxe-table vxe-grid 在表格中有时候需要对数据会列进行操作。可以会定位到某一行或某一列,vxe-table 中提供了丰富的函数式 API,可以轻松对行与列进行各种的灵活的操作。 定位到指定行与列 通过调用 scrollColumn(columnOrField) 方法&…

【Python】图形用户界面

在Python中,开发图形用户界面(GUI)的工具包有许多种,常用的包括: Tkinter:Python的标准GUI库,作为Python内置的一部分,简单易用,适用于轻量级应用。PyQt/PySide&#xf…

探索面向对象编程的核心:类、对象与封装

探索面向对象编程的核心:类、对象与封装 在学习Java编程时,面向对象编程(OOP)是一个非常重要的核心概念。今天我们将深入探讨其中最基本、但却非常重要的组成部分:类和对象,以及它们是如何通过封装来实现数…

全网免费的文献调研方法以及获取外网最新论文、代码和翻译pdf论文的方法(适用于硕士、博士、科研)

1. 文献调研 学术搜索引擎(十分推荐前三个,超有用):使用 Google Scholar(https://scholar.google.com/)(https://scholar.google.com.tw/)(巨人学术搜索‬‬)、(三个都可以,镜像网站) arXiv(https://arxiv.org/)、&am…

企业架构系列(21)ArchiMate建模ADM阶段A:架构愿景

从本篇开始,将通过6篇文章逐一介绍如何使用 ArchiMate 的特定视角来创建与 TOGAF 架构开发方法相关的图形化模型或图表(即,ADM中的图形制品,Graphical Artifacts)。这些制品让利益相关者以可视化的方式来理解架构内容&…

vscode中每个打开的文件都显示在一个单独的标签页中

版本:1.94 实现步骤: 1、打开设置 File-》Preferences-》Settings 2、具体设置 2.1、在配置中搜索 workbench.editor.showTabs 设置为multiple。 2.2、在配置中搜索 workbench.editor.enablePreview 取消勾选。 根据这个功能的说明,在…

Java项目-基于springcloud框架的分布式架构网上商城系统项目实战(附源码+文档)

作者:计算机学长阿伟 开发技术:SpringBoot、SSM、Vue、MySQL、ElementUI等,“文末源码”。 开发运行环境 开发语言:Java数据库:MySQL技术:SpringBoot、Vue、Mybaits Plus、ELementUI工具:IDEA/…

背包九讲——完全背包问题

目录 完全背包问题 问题定义 动态规划解法 状态转移方程 初始化 遍历顺序 三种解法: 朴素版——枚举k 进阶版——dp正推(一维滚动数组) 背包问题第三讲——完全背包问题 背包问题是一类经典的组合优化问题,通常涉及在限定…

PCB走线线径与电流关系

转载自一个实验搞明白PCB走线应该画多宽_哔哩哔哩_bilibili

2011年国赛高教杯数学建模A题城市表层土壤重金属污染分析解题全过程文档及程序

2011年国赛高教杯数学建模 A题 城市表层土壤重金属污染分析 随着城市经济的快速发展和城市人口的不断增加,人类活动对城市环境质量的影响日显突出。对城市土壤地质环境异常的查证,以及如何应用查证获得的海量数据资料开展城市环境质量评价,研…

什么是智能电网?

智能电网(Smart Grid)被认为是当今电力行业发展的重要方向之一。它是传统电网与现代信息技术、通信技术和自动化技术深度融合的产物,旨在提高电力系统的效率、可靠性和可持续性。智能电网不仅仅是一个技术创新的名词,更是一个系统…

全域推广什么意思?如何搭建高效优质的全域推广服务商系统?

当前,全域推广一词的热度日渐升高,越来越多的人开始关注和计划入局这一全新项目,希望能够吃到第一波红利。不过,由于这一项目刚刚兴起,相关资料尚不完善,因此,绝大多数有意向入局的人都对该项目…

创客项目秀 | 基于使用 XIAO BLE Sense 和 Edge Impulse 的宠物活动跟踪器

今天为大家带来的是来自美国的创作者米顿-达斯的作品:宠物活动跟踪器.这个装置主要是为宠物主人提供关于宠物日常活动量的详尽数据,还能够根据宠物的独特需求,提供个性化的健康建议和活动指导。 项目背景 为了全面促进宠物的健康与活力,采用…

来可电子CAN数据记录仪通过智诊小助手TF卡记录文件导出

若想将TF卡中记录的数据文件导出可按以下的流程进行配置: 点击主界面中的导出选项即可进入到下图中TF卡应用界面 点击TF卡应用界面中“查看记录文件”的选项,进入导出文件界面。 点击“选择”进入勾选文件的界面 点击“导出”后,点击“确定”…

Vulnhub打靶-napping

基本信息 靶机下载:https://download.vulnhub.com/napping/napping-1.0.1.ova 攻击机器:192.168.20.128(Windows操作系统)& 192.168.20.138(kali) 提示信息:甚至管理员也可以在工作中睡…

统信UOS与Windows11传输数据

原文连接:统信UOS与Windows11相互传输数据 hello,大家好啊,今天给大家带来一篇统信UOS与Windows11之间通过共享文件夹传输数据的方法,首先在Windows11上创建共享文件夹,然后通过smb协议在UOS上进行连接访问&#xff0c…

彻底解决IDEA SpringBoot项目yml文件没有小树叶,读取配置文件失败问题

报错说没有配置dubbo:application:name,其实是配置了的,就是读不到,那有没有可能是yml文件不是绿叶的问题?网上查了很多文章配置小绿叶,最后还是报这个错,而且网上的文章配置小绿叶也太过于繁琐,其实就一招…

【Java后端】之 ThreadLocal 详解

想象一下,你有一个工具箱,里面放着各种工具。在多人共用这个工具箱的时候,很容易出现混乱,比如有人拿走了你的锤子,或者你找不到合适的螺丝刀。为了避免这种情况,最好的办法就是每个人都有自己独立的工具箱…