【深度学习实验】前馈神经网络(八):模型评价(自定义支持分批进行评价的Accuracy类)

news2024/11/18 3:34:39

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

 0. 导入必要的工具包

1. __init__(构造函数)

2. update函数(更新评价指标)

5. accumulate(计算准确率)

4. reset(重置评价指标)

5. 构造数据进行测试

6. 代码整合


一、实验介绍

       本文将实现一个辅助功能——计算预测的准确率。Accuracy支持对每一个回合中每批数据进行评价,并将结果累积,最终获得整批数据的评价结果。

  • 在训练或验证过程中迭代地调用update方法来更新评价指标;
  • 使用accumulate方法获取累计的准确率;
  • 通过reset方法重置评价指标,以便进行下一轮的计算。

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

 0. 导入必要的工具包

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
  • DatasetDataLoader类用于处理数据集和数据加载

这段代码定义了一个名为Accuracy的类,用于支持分批进行模型评价,特别是在分类任务中计算准确率。

1. __init__(构造函数)

class Accuracy:
    def __init__(self, is_logist=True):
        self.num_correct = 0
        self.num_count = 0
        self.is_logist = is_logist
  • 构造函数在创建Accuracy对象时被调用。它接受一个可选的参数is_logist,默认为True,用于指示是否为logist形式的预测值。
  • self.num_correct用于记录正确预测的样本个数。
  • self.num_count用于记录总样本个数。
  • self.is_logist指示是否为logist形式的预测值。

2. update函数(更新评价指标)

def update(self, outputs, labels):
    if outputs.shape[1] == 1:
        outputs = outputs.squeeze(-1)
        if self.is_logist:
            preds = (outputs >= 0).long()
        else:
            preds = (preds >= 0.5).long()
    else:
        preds = torch.argmax(outputs, dim=1).long()
        
    labels = labels.squeeze(-1)
    batch_correct = (preds==labels).float().sum()
    batch_count = len(labels)
    self.num_correct += batch_correct
    self.num_count += batch_count
  • update方法用于更新评价指标。它接受两个参数outputslabels,分别表示模型的预测输出和真实标签。
  • 根据outputs的形状判断任务类型。
    •  如果outputs是二维张量且第二维大小为1,那么表示是二分类任务。
      •   如果is_logist=True,则将outputs通过阈值(0)转换为预测值preds,并将其转换为整数类型。
      •   如果is_logist=False,则将outputs通过阈值(0.5)转换为预测值preds,并将其转换为整数类型。
    •  如果outputs是二维张量且第二维大小大于1,表示是多分类任务。此时,将outputs中概率最大的类别作为预测值preds
  • labels去除多余的维度,并计算本批数据中预测正确的样本个数batch_correct
  • 获取本批数据的样本个数batch_count
  • 更新num_correctnum_count,累积计算正确样本个数和总样本个数。

5. accumulate(计算准确率)

def accumulate(self):
    if self.num_count == 0:
        return 0
    return self.num_correct / self.num_count
  • accumulate方法用于计算准确率。
    •  如果num_count为0,表示没有进行过更新,返回0。
    • 否则,返回正确样本个数除以总样本个数的比例,即准确率

4. reset(重置评价指标)

def reset(self):
    self.num_correct = 0
    self.num_count = 0
  • reset方法用于重置评价指标,将num_correctnum_count重置为0,以便进行下一轮评价

5. 构造数据进行测试

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

6. 代码整合

import torch


