torchnet简介

news2024/11/23 21:57:25

前言

最近项目开发过程中遇到了 t o r c h n e t . m e t e r torchnet.meter torchnet.meter来记录模型信息,搜了好多篇博客,都潦草草没有一点干货。于是参考了官方文档以及参考代码,根据自己的理解,在此做了一个其的使用教程:

torchnet简介

torchent是torch框架的一种,其提供了一套抽象的概念,旨在鼓励代码复用和模块化编程。提供了四个重要的类:
在这里插入图片描述
每个meter的子类都有三个方法:
在这里插入图片描述

Classification Meters

APMeter

APMeter计算每个类的AP,即平均精度average precision。

import torch
from torch.nn import functional as F
from torchnet import meter as tnt


seed = 1024
torch.manual_seed(seed)

# 128条数据, 10个类别
size = (128, 10)
output = torch.rand(size=size)
output = F.softmax(output, dim=1)
target = torch.randint(0, 2, size=size)

aper = tnt.APMeter()
aper.add(output, target)

"""
add(output, target, weight=None):
	output: 模型的输出, 是一个NxK的tensor, 表示每个类别的概率, N表示样本数目, K表示类别数目, 所有类别的概率总和应为1
	target: 样本的标签, 是一个NxK的二进制tensor, 即其值只能是0(负样本)或1(正样本)
	weight: 可选参数
"""

print('AP: ', aper.value().numpy())
print('mAP: ', aper.value().sum().numpy() / 10)
# AP:  [0.535214  0.6198798  0.59850764  0.527964  0.4984482  0.5188082  0.5916564  0.41430935  0.48577505  0.41956347]
# mAP:  0.5210125923156739

mAPMeter

mAPMeter计算所有类别的mAP,即平均AP。

import torch
from torch.nn import functional as F
from torchnet import meter as tnt


seed = 1024
torch.manual_seed(seed)

# 128条数据, 10个类别
size = (128, 10)
output = torch.rand(size=size)
# output /= output.sum(dim=1).unsqueeze(dim=1).expand(size=size)
output = F.softmax(output, dim=1)
target = torch.randint(0, 2, size=size)

maper = tnt.mAPMeter()
maper.add(output, target)

"""
add(output, target, weight=None):
	output: 模型的输出, 是一个NxK的tensor, 表示每个类别的概率, N表示样本数目, K表示类别数目, 所有类别的概率总和应为1
	target: 样本的标签, 是一个NxK的二进制tensor, 即其值只能是0(负样本)或1(正样本)
	weight: 可选参数
"""

print('mAP: ', maper.value().numpy())
# mAP:  0.5210126

ClassErrorMeter

计算模型的accuracy,即准确率。

import torch
from torch.nn import functional as F
from torchnet import meter as tnt


seed = 1024
torch.manual_seed(seed)

# 128条数据, 10个类别
size = (128, 10)
output = torch.randn(size=size)
output = F.softmax(output, dim=1)
target = torch.randint(0, 10, size=(128,))

# 计算acc1和acc5, 默认计算acc1, 即topk=[1]
classer = tnt.ClassErrorMeter(topk=[1, 5], accuracy=True)
classer.add(output, target)

"""
add(output, target):
	output: 模型的输出, 是一个NxK的tensor, 表示每个类别的概率, N表示样本数目, K表示类别数目, 所有类别的概率总和应为1
	target: 样本的标签, 是一个长度为N的tensor, 标签id从0开始
"""

print('acc1: {0}%, acc5: {1}%'.format(classer.value()[0], classer.value()[1]))
# acc1: 11.71875%, acc5: 56.25%

ConfusionMeter

ConfusionMeter计算多分类模型的confusion matrix,即混淆矩阵。不支持 multi-label和multi-class问题,对于这类问题可以使用 MultiLabelConfusionMeter

import torch
from torch.nn import functional as F
from torchnet import meter as tnt


seed = 1024
torch.manual_seed(seed)

size = (128, 10)
output = torch.randn(size=size)
output = F.softmax(output, dim=1)
target = torch.randint(0, 10, size=(128,))

# k表示类别的数目, normalized表示是否对混淆矩阵进行归一化, 默认False
confer = tnt.ConfusionMeter(k=10, normalized=False)
confer.add(output, target)

