使用TensorBoard进行可视化

news2024/9/24 21:19:46

1. TensorBoard介绍

TensorBoard是TensorFlow推出的可视化工具,可以可视化模型结构、跟踪并以表格形式显示模型指标。

TensorBoard的使用包括两个步骤:

  1. 在代码中设置TensorBoard,在训练的过程中将会根据设置产生日志文件
  2. 在浏览器中可视化该日志文件,查看网络结构、loss的变化情况等

下面以 LeNet-5 为例,介绍如何在TensorFlow和PyTorch中配置TensorBoard。

2. 代码中设置TensorBoard

2.1 TensorFlow中设置

2.1.1 使用说明

在TensorFlow通过两个简单步骤即可使用TensorBoard(更多的功能可参考官方文档)。

  1. tensorboard_callback = keras.callbacks.TensorBoard(log_dir='logs/tf') :初始化日志文件。其中,log_dir 是日志文件的存放位置

  2. model.fit(callbacks=[tensorboard_callback]) :设置回调函数,在模型训练期间将会调用 tensorboard_callback 从而向日志文件中写入数据

    注意:

    日志文件的绝对路径中不能包含中文

2.1.2 代码实现

LeNet-5 的搭建可参考[TensorFlow搭建神经网络]https://blog.csdn.net/qq_41100617/article/details/132122966)

import time

import tensorflow as tf
from keras import datasets
from keras import layers
from tensorflow import keras


def my_model(input_shape):
    # 首先,创建一个输入节点
    inputs = keras.Input(input_shape)

    # 搭建神经网络
    x = layers.Conv2D(filters=6, kernel_size=(5, 5), strides=(1, 1), activation='relu')(inputs)
    x = layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2))(x)
    x = layers.Conv2D(filters=16, kernel_size=(5, 5), strides=(1, 1), activation='relu')(x)
    x = layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2))(x)

    x = layers.Flatten()(x)

    x = layers.Dense(units=16 * 4 * 4, activation='relu')(x)
    x = layers.Dense(units=120, activation='relu')(x)

    # 输出层
    outputs = layers.Dense(units=10, activation='softmax')(x)

    model = keras.Model(inputs=inputs, outputs=outputs)

    return model


(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

train_images = tf.reshape(train_images, (train_images.shape[0], train_images.shape[1], train_images.shape[2], 1))
train_images = tf.cast(train_images, tf.float32)

test_images = tf.reshape(test_images, (test_images.shape[0], test_images.shape[1], test_images.shape[2], 1))
test_images = tf.cast(test_images, tf.float32)

model = my_model(train_images.shape[1:])

loss = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.SGD(0.0001)
model.compile(loss=loss, optimizer=keras.optimizers.SGD(0.00001))

# 初始化日志文件
tensorboard_callback = keras.callbacks.TensorBoard(
    log_dir='logs/tf/' + time.strftime('%m-%d-%H-%M', time.localtime(time.time())))

# 在训练过程中设置回调
model.fit(train_images, train_labels, validation_split=0.3, epochs=1000, batch_size=20,
          callbacks=[tensorboard_callback])

pre_labels = model.predict(test_images)

2.2 PyTorch中设置

2.2.1 使用说明

PyTorch1.1之后添加了TensorBoard,在PyTorch中使用TensorBoard主要有3个步骤(具体信息可查看其官方文档):

  1. SummaryWriter :创建一个日志文件,主要有以下两个参数:
    • log_dir :保存日志文件的目录,默认值为 runs/当前日期时间存放的路径中不能包含中文!)
    • commentlog_dir 取默认值时添加到 log_dir 后面的后缀,当设置了 log_dir 时,该值无效
  2. add_XXX: 向创建好的日志文件里面添加数据,常用的有:
    • add_graph :添加GRAPHS,其中存放了网络结构。有以下两个常用参数:
      • model :模型
      • input_to_model :模型的输入数据
    • add_scalar : 添加SCALARS,其中可以存放折线图数据用于显示训练过程中的损失值变化情况。有以下三个常用参数:
      • tag :折线图的标签。指定tag时,加入 / 分割,可将多个折线图放在一个tag下,如 :tag=train/loss, tag=train/map,此时train下面会有loss和map两个折线图
      • scalar_value :折线图纵轴的值,一般是loss、准确率等数据
      • global_step :折线图的横轴,一般是epoch
    • add_scalars :在一张折线图中同时绘制多个数据,有以下三个常用参数:
      • main_tag, :与 add_scalar 中的 tag 用法一样
      • tag_scalar_dict :纵轴的值,因为是多个数据,需要以字典形式传入(具体看下面代码实现)
      • global_step :与 add_scalar 中的 global_setp 用法一样
  3. close:结束log写入,一般用在训练结束后

