tensorboard与torchinfo的使用

news2025/1/12 17:37:24

目录

  • 1. tensorboard
    • 1.1 本地使用
    • 1.2 远程服务器使用
  • 2. torchinfo
  • Ref

1. tensorboard

1.1 本地使用

只需要掌握一个 torch.utils.tensorboard.writer.SummaryWriter 接口即可。

在初始化 SummaryWriter 的时候,通常需要指定log的存放路径。这个路径默认是 runs/CURRENT_DATETIME_HOSTNAME,其中

CURRENT_DATETIME_HOSTNAME = datetime.now().strftime("%b%d_%H-%M-%S") + "_" + socket.gethostname()

这里建议为每个实验单独开一个文件夹,例如 runs/exp1runs/exp2 等,所有的events文件都会存放在 log_dir

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir='./runs/exp1')

writer中用到最多的接口自然是 add_scalar,接下来也将重点讲解该接口的使用。如还需参考其他接口,可以移步至官方文档。

顾名思义,add_scalar 用于逐步添加标量以形成一条曲线(例如loss、acc曲线等),具体用法如下

writer.add_scalar(tag: str, scalar_value: float, global_step: int)

tag 是该曲线的标签(例如它可以是loss、acc、lr等),scalar_value 是这一个点的纵坐标,global_step 是这一个点的横坐标。

我们可以调用 add_scalar 来绘制一条 y = 2 x y=2x y=2x 的曲线,如下

import time
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./runs/exp1')
for i in range(20):
    writer.add_scalar('y=2x', 2 * i, i)
    time.sleep(2)
writer.close()

上述代码每隔2秒添加一个点,40秒之后整条曲线绘制完毕。

如需查看该曲线,打开本地终端,执行命令(既可以在绘制的过程中执行该命令也可以在绘制结束后执行该命令)

tensorboard --logdir=./runs/

📝 关于 --logdir 的官方解释:Directory where TensorBoard will look to find TensorFlow event files that it can display. TensorBoard will recursively walk the directory structure rooted at logdir, looking for .*tfevents.* files. A leading tilde will be expanded with the semantics of Python’s os.expanduser function.

然后在本地浏览器中打开 http://localhost:6006/ 即可,效果如下

我们还可以对 tag 进行分级,例如

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./runs/exp2')
for i in range(20):
    writer.add_scalar('straight_line/y=2x', 2 * i, i)
    writer.add_scalar('straight_line/y=3x', 3 * i, i)
    writer.add_scalar('parabola/y=x^2', i * i, i)
    writer.add_scalar('parabola/y=3x^2', 3 * i * i, i)
writer.close()

效果如下

我们还可以在同一个 tag 上添加多条曲线

import random
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./runs/exp3')
for i in range(5):
    slopes = [1, 2, 3, 4, 5]
    scope = random.randint(15, 40)
    for j in range(scope):
        writer.add_scalar('different_lines', slopes[i] * j, j)
writer.close()

效果如下

tensorboard的默认端口是6006,如果该端口事先被占用,则tensorboard无法正常启动。我们可以手动指定端口 --port=[指定的端口号] 也可以让系统自动寻找可用的端口 --port=0

1.2 远程服务器使用

由于远程服务器上通常是没有浏览器的,因此我们需要进行相应的配置才可以在本地的浏览器访问远程服务器上的tensorboard服务。

首先建立一个SSH隧道,在本地机器上执行

ssh -NfL 6006:localhost:6006 remote_username@remote_host

-N 代表不执行远程命令,-f 代表把SSH挂在后台。该命令将本地的6006端口转发至远程服务器的6006端口。之后,在远程服务器上启动tensorboard服务

tensorboard --logdir=/path/to/logs

最后,在本地机器上打开http://localhost:6006/即可。

注意,以上方式建立的SSH隧道会一直挂在后台,如果需要关闭隧道,可执行以下命令