"""
add(output, target):
	output: 模型的输出, 是一个NxK的tensor, 表示每个类别的概率, N表示样本数目, K表示类别数目, 所有类别的概率总和应为1
	target: 样本的标签, 是一个长度为N的tensor, 标签id从0开始
"""

print('confusion matrix: \n', confer.value())
# confusion matrix: 
 [[1 1 2 1 0 0 0 1 1 0]
 [0 1 1 4 1 1 1 0 2 2]
 [0 2 2 1 0 1 0 2 1 2]
 [1 0 1 0 1 5 0 1 3 0]
 [3 3 3 1 1 1 3 1 0 3]
 [1 2 1 2 2 3 3 0 4 1]
 [1 1 1 0 1 1 0 2 0 2]
 [1 1 1 1 0 1 2 2 1 2]
 [2 0 0 0 1 0 1 2 2 3]
 [1 1 3 1 1 1 0 1 3 3]]

# normalized=True 
# confusion matrix: 
 [[0.14285715 0.14285715 0.2857143  0.14285715 0.         0.         0.         0.14285715 0.14285715 0.        ]
 [0.         0.07692308 0.07692308 0.30769232 0.07692308 0.07692308 0.07692308 0.         0.15384616 0.15384616]
 [0.         0.18181819 0.18181819 0.09090909 0.         0.09090909 0.         0.18181819 0.09090909 0.18181819]
 [0.08333334 0.         0.08333334 0.         0.08333334 0.41666666 0.         0.08333334 0.25       0.        ]
 [0.15789473 0.15789473 0.15789473 0.05263158 0.05263158 0.05263158 0.15789473 0.05263158 0.         0.15789473]
 [0.05263158 0.10526316 0.05263158 0.10526316 0.10526316 0.15789473 0.15789473 0.         0.21052632 0.05263158]
 [0.11111111 0.11111111 0.11111111 0.         0.11111111 0.11111111 0.         0.22222222 0.         0.22222222]
 [0.08333334 0.08333334 0.08333334 0.08333334 0.         0.08333334 0.16666667 0.16666667 0.08333334 0.16666667]
 [0.18181819 0.         0.         0.         0.09090909 0. 0.09090909 0.18181819 0.18181819 0.27272728]
 [0.06666667 0.06666667 0.2        0.06666667 0.06666667 0.06666667 0.         0.06666667 0.2        0.2       ]]

Regression/Loss Meters

AverageValueMeter

AverageValueMeter计算均值和标准差

from torchnet import meter as tnt


avger = tnt.AverageValueMeter()

for i in range(10):
    avger.add(i)

"""
add(value):
	value: 一个数值
"""

print('mean: {0}, std: {1}'.format(avger.value()[0], avger.value()[1]))
# mean: 4.5, std: 3.0276503540974917

AUCMeter

计算AUC,即ROC曲线下的面积,用于二分类。

import torch
from torch.nn import functional as F
from torchnet import meter as tnt


seed = 1024
torch.manual_seed(seed)

size = (128, )
output = torch.randn(size=size)
output = F.sigmoid(output)
target = torch.randint(0, 2, size=size)

aucer = tnt.AUCMeter()
aucer.add(output, target)

"""
add(output, target):
	output: 模型的输出分数, 是一个一维的tensor
	target: 样本的标签, 也是一个一维的tensor, 其值只能是0(负样本)或1(正样本)
"""

print('AUC: ', aucer.value()[0])
# AUC:  0.5208791208791209

MovingAverageValueMeter

计算当前状态前的windowsize个数的均值和标准差。即计算最后windowsize个数的均值和标准差

from torchnet import meter as tnt


# windowsize 需要计算的个数
mavger = tnt.MovingAverageValueMeter(windowsize=5)

for i in range(10):
    mavger.add(i)

"""
add(value):
	value: 一个数值
"""

print('mean: {0}, std: {1}'.format(mavger.value()[0].item(), mavger.value()[1]))
# mean: 7.0, std: 1.5811388300841898

MSEMeter

计算模型的MSE,即均方误差。

from torchnet import meter as tnt


seed = 1024
torch.manual_seed(seed)

size = (128, 10)
output = torch.randint(0, 10, size=size)
target = torch.randint(0, 10, size=size)