类似于操作文件,SummaryWriter 也可以和 with 一起使用,此时可以不需显示调用 close ,如:

writer = SummaryWriter()
# 写入数据……
writer.close()

#上面的代码等价于下面

with SummaryWriter() as w:
    # 写入数据……

2.2.2 代码实现

LeNet-5 的搭建可参考PyTorch搭建神经网络

import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from sklearn.metrics import classification_report
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from tqdm import tqdm


class LeNet(nn.Module):
    def __init__(self, in_channels):
        super(LeNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5, stride=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)

        return x


def main():
    batch_size = 8
    num_epochs = 10

    train_dataset = torchvision.datasets.MNIST(root="data/", train=True, transform=transforms.ToTensor(),
                                               download=True)
    val_dataset = torchvision.datasets.MNIST(root="data/", train=False, transform=transforms.ToTensor(), download=True)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = LeNet(1).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())

    writer = SummaryWriter(
        'logs/pt/' + time.strftime('%m-%d-%H-%M', time.localtime(time.time())))  # 设置了存放位置,此时即使设置了 comment 也不起作用
    writer.add_graph(model, torch.randn(1, 1, 28, 28))  # 先写入模型结构

    for epoch in tqdm(range(num_epochs)):
        train_loss = 0
        val_loss = 0

        accuracy = 0
        macro_avg_f1 = 0
        weighted_avg_f1 = 0

        # 训练模型
        for batch_idx, (data, label) in enumerate(train_loader):
            data = data.to(device=device)
            label = label.to(device=device)

            pre = model(data)
            loss = criterion(pre, label)
            train_loss = (train_loss * batch_idx + loss) / (batch_idx + 1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 评估模型
        with torch.no_grad():
            model.eval()
            for batch_idx, (data, label) in enumerate(val_loader):
                data = data.to(device=device)
                label = label.to(device=device)

                pre = model(data)

                loss = criterion(pre, label).item()
                val_loss = (val_loss * batch_idx + loss) / (batch_idx + 1)

                pre = torch.argmax(pre, dim=1)
                score = classification_report(pre, label, output_dict=True)

                accuracy = (accuracy * batch_idx + score['accuracy']) / (batch_idx + 1)
                macro_avg_f1 = (macro_avg_f1 * batch_idx + score['macro avg']['f1-score']) / (batch_idx + 1)
                weighted_avg_f1 = (weighted_avg_f1 * batch_idx + score['weighted avg']['f1-score']) / (batch_idx + 1)

            model.train()

        writer.add_scalar('val/accuracy', accuracy, epoch)  # 在一个 tag 下面添加多个折线图
        writer.add_scalar('val/macro avg-f1', macro_avg_f1, epoch)
        writer.add_scalar('val/weighted avg-f1', weighted_avg_f1, epoch)

        writer.add_scalars('loss', {'train_loss': train_loss, 'val_loss': val_loss}, epoch)  # 一个折线图里面显示多个数据

    writer.close()  # 训练结束,不再写入数据,关闭writer


if __name__ == '__main__':
    main()

3. 可视化

注意,这里如果直接输入tensorboard使用的是系统python环境里面的tensorboard,因此,务必确保系统python环境中已安装了tensorboard

打开 log 所在的文件夹,在该文件夹下打开命令行窗口(1.在路径显示框中输入 cmd 然后回车即可打开 或者 2. 在文件夹中按住 Shift 键同时右击鼠标然后选择‘在此处打开Powershell窗口’),输入:

tensorboard --logdir=.\logs --port=6007

其中:

  • –logdir:存放日志文件的文件夹
  • –port:端口号,默认是6006

若正常,将显示以下信息:
请添加图片描述

打开浏览器,在地址栏中输入:localhost:6007(端口号要和上面保持一致)。即可看到可视化的效果,点击不同的标签可查看不同数据

请添加图片描述

4. 配合远程服务器使用

上面的代码和浏览器都在同一个电脑上,有时需要在服务器上跑代码,这时候要想在自己的电脑浏览器上可视化TensorBoard,可以按照下面的方式:

4.1 端口转发

在电脑命令行窗口输入:

ssh -L 6008:127.0.0.1:6007 服务器用户名@服务器IP地址 -p 22

这将在连接服务器的同时将服务器上的 6007 端口转发到本地的 6008 端口。

成功连接服务器后,按照上面讲的可视化的方法在服务器对应的目录下输入:

tensorboard --logdir=日志文件的存放路径 --port=6007

成功启动后,在电脑的浏览器中输入:localhost:6008 。这里用到了两个端口号,一个是服务器上tensorboard使用的端口号 6007 ,一个是本地转发以及浏览器使用的端口号 6008 ,注意区分。

4.2 Xshell建立隧道

在Xshell中选择需要设置的会话,然后右键选择属性,进行如下设置

请添加图片描述
请添加图片描述

设置完之后连接服务器,在服务器对应的目录下输入:

tensorboard --logdir=日志文件的存放路径 --port=6007

成功启动后,在电脑的浏览器中输入:localhost:6008 。这里用到了两个端口号,一个是服务器上tensorboard使用的端口号 6007 ,一个是浏览器使用的端口号 6008 ,注意区分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/839088.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

启动RocketMQ报错

说明:启动RocketMQ消费者时,报以下错误:java.lang.IllegalStateException:Failed to start RocketMQ push consumer. 解决:看下所有的监听器类,检查是不是有相同的消费者组名,注释掉其中一个即可…

文心一言 VS 讯飞星火 VS chatgpt (71)-- 算法导论7.1 1题

参照图 7-1的方法,说明 PARTITION在数组 A(13,19,9,5,12,8,7,4,21,2,6,11)上的操作过程。 文心一言: PARTITION 是一种常…

C++笔记之while循环的帧率监测与控制

C笔记之while循环的帧率监测与控制 code review! 文章目录 C笔记之while循环的帧率监测与控制1.两个时间相关的类2.例程3.运行4.代码 1.两个时间相关的类 2.例程 3.运行 4.代码 // 帧率监测与控制程序 #include <chrono> #include <iostream> #include <thre…

python 连接oracle pandas以简化excel的编写和数据操作

python代码 Author: liukai 2810248865qq.com Date: 2022-08-18 04:28:52 LastEditors: liukai 2810248865qq.com LastEditTime: 2023-07-06 22:12:56 FilePath: \PythonProject02\pandas以简化excel的编写和数据操作.py Description: 这是默认设置,请设置customMade, 打开koro…

Python爬虫的学习day02 requests 模块post 函数, lmxl 模块的 etree 模块

1. requests 模块post 函数 1.1 post 函数的参数 &#xff08;简单版&#xff09; 参数1&#xff1a; url 网络地址 参数2&#xff1a; data 请求数据 &#xff08;一般数据是 账号&#xff0c;密码&#xff09; 参数3&#xff1a; headers 头请求 &#xff08…

概念解析 | 虚拟镜面:超越三次反射的非视线成像

虚拟镜面:超越三次反射的非视线成像 注1:本文系“概念解析”系列之一,致力于简洁清晰地解释、辨析复杂而专业的概念。本次辨析的概念是:虚拟镜面在非视线成像中的应用。 参考文献:Royo D, Sultan T, Muoz A, et al. Virtual Mirrors: Non-Line-of-Sight Imaging Beyond the Th…

MyBatis关联查询

文章目录 前言多对一关联 association一对多关联 collection 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 关联查询是指在一个查询中同时获取多个表中的数据&#xff0c;将它们结合在一起进行展示。 关联表需要两个及以上的表 数据库代码&#xff1…

初阶C语言——特别详细地介绍函数

系列文章目录 第一章 “C“浒传——初识C语言&#xff08;更适合初学者体质哦&#xff01;&#xff09; 第二章 详细认识分支语句和循环语句以及他们的易错点 第三章 初阶C语言——特别详细地介绍函数 目录 系列文章目录 前言 一、函数是个什么鬼东西&#xff1f; 二、C语…

springboot基于vue的高校迎新系统的设计与实现8jf9e

随着时代的发展&#xff0c;人们的生活方式得到巨大的改变&#xff0c;从而慢慢地产生了大量高校迎新信息&#xff0c;高校迎新信息需要一个现代化的管理系统&#xff0c;进行高校迎新信息的管理。 高校迎新系统的开发就是为了解决高校迎新管理的问题&#xff0c;系统开发是基于…

JavaScript |(六)DOM事件 | 尚硅谷JavaScript基础实战

学习来源&#xff1a;尚硅谷JavaScript基础&实战丨JS入门到精通全套完整版 文章目录 &#x1f4da;事件对象&#x1f4da;事件的冒泡&#x1f4da;事件的委派&#x1f4da;事件的绑定&#x1f407;赋值绑定&#x1f407;addEventListener()&#x1f407;attachEvent()&…

认识FFMPEG框架

FFMPEG全称: Fast Forward Moving Picture Experts Group (MPEG:动态图像专家组) ffmpeg相关网站: git://source.ffmpeg.org/ffmpeg.git http://git.videolan.org/?pffmpeg.git https://github.com/FFmpeg/FFmpeg FFMPEG框架基本组件: AVFormat , AVCodec, AVDevice, AVFil…

【多线程学习6】synchronized关键字

【多线程学习6】synchronized关键字 一、synchronized关键字是什么&#xff1f;有什么作用&#xff1f; synchronized关键字是Java线程同步的关键字&#xff0c;其可以修饰方法或代码块&#xff0c;并可以保证其修饰的方法或代码块在任意时刻只能有一个线程执行。 synchroni…

关于HIVE的分区与分桶

1.分区 1.概念 Hive中的分区就是把一张大表的数据按照业务需要分散的存储到多个目录&#xff0c;每个目录就称为该表的一个分区。在查询时通过where子句中的表达式选择查询所需要的分区&#xff0c;这样的查询效率会提高很多 个人理解白话:按表中或者自定义的一个列,对数据进…

筛选给定范围内的日志

目录 1.时间戳 2.实例 1.首先创建ubuntu.log日志 2.写dem.awk创建规则 3.筛选 1.时间戳 一个能表示一份数据在某个特定时间之前已经存在的、 完整的、 可验证的数据,通常是一个字符序列&#xff0c;唯一地标识某一刻的时间。 awk提供了mktime()函数&#xff0c;它可以将时间…

macbook怎么卸载软件?2023年最新全新解析macbook电脑怎样删除软件

macbook怎么卸载软件&#xff1f;2023年最新全新解析macbook电脑怎样删除软件。关于Mac笔记本如何卸载软件_Mac笔记本卸载软件的四种方法的知识大家了解吗&#xff1f;以下就是小编整理的关于Mac笔记本如何卸载软件_Mac笔记本卸载软件的四种方法的介绍&#xff0c;希望可以给到…

像素画教程:立体感与“84渐变法“

像素画本身没有什么困难&#xff0c;是矢量图简笔画之外最简单、而又最容易产生美术效果的画风。 然而&#xff0c;细节难以描绘、立体感难表现、画面易单调成了像素画绘制过程中的常见困难。 这篇文章或许不能保证每个人都能熟练掌握、运用像素画&#xff0c;但至少可以抛砖引…

任务14、无缝衔接,MidJourney瓷砖(Tile)参数制作精良贴图

14.1 任务概述 在这个实验任务中,我们将深入探索《Midjourney Ai绘画》中的Tile技术和其在艺术创作中的具有挑战性的应用。此任务将通过理论学习与实践操作相结合的方式,让参与者更好地理解Tile的核心概念,熟练掌握如何在Midjourney平台上使用Tile参数,并实际运用到AI绘画…

【雕爷学编程】Arduino动手做(184)---快餐盒盖,极低成本搭建机器人实验平台

吃完快餐粥&#xff0c;除了粥的味道不错之外&#xff0c;我对个快餐盒的圆盖子产生了兴趣&#xff0c;能否做个极低成本的简易机器人呢&#xff1f;也许只需要二十元左右 知识点&#xff1a;轮子&#xff08;wheel&#xff09; 中国词语。是用不同材料制成的圆形滚动物体。简…

Dockerfile定制Tomcat镜像

Dockerfile中的打包命令 FROM &#xff1a; 以某个基础镜像作为此镜像的基础 RUN &#xff1a; RUN后面跟着linux常用命令&#xff0c;如RUN echo xxx >> xxx,注意&#xff0c;RUN 不能用于执行命令&#xff0c;因为每个RUN都是独立运行的&#xff0c;RUN 的cd对镜像中的…

Python求均值、方差、标准偏差SD、相对标准偏差RSD

均值 均值是统计学中最常用的统计量&#xff0c;用来表明资料中各观测值相对集中较多的中心位置。用于反映现象总体的一般水平&#xff0c;或分布的集中趋势。 import numpy as npa [2, 4, 6, 8]print(np.mean(a)) # 均值 print(np.average(a, weights[1, 2, 1, 1])) # 带…