# 支持分批进行模型评价的 Accuracy 类
class Accuracy:
    def __init__(self, is_logist=True):
        # 正确样本个数
        self.num_correct = 0
        # 样本总数
        self.num_count = 0
        self.is_logist = is_logist

    def update(self, outputs, labels):
        # 判断是否为二分类任务
        if outputs.shape[1] == 1:
            outputs = outputs.squeeze(-1)
            # 判断是否是logit形式的预测值
            if self.is_logist:
                preds = (outputs >= 0).long()
            else:
                preds = (preds >= 0.5).long()
        else:
            # 多分类任务时,计算最大元素索引作为类别
            preds = torch.argmax(outputs, dim=1).long()

        # 获取本批数据中预测正确的样本个数
        labels = labels.squeeze(-1)
        batch_correct = (preds == labels).float().sum()
        batch_count = len(labels)
        # 更新
        self.num_correct += batch_correct
        self.num_count += batch_count

    def accumulate(self):
        # 使用累计的数据,计算总的评价指标
        if self.num_count == 0:
            return 0
        return self.num_correct / self.num_count

    def reset(self):
        self.num_correct = 0
        self.num_count = 0


y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1032646.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Yolov8的工业小目标缺陷检测(5):大缺陷小缺陷一网打尽的轻量级目标检测器GiraffeDet,暴力提升工业缺陷检测能力

💡💡💡本文改进:大小缺陷一网打尽的GiraffeDet,提升处理低分辨率图像和小物体等更困难的检测能力。 GiraffeDet | 亲测在工业小目标缺陷涨点明显,原始mAP@0.5 0.679提升至0.727 收录专栏: 💡💡💡深度学习工业缺陷检测 :http://t.csdn.cn/fVSgs ✨✨✨提供…

【一】Spring Cloud 系列简介

Spring Cloud 系列简介 简介:从单体架构到分布式架构,再到微服务架构,一路经历走来spring框架也一直在与时俱进,回顾下来感觉做Java开发就是基于spring开发,spring也一路发展出了spring boot,在此基础上发…

阿里云服务器u1和经济型e系列性能差异?哪个比较好?

阿里云服务器经济型e实例和云服务器u1有什么区别?同CPU内存配置下云服务器u1性能更强,u1实例价格也要更贵一些。经济型e实例属于共享型云服务器,不同实例vCPU会争抢物理CPU资源,并导致高负载时计算性能波动不稳定,而云…

LLMs资源

一、ChatGPT 《中科院学术专业版 ChatGPT》: gpt_academic项目针对了中科院日常科研工作,基于 ChatGPT 专属定制了一整套实用性功能,用于优化学术研究以及开发日常工作流程。其中内置的工具,包括但不限于以下这些:学术…

软件测试的理论基础1

软件的生命周期 可行性研究和计划(立项) 需求分析 概要设计(测试计划) 详细设计(测试方案) 实现(开发阶段;包含单元测试) 组装测试(集成测试) 确…

十六)Stable Diffusion教程:出图流程化

今天说一个流程化出图的案例,适用很多方面。 1、得到线稿,自己画或者图生图加线稿lora出线稿;如果想sd出图调整参数不那么频繁细致,则线稿的素描关系、层次、精深要表现出来,表现清楚。 2、文生图,seed随机…

kafka的 ack 应答机制

目录 一 ack 应答机制 二 ISR 集合 一 ack 应答机制 kafka 为用户提供了三种应答级别: all,leader,0 acks :0 这一操作提供了一个最低的延迟,partition的leader接收到消息还没有写入磁盘就已经返回ack&#x…

PyCharm 远程debug 快速上手

一、方法 1. 配置远程解释器(简单高效,强烈推荐!!!) 要求: 通过 SSH 从本地机器访问远程服务器,使用任何预定义的端口从远程服务器访问本地机器,最好关掉vpn等网络代理服务。 常见…

中通快递一键查询,轻松掌握物流信息

在如今的快速发展的电商时代,快递已成为人们生活中不可或缺的一部分。随着快递业务的繁荣,快递公司也纷纷推出了各种查询方式,方便顾客随时掌握自己包裹的物流信息。在这其中,中通快递无疑是许多人选择的首选。下面,我…

合并两个升序链表,哨兵位的理解