mser = tnt.MSEMeter(root=False)
mser.add(output, target)

"""
add(output, target):
	output: 模型的输出类别, 是一个NxK的tensor
	target: 样本的标签, 也是一个NxK的tensor
"""

print('MSE: ', mser.value().item())
# MSE:  17.3515625

Miscellaneous Meters

TimeMeter

用来计算模型处理数据的时间。

from torchnet import meter as tnt


def my_model():
    tmp = 1
    for i in range(10000000):
        tmp *= 1024 * 10.24 * (i+1)

# unit=False, 统计总的消耗时间
# unit=True, 统计平均消耗时间
timer = tnt.TimeMeter(unit=False)
for epoch in range(10):
    my_model()
    # timer.value()

print('all time: ', timer.value())
# all time:  8.787968158721924

总结

慢慢的将这个库都会使用,以及会自己总结经验都行啦的额样子与打算。
在这里插入图片描述

  • 代码复用和模块化编程的框架都搞清楚

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/64883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

flink1.13.2 Streaming File Sink产生大量orc小文件的问题解决方案

Orc小文件合并问题 Orc小文件合并问题 现象:hdfs中出现大量ORC小文件 1.1. 已经映射为hive表ORC小文件合并 1.1.1. 非分区表 alter table 表名 concatenate; 示例: alter table ods_lxy_demo concatenate; 注意:可多次重复执行,每执行一次就会做一次文件合并,执行多次最终…

gitee/github上传远程仓库错误usage: git remote add [<options>] <name> <url>

gitee/github上传远程仓库错误gitee/github上传远程仓库错误错误截图版本错误出现时间错误检查及解决1.网址中含有空格2.关闭翻译软件3.git bash自身问题gitee/github上传远程仓库错误 不知道大家最近有没有碰到这个错误usage: git remote add [<options>] <name>…

[附源码]计算机毕业设计JAVA疫情防控期间网上教学管理

[附源码]计算机毕业设计JAVA疫情防控期间网上教学管理 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM…

Uniapp云开发(Uniapp入门)

前言&#xff1a; 今天这篇文章主要讲解的是Uniapp云开发基础&#xff0c;有了Uniapp云开发&#xff0c;我们就不用需要后端&#xff0c;前端自己就可以实现增删改查。还有就是案例很重要&#xff0c;一定要看&#xff0c;自己去尝试运行试试。 目录超详细一. 什么是Uniapp云开…

分布式配置中心Apollo

Apollo&#xff08;阿波罗&#xff09;是携程框架部门研发的分布式配置中心&#xff0c;能够集中化管理应用不同环境、不同集群的配置&#xff0c;配置修改后能够实时推送到应用端&#xff0c;并且具备规范的权限、流程治理等特性&#xff0c;适用于微服务配置管理场景。 服务…

【Python百日进阶-数据分析】Day326 - plotly.express.scatter_geo():地理散点图