pkill -f 'ssh -NfL'

如果不加 -Nf,则执行命令后会直接连接到远程服务器,当断开与远程服务器的连接后,SSH隧道也会随之关闭。

2. torchinfo

通常我们会使用如下语法来计算一个模型的总参数量

print(sum(p.numel() for p in model.parameters()))

但是这种方式并不能提供关于模型各层的详细信息。对于复杂的模型,我们可能更关心每层的参数量、输出形状、以及内存占用等信息。在这种情况下,我们可以使用 torchinfo

torchinfo 的前身是 torchsummary,安装方式如下

pip install torchinfo

同样地,只需要掌握一个 torchinfo.torchinfo.summary 接口即可。

查看一个模型的完整信息需要提供输入的形状(如果只想查看参数量则不必提供),且形状应当包含 batch_size 这个维度。例如,查看resnet18的信息需要提供形状为 ( N , C , H , W ) = ( N , 3 , 224 , 224 ) (N,C,H,W)=(N,3,224,224) (N,C,H,W)=(N,3,224,224) 的输入:

from torchinfo import summary
from torchvision.models import resnet18

model = resnet18()
batch_size = 32
summary(model, input_size=(batch_size, 3, 224, 224))

输出如下

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [32, 1000]                --
├─Conv2d: 1-1                            [32, 64, 112, 112]        9,408
├─BatchNorm2d: 1-2                       [32, 64, 112, 112]        128
├─ReLU: 1-3                              [32, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [32, 64, 56, 56]          --
├─Sequential: 1-5                        [32, 64, 56, 56]          --
│    └─BasicBlock: 2-1                   [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-1                  [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-2             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-3                    [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-4                  [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-5             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-6                    [32, 64, 56, 56]          --
│    └─BasicBlock: 2-2                   [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-7                  [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-8             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-9                    [32, 64, 56, 56]          --
│    │    └─Conv2d: 3-10                 [32, 64, 56, 56]          36,864
│    │    └─BatchNorm2d: 3-11            [32, 64, 56, 56]          128
│    │    └─ReLU: 3-12                   [32, 64, 56, 56]          --
├─Sequential: 1-6                        [32, 128, 28, 28]         --
│    └─BasicBlock: 2-3                   [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-13                 [32, 128, 28, 28]         73,728
│    │    └─BatchNorm2d: 3-14            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-15                   [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-16                 [32, 128, 28, 28]         147,456
│    │    └─BatchNorm2d: 3-17            [32, 128, 28, 28]         256
│    │    └─Sequential: 3-18             [32, 128, 28, 28]         8,448
│    │    └─ReLU: 3-19                   [32, 128, 28, 28]         --
│    └─BasicBlock: 2-4                   [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-20                 [32, 128, 28, 28]         147,456
│    │    └─BatchNorm2d: 3-21            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-22                   [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-23                 [32, 128, 28, 28]         147,456
│    │    └─BatchNorm2d: 3-24            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-25                   [32, 128, 28, 28]         --
├─Sequential: 1-7                        [32, 256, 14, 14]         --
│    └─BasicBlock: 2-5                   [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-26                 [32, 256, 14, 14]         294,912
│    │    └─BatchNorm2d: 3-27            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-28                   [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-29                 [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-30            [32, 256, 14, 14]         512
│    │    └─Sequential: 3-31             [32, 256, 14, 14]         33,280
│    │    └─ReLU: 3-32                   [32, 256, 14, 14]         --
│    └─BasicBlock: 2-6                   [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-33                 [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-34            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-35                   [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-36                 [32, 256, 14, 14]         589,824
│    │    └─BatchNorm2d: 3-37            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-38                   [32, 256, 14, 14]         --
├─Sequential: 1-8                        [32, 512, 7, 7]           --
│    └─BasicBlock: 2-7                   [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-39                 [32, 512, 7, 7]           1,179,648
│    │    └─BatchNorm2d: 3-40            [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-41                   [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-42                 [32, 512, 7, 7]           2,359,296
│    │    └─BatchNorm2d: 3-43            [32, 512, 7, 7]           1,024
│    │    └─Sequential: 3-44             [32, 512, 7, 7]           132,096
│    │    └─ReLU: 3-45                   [32, 512, 7, 7]           --
│    └─BasicBlock: 2-8                   [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-46                 [32, 512, 7, 7]           2,359,296
│    │    └─BatchNorm2d: 3-47            [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-48                   [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-49                 [32, 512, 7, 7]           2,359,296
│    │    └─BatchNorm2d: 3-50            [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-51                   [32, 512, 7, 7]           --
├─AdaptiveAvgPool2d: 1-9                 [32, 512, 1, 1]           --
├─Linear: 1-10                           [32, 1000]                513,000
==========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 58.05
==========================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 1271.92
Params size (MB): 46.76
Estimated Total Size (MB): 1337.94
==========================================================================================

可以看到,resnet18的大小为11M。一个浮点数占用四个字节,那么所有参数将占用 11689512 ⋅ 4 / 1 0 6 ≈ 46.76 11689512\cdot4/10^6\approx46.76 116895124/10646.76 MB的内存。输入占用19.27MB的内存,进行一次前向/反向传播占用1271.92MB的内存(因为要保存中间变量),所以在训练resnet18的过程中将总共占用1337.94MB的内存。

有些时候,我们不想立即让 summary 打印出模型的信息,因此可以设置 verbose=0,等到合适的时候进行打印

model_stats = summary(model, input_size=(batch_size, 3, 224, 224), verbose=0)
# ...
print(model_stats)

如果模型需要提供多个输入,我们可以为 input_size 提供 List[Tuple[int, ...]] 的格式,例如Transformer

from torchinfo import summary
import torch

model = torch.nn.Transformer()
batch_size = 128
src_len, tgt_len = 256, 256
embed_dim = 512
summary(model, input_size=[(src_len, batch_size, embed_dim), (tgt_len, batch_size, embed_dim)])

输出如下

====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
Transformer                                        [256, 128, 512]           --
├─TransformerEncoder: 1-1                          [256, 128, 512]           --
│    └─ModuleList: 2-1                             --                        --
│    │    └─TransformerEncoderLayer: 3-1           [256, 128, 512]           3,152,384
│    │    └─TransformerEncoderLayer: 3-2           [256, 128, 512]           3,152,384
│    │    └─TransformerEncoderLayer: 3-3           [256, 128, 512]           3,152,384
│    │    └─TransformerEncoderLayer: 3-4           [256, 128, 512]           3,152,384
│    │    └─TransformerEncoderLayer: 3-5           [256, 128, 512]           3,152,384
│    │    └─TransformerEncoderLayer: 3-6           [256, 128, 512]           3,152,384
│    └─LayerNorm: 2-2                              [256, 128, 512]           1,024
├─TransformerDecoder: 1-2                          [256, 128, 512]           --
│    └─ModuleList: 2-3                             --                        --
│    │    └─TransformerDecoderLayer: 3-7           [256, 128, 512]           4,204,032
│    │    └─TransformerDecoderLayer: 3-8           [256, 128, 512]           4,204,032
│    │    └─TransformerDecoderLayer: 3-9           [256, 128, 512]           4,204,032
│    │    └─TransformerDecoderLayer: 3-10          [256, 128, 512]           4,204,032
│    │    └─TransformerDecoderLayer: 3-11          [256, 128, 512]           4,204,032
│    │    └─TransformerDecoderLayer: 3-12          [256, 128, 512]           4,204,032
│    └─LayerNorm: 2-4                              [256, 128, 512]           1,024
====================================================================================================
Total params: 44,140,544
Trainable params: 44,140,544
Non-trainable params: 0
Total mult-adds (G): 6.46
====================================================================================================
Input size (MB): 134.22
Forward/backward pass size (MB): 12348.03
Params size (MB): 100.92
Estimated Total Size (MB): 12583.17
====================================================================================================

可以看到,transformer的大小为44M,训练时将占用12GB的内存。

Ref

[1] https://stackoverflow.com/questions/37987839/how-can-i-run-tensorboard-on-a-remote-server
[2] https://github.com/TylerYep/torchinfo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/735862.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python脚本小工具之文件与内容搜索

目录 一、前言 二、代码 三、结果 一、前言 ​日常工作中,经常需要在指定路径下查找指定类型的文件,或者是指定内容的查找,在window环境中,即可以使用一些工具如notepad或everything,也可以使用python脚本。但在l…

【C++进阶】bitset位图介绍以及模拟实现

文章目录 位图介绍一、位图的引入二、位图的概念 位图模拟实现一、构造函数二、set,reset,test函数三、代码测试四、完整代码 位图介绍 一、位图的引入 先来看下边一道面试题: 给40亿个不重复的无符号整数,没排过序。给一个无符…

SAR ADC version2 ——ADC背景介绍

目录: ADC常用指标分类 静态性能:微分非线性:DNL 积分非线性:INL 仿真测试DNL:(码密度法)(code density&…

OpenCV 入门教程:像素访问和修改

OpenCV 入门教程:像素访问和修改 导语一、像素访问1.1 获取图像的大小1.2 访问图像的像素值1.3 修改图像的像素值 二、示例应用2.1 图像反转2.2 阈值化操作 三、总结 导语 在图像处理和计算机视觉领域,像素级操作是非常重要和常见的任务之一。通过像素访…

Python——将F12得到的请求头转换成其对应json格式

问题引入 最近在鼓捣爬虫准备爬爬学校网站,用到pthon的requests库发送get请求时需要提供headers. 需要将请求头转换成json格式的数据。json格式如下所示 headers{"Path":"xxx","User-Agent":"xxx" } 但是从网页上f12复…

21-注册中心与配置中心Nacos

已经使用过了Spring cloud提供的Geteway、openFeign。 1、注册中心与配置中心 1.1、注册中心 相当于通讯录,让应用之间相互认识。 用途: 实例的健康检查。 路由转发:为了控制成本,会对机器做动态扩容,此时IP就不固定了。 远程调用。 1.2、配置中心 动态修改线上的配…

深入解析MySQL视图、索引、数据导入导出:优化查询和提高效率

目录 1. 视图(View): 什么是视图? 为什么要使用视图? 视图的优缺点 1) 定制用户数据,聚焦特定的数据 2) 简化数据操作 3) 提高数据的安全性 4) 共享所需数据 5) 更改数据格式 6) 重用 SQL 语句 …

十一.Redis发布订阅

Redis发布订阅(pub/sub)是一种消息通信模式:发布者(pub)发送消息,订阅者(sub)接受消息。此种模式下,消息发布者和订阅者不进行直接通信,发布者客户端向指定的频道(channel) 发布消息,订阅该频道…

【MQTT】Esp32数据上传采集:最新mqtt插件(支持掉线、真机调试错误等问题)

前言 这是我在Dcloud发布的插件-最完整Mqtt示例代码(解决掉线、真机调试错误等问题),经过整改优化和替换Mqtt的js文件使一些市场上出现的问题得以解决,至于跨端出问题,可能原因有很多,例如,合法…

MySQL基础篇第3章(基本的SELECT语句)

文章目录 1、SQL概述1.1 SQL背景知识1.2 SQL分类 2、SQL语言的规则与规范2.1 基本规则2.2 SQL大小写规范 (建议遵守)2.3 注释2.4 命名规则2.5 数据导入指令 3、基本的SELECT语句3.0 SELECT...3.1 SELECT...FROM3.2 列的别名3.3 去除重复行3.4 空置参与运…

营销人累了看看这5部影片吧!保你再燃激情

市场瞬息万变,做营销需不断学习充电,除了看书听课之外看电影也是学习营销的有效方式。今天小马识途营销顾问给大家推荐5部市场营销人员必看的高评分电影,相信看完之后,会对你今后的发展影响深远!话不多说直接上干货&am…

并发编程 - Event Driven 设计模式(EDA)

文章目录 EDA 概述初体验EventEvent HandlersEvent Loop 如何设计一个Event-Driven框架同步EDA框架设计MessageChannelDynamic RouterEventEventDispatcher测试同步EDA架构类图 异步EDA框架设计抽象基类 AsyncChannelAsyncEventDispatcher 并发分发消息测试 EDA 概述 EDA&…

【计算机网络】第 2 课 - 计算机网络的性能指标

欢迎来到博主 Apeiron 的博客,祝您旅程愉快 ! 时止则止,时行则行。动静不失其时,其道光明。 目录 1、缘起 2、性能指标 2.1、速率 2.2、带宽 2.3、吞吐量 2.4、时延 2.5、时延带宽积 2.6、往返时间 2.7、利用率 2.8、丢…

【Cartopy学习系列】Cartopy中的投影类型总结

一、PlateCarree(圆柱投影) PlateCarree 是Cartopy的默认投影,投影将地物投影到圆柱面上再展开,常用来绘制世界地图。该投影具有经线或纬线方向等度数的特点,亦称等经纬度投影。 class cartopy.crs.PlateCarree(cent…

【Kafka】Kafka消费者

【Kafka】Kafka消费者 文章目录 【Kafka】Kafka消费者1. 消费方式1.1 消费者工作流程1.2 消费者组原理1.3 消费者组初始化流程1.4 消费者组详细消费流程1.5 消费者重要参数 2. 消费者API2.1 独立消费者案例2.2 订阅分区2.3 消费者组案例 1. 消费方式 pull(拉)模式:…

Linux上查看外接USB设备类型

最近遇到一个问题,需要在shell脚本中识别当前显示器的USB触屏线是否插入,并读取显示器名称,以确定是否是想要的。 解决思路: lsusb命令可以列出所有的外接USB设备: 其中 “Atmel Corp. Atmel maXTouch Digitizer” 即为…

rabbitmq使用springboot实现direct模式

一、 Direct模式 类型&#xff1a;direct特点&#xff1a;Direct模式是fanout模式上的一种叠加&#xff0c;增加了路由RoutingKey的模式。 二、coding Ⅰ 生产者 1、引入相应的pom文件 pom.xml <?xml version"1.0" encoding"UTF-8"?> <pro…

Linux 学习记录48(QT篇待完成)

Linux 学习记录48(QT篇) 本文目录 Linux 学习记录48(QT篇)一、1.2. 二、三、四、练习1. 自制文本编辑器(0. main.cpp(1. txt_window.h(2. txt_window.cpp 2. 登录界面完善 一、 1. 2. 二、 三、 四、 练习 1. 自制文本编辑器 (0. main.cpp #include "txt_window.h…

JavaWeb 笔记——5

JavaWeb 笔记——5 一、Filter1.1、概述1.2、Filter快速入门1.3、Filter执行流程1.4、Filter使用细节1.5、Filter-案例-登陆验证 二、Listener2.1、Listener概述与分类2.2、ServletContextListener使用 三、AJAX3.1、AJAX概述3.2、AJAX快速入门3.3、使用Ajax验证用户名是否存在…

《阿里大数据之路》研读笔记(3)事实表

不理解可以先看看这个例子 例子里的start_time可以看成下单时间 end看成确认收货时间 这个例子中累计快照事实表和拉链表类似 图解HIVE累积型快照事实表_累积快照事实表_小基基o_O的博客-CSDN博客 累计快照事实表 我的理解是 根据上面的例子 就是一行代表多个业务过程 每个…