开始时也要判断是否有一个链表本来就是空,如果是,直接返回另外一个链表 代码: struct ListNode* mergeTwoLists(struct ListNode* list1, struct ListNode* list2){if(list1NULL){return list2;}if(list2NULL){return list1;} struct ListN…

仿互站资源商城平台系统源码多款应用模版

首先安装好环境,推荐用Linux宝塔 请示:安装前请先别开防火墙,和跨站篡改 第1步上传程序到服务器, 第2步修改数据库文件,config/config.php 第3步,导入数据,根目录的数据库文件夹里面 数据.s…

SpringCloud Alibaba-Seata

接上文 SpringCloud Alibaba - Sentinel 1.简介(Seata与分布式事务) Seata官方网址https://seata.io/zh-cn/docs/overview/what-is-seata.html 2.环境搭建 首先对之前的图书借阅系统进行升级: 编写对应的服务接口。 (1&#…

操作系统:体系结构

1.内核的划分 1.术语解释 时钟管理:利用时钟断实现计时功能。原语是一种特殊的程序,具有原子性。也就是说,这段程序的运行必须一气呵成,不可被“中断”Ubuntu、Centos的开发团队,其主要工作是实现非内核功能,而内核都是用了Linux内核。 内核…

el-table-column默认选中一个复选框和只能单选事件

表格代码 <el-table ref"contractTable" v-loading"loading" :data"contractList" selection-change"contractSelectionChange" style"margin-top: 10%;"><el-table-column type"selection" width"…

【Linux】系统编程线程互斥与同步(C++)

目录 【1】线程互斥 【1.1】进程线程间的互斥相关背景概念 【1.2】互斥量mutex 【1.3】互斥量实现原理探究 【1.4】RAII的加锁风格 【2】可重入VS线程安全 【2.1】概念 【2.2】常见的线程不安全的情况 【2.3】常见的线程安全的情况 【2.4】常见不可重入的情况 【2.5…

【golang】深入理解GMP调度模型

Goroutine Go中&#xff0c;协程被称为goroutine&#xff0c;它非常轻量&#xff0c;一个goroutine只占几KB&#xff0c;并且这几KB就足够goroutine运行完&#xff0c;这就能在有限的内存空间内支持大量goroutine&#xff0c;支持了更多的并发&#xff0c;虽然一个goroutine的…

基于YOLOv8模型的条形码二维码检测系统(PyTorch+Pyside6+YOLOv8模型)

摘要&#xff1a;基于YOLOv8模型的条形码二维码检测系统可用于日常生活中检测与定位条形码与二维码目标&#xff0c;利用深度学习算法可实现图片、视频、摄像头等方式的目标检测&#xff0c;另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测…

Ubuntu 12.04增加右键命令:在终端中打开增加打开文件

Ubuntu 12.04增加右键命令&#xff1a;在终端中打开 软件中心&#xff1a;搜索nautilus-open-terminal安装 用快捷键CtrlT打开命令行输入&#xff1a; sudo apt-get install nautilus-open-terminal 重新加载文件管理器 nautilus -q 或注销再登录即要使用

一文弄懂基于采样的路径规划-RRT系列(python代码)

基于采样的路径规划算法-RRT系列 VX关注晓理紫并回复rrt获取代码 [晓理紫] 1、基于采样的路径规划算法 基于抽样的规划方法&#xff08;或称概率方法&#xff09;通过在连续 C 空间中逐步或批量抽样&#xff0c;构建由离散 C 空间样本连接的树或图&#xff0c;从而捕捉解空间的…

飞书应用配置+蓝鲸流水线+jump server

开发者后台创建应用 配置应用基础信息&#xff0c;权限&#xff0c;安全等 管理后台 设置应用在工作台的可见范围和其他设置 Linux 常用命令&#xff1a;Linux 常用命令, ll 文件夹下文件&#xff0c;ls 文件&#xff0c;cd进入目录&#xff0c; cat 查看文件&#xff0c; v…