文章目录一、scatter_geo语法二、参数三、返回类型四、实例4.1 常规地理散点图4.2自定义地理散点图4.3GeoPandas 的基本示例一、scatter_geo语法 plotly.express.scatter_geo(data_frame None ,lat None ,lon None ,locations None ,locationmode None ,geojson None , …

MuziDB数据库-0.项目描述

前言 该项目写完也有一段时间了&#xff0c;为了避免以后忘记该项目的一些实现的原理&#xff0c;所以写下这篇博客来记录一下该项目的设计等 项目整体 MuziDB分为前端与后端&#xff0c;前后端交互通过socket进行交互&#xff0c;前端的作用就是读取用户输入并发送到后端进…

mybatis开发要点-insert主键ID获取和多参数传递

1.2、代码示例 二、查询如何传入多个参数 1、使用map传递参数&#xff1b; 2、使用注解传递参数&#xff1b; 3、使用Java Bean的方式传递参数&#xff1b; 一、插入数据主键ID获取 一般我们在做业务开发时&#xff0c;经常会遇到插入一条数据并使用到插入数据的ID情况。如…

网络安全事件应急演练方案

文章目录1 总则1.1 应急演练定义1.2 应急演练目的1.3 应急演练原则1.4 应急演练分类1.4.1 按组织形式划分1.4.2 按内容划分1.4.3 按目的与作用划分1.4.4 按组织范围划分1.5 应急演练规划2 应急演练组织机构2.1 组织单位2.1.1 领导小组2.1.2 策划小组2.1.3 保障小组2.1.4 评估小…

jvm参数造成http请求Read time out

问题描述 线上部署的代码&#xff0c;部署在测试环境突然抛出接口请求Read time out的异常。查看线上日志&#xff0c;接口请求正常。重新启动&#xff0c;部署测试环境代码都没有效果&#xff0c;接口还是必现Read time out。 原因分析&#xff1a; 1. 排查网络原因 直接在…

内核开发-同步场景与概念

进程上下文执行环境还有中断上下文执行环境&#xff0c;并且中断上下文优先级比较高&#xff0c;可以随时打断进程的执行&#xff0c;因此情况更加复杂。内核当中提供了不同的同步机制。比如说信号量&#xff0c;自旋锁&#xff0c;rcu&#xff0c;原子变量等等。他们各自都有自…

《计算机视觉技术与应用》-----第六章 直方图

系列文章目录 提示&#xff1a;这里可以添加系列文章的所有文章的目录&#xff0c;目录需要自己手动添加 例如&#xff1a;第一章 Python 机器学习入门之pandas的使用 提示&#xff1a;写完文章后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目…

软件过程模型

软件过程软件过程:获得高质量软件的一系列任务框架瀑布模型:特点:顺序,依赖,推迟实现,质量保证优点:规范方法,规定文档,阶段质量验证缺点:开发初期困难,需求验证困难,难以维护快速原型优点:满足需求,线性过程缺点:设计困难,原型理解不同,不利于创新增量模型:优点:短时间可完成部…

[附源码]Python计算机毕业设计SSM健身房管理系统设计(程序+LW)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

[附源码]Python计算机毕业设计Django体育馆场地预约管理系统

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

VSCode 配置C语言环境 全程记录 ,配置成功

目录 1.vscode介绍&#xff1a; 1.1 卸载干净VSCode 1.2安装VSCode 1.2.1 下载安装 1.2.2 vscode 小插件安装 2. 配置vscode 编译器 2.1 下载编译器资源文件&#xff1a; 2.2 配置环境变量 2.3 vscode项目文件配置 1. 首先新建一个.c文件&#xff0c;命名为英文哦 2. 然后…

含有DBCO和马来酰亚胺基团Mal-PEG2-DBCO,2698339-31-8,DBCO-PEG2-Maleimide

中英文别名&#xff1a; CAS号&#xff1a;2698339-31-8 | 英文名&#xff1a;DBCO-PEG2-Maleimide&#xff0c;Mal-PEG2-DBCO |中文名&#xff1a;二苯并环辛炔-二聚乙二醇-马来酰亚胺物理参数&#xff1a; CASNumber&#xff1a;2698339-31-8 Molecular formula&#xff1a;C…

工业和信息化部公布45个国家先进制造业集群名单

近日&#xff0c;工业和信息化部正式公布45个国家先进制造业集群的名单。 45个国家级集群2021年主导产业产值达19万亿元&#xff0c;布局建设了18家国家制造业创新中心&#xff0c;占全部国家级创新中心数量的70%&#xff0c;拥有国家级技术创新载体1700余家&#xff0c;培育创…

【苹果相册推iMessage】软件安装Websocket可以在浏览器顶用于支持两个通讯并使用它

推荐内容IMESSGAE相关 作者推荐内容iMessage苹果推软件 *** 点击即可查看作者要求内容信息作者推荐内容1.家庭推内容 *** 点击即可查看作者要求内容信息作者推荐内容2.相册推 *** 点击即可查看作者要求内容信息作者推荐内容3.日历推 *** 点击即可查看作者要求内容信息作者推荐…

Python3,9行代码,对比两个Excel数据差异,并把差异结果重新保存。

Excel数据差异对比1、引言2、代码实战3、总结1、引言 小屌丝&#xff1a;鱼哥&#xff0c;还记得上次写的把数据库的查询结果写入到excel这个脚本不。 小鱼&#xff1a;嗯… 可以说不记得吗 小屌丝&#xff1a;我猜你就记得。 小鱼&#xff1a;你…说…啥&#xff1f;&#xf…