用TensorBoard可视化PyTorch

news2025/1/11 11:19:13

一、TensorBoard与PyTorch配合使用的基本步骤

PyTorch可以直接与TensorBoard进行集成,因为TensorBoard是一个独立于TensorFlow之外的可视化工具。TensorBoard被设计为支持机器学习实验的可视化,如训练的进度和结果等。PyTorch中的`torch.utils.tensorboard`模块允许PyTorch用户使用这个强大的可视化工具。
以下是将TensorBoard与PyTorch配合使用的基本步骤:
1. 在PyTorch中安装TensorBoard:   

pip install tensorboard

2. 在Python代码中导入TensorBoard的`SummaryWriter`:   

from torch.utils.tensorboard import SummaryWriter

3. 创建一个`SummaryWriter`实例,它将日志写入指定的目录:

writer = SummaryWriter('runs/your_experiment_name')

4. 将数据写入日志:

   # For example, log scalars
   writer.add_scalar('Loss/train', loss_value, epoch)

   # Log values and models
   writer.add_histogram('weights', model.weight, epoch)
   writer.add_graph(model, input_to_model)

   # Log images
   writer.add_image('input_image', img, epoch)

   # And many more...

   5. 当所有日志都写入后,在命令行启动TensorBoard,在浏览器中查看结果:

   tensorboard --logdir=runs

之后,就可以在TensorBoard的Web界面中看到各种图形和数据的可视化展现,这对于理解模型的学习过程、调试以及展示结果是非常有用的。
此外,社区也开发了一些其他可视化工具,比如`visdom`,但TensorBoard因其功能强大和易用性,在PyTorch社区中得到了广泛的应用。 

二、PyTorch与TensorBoard进行集成的完整示例

要将PyTorch与TensorBoard结合起来,可以使用`tensorboardX`库,这是一个提供了与TensorBoard兼容的API的库,使得可以从PyTorch中记录数据并在TensorBoard中查看。不过,从PyTorch 1.1.0起,官方直接内置了对TensorBoard的支持,称为`torch.utils.tensorboard`。以下是一个简单的例子,说明如何使用PyTorch训练一个模型并使用TensorBoard记录日志:
首先,确保已经安装PyTorch和TensorBoard:

pip install torch torchvision tensorboard

接下来,是一个简单的训练脚本示例,将会记录损失和精度:


import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 创建一些数据进行演示
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 转换为torch张量
tensor_x = torch.Tensor(X_train)
tensor_y = torch.Tensor(y_train).long()
tensor_x_test = torch.Tensor(X_test)
tensor_y_test = torch.Tensor(y_test).long()

# 创建数据加载器
train_dataset = TensorDataset(tensor_x, tensor_y)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = TensorDataset(tensor_x_test, tensor_y_test)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 创建一个简单的模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(20, 64)
        self.fc2 = nn.Linear(64, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 设置TensorBoard
writer = SummaryWriter()

# 训练模型
for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # 将结果写入TensorBoard
    writer.add_scalar('training loss', running_loss / len(train_loader), epoch)
    
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    writer.add_scalar('accuracy', correct / total, epoch)

writer.close()
print('Finished Training')

# Now you can view the results by running in your terminal:
# tensorboard --logdir=runs

在这个脚本中,我们创建了一个简单的完全连接网络用于分类,并将训练过程中的损失和精度写入TensorBoard。要查看TensorBoard结果,保存并运行上面的脚本。然后,在终端中运行以下命令:

tensorboard --logdir=runs

打开显示的URL,将能够看见TensorBoard的仪表盘,反映出模型训练过程中记录的数据。

这段代码是一个使用PyTorch进行数据分类的示例,同时演示了如何将训练过程的信息记录到TensorBoard中。
1. 导入必要的库:
   这段代码首先导入了PyTorch相关的库,包括模型(layers)、优化(optimizer)、数据处理等组件。同时还导入了`SummaryWriter`用于向TensorBoard写入数据。
2. 生成和预处理数据:
   代码使用scikit-learn库中的`make_classification`函数生成了一个具有1000个样本、20个特征的合成分类数据集。然后,它使用`train_test_split`将数据集分割为训练集和测试集。标准化这些数据使得每个特征的分布均值为0,方差为1。
3. 准备PyTorch数据加载器:
   代码将处理好的数据转换为PyTorch张量,然后创建了`TensorDataset`数据集对象,最后通过`DataLoader`为训练和测试数据集创建迭代器,用于在训练过程中加载数据。
4. 定义简单的神经网络模型:
   定义了一个名为`SimpleNet`的神经网络类,它包含两个全连接层,第一个全连接层将20个特征映射到64个隐藏单元,接着是ReLU激活函数,最后一个全连接层将64个隐藏单元映射到2个输出(因为是二分类问题)。
5. 创建损失函数和优化器:
   使用交叉熵损失函数作为分类问题的损失函数,以及使用Adam优化器对模型的参数进行优化。
6. 设置TensorBoard日志记录器:
   初始化了`SummaryWriter`,这个对象将用于将训练过程中的信息写入日志文件,这些文件可以被TensorBoard读取并可视化。
7. 训练模型:
   在一个循环中,代码遍历了数据集多次(这里定义了10个epoch),在每个epoch中,设置模型为训练模式,并用数据加载器获取训练数据。对于每个批次的数据,执行前向传播,计算损失,执行反向传播和优化步骤。同时,汇总损失并在每个epoch结束时将平均损失记录到TensorBoard中。
8. 评估模型和记录准确率:
   在每个训练epoch之后,代码进入评估模式,并停止梯度计算,使用测试数据集计算模型预测的准确性,并将这个结果记录到TensorBoard中。
9. 关闭TensorBoard日志记录器并完成训练:
   训练结束后,会关闭`SummaryWriter`,此时训练生成的日志文件已经写入磁盘中`runs`目录下。打印完成训练的信息。
10. 查看TensorBoard中的结果:
    最后,通过命令行中运行`tensorboard --logdir=runs`来启动TensorBoard服务,并可以在浏览器中打开显示的URL来查看训练过程中记录的损失和准确率曲线。
另,with torch.no_grad():是用来停止PyTorch跟踪梯度信息。在测试模式下通常需要这么做,以减少内存消耗并加速计算。

三、可能出现的问题

由于`collections.Mapping`在Python 3.10及以后的版本已经被移除了,而应该使用`collections.abc.Mapping`。由于`tensorboard`的某些依赖库在较新版的Python中可能仍在使用已经废弃的模块路径,因此抛出了`ImportError`。
如果TensorBoard是独立于PyTorch环境外安装的,可能需要在一个PyTorch支持的Python环境中安装和运行TensorBoard。PyTorch目前支持的Python版本是3.6-3.9,如果环境中的Python版本是3.12,这有可能导致兼容性问题。
要解决这个问题,可以试图降低Python的版本,创建一个新的虚拟环境,安装一个TensorBoard版本,该版本与Python版本兼容,或等待或协助贡献TensorBoard对新Python版本的支持。下面是创建新虚拟环境并尝试安装TensorBoard的方法:
1. 创建新的Python环境(推荐使用Python 3.9)并激活它:

conda create -n new_env python=3.9
conda activate new_env

2. 在新环境中安装TensorBoard和PyTorch:

pip install tensorboard torch torchvision

3. 重新尝试启动TensorBoard:

tensorboard --logdir=runs

这种方式安装可能会避开遇到的兼容性问题。如果问题依旧存在,请考虑在相应的TensorBoard或相关依赖包的GitHub问题跟踪页面提交问题报告,以获取官方或社区的解决方案。在等待修复的同时,可以使用其他的Python版本,在那里TensorBoard是兼容的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1578730.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构】考研真题攻克与重点知识点剖析 - 第 6 篇:图

前言 本文基础知识部分来自于b站:分享笔记的好人儿的思维导图与王道考研课程,感谢大佬的开源精神,习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析,本人技术…

智慧安防系统EasyCVR视频汇聚平台接入大华设备无法语音对讲的原因排查与解决

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台支持7*24小时实时高清视频监控,能同时播放多路监控视频流,视频画面1、4、9、16个可选,支持自定义视频轮播。EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标…

数据可视化-ECharts Html项目实战(10)

在之前的文章中,我们学习了如何在ECharts中编写雷达图,实现特殊效果的插入运用,函数的插入,以及多图表雷达图。想了解的朋友可以查看这篇文章。同时,希望我的文章能帮助到你,如果觉得我的文章写的不错&…

甲方安全建设之研发安全-SCA

前言 大多数企业或多或少的会去采购第三方软件,或者研发同学在开发代码时,可能会去使用一些好用的软件包或者依赖包,但是如果这些包中存在恶意代码,又或者在安装包时不小心打错了字母安装了错误的软件包,则可能出现供…

shrine-攻防世界

题目 代码 import flask import os app flask.Flask(__name__) app.config[FLAG] os.environ.pop(FLAG) app.route(/) def index(): return open(__file__).read() app.route(/shrine/) def shrine(shrine): def safe_jinja(s): s s.replace((, ).replace(), ) …

算法之美:缓存数据淘汰算法分析及分解实现

在设计一个系统的时候,由于数据库的读取速度远小于内存的读取速度,那么为加快读取速度,需先将一部分数据加入到内存中(该动作称为缓存),但是内存容量又是有限的,当缓存的数据大于内存容量时&…

nodejs+python基于vue的羽毛球培训俱乐部管理系统django

语言:nodejs/php/python/java 框架:ssm/springboot/thinkphp/django/express 请解释Flask是什么以及他的主要用途 Flask是一个用Python编写的清凉web应用框架。它易于扩展且灵活,适用于小型的项目或者微服务,以及作为大型应用的一…

spring eureka 服务实例实现快速下线快速感知快速刷新配置解析

背景 默认的Spring Eureka服务器,服务提供者和服务调用者配置不够灵敏,总是服务提供者在停掉很久之后,服务调用者很长时间并没有感知到变化。或者是服务已经注册上去了,但是服务调用方很长时间还是调用不到,发现不了这…

【Mysql高可用集群-双主双活-myql+keeplived】

Mysql高可用集群-双主双活-myqlkeeplived 一、介绍二、准备工作1.两台centos7 linux服务器2.mysql安装包3.keepalived安装包 三、安装mysql1.在128、129两台服务器根据《linux安装mysql服务-两种安装方式教程》按方式一安装好mysql应用。2.修改128服务器/etc/my.cnf配置文件&am…

第8章 数据集成和互操作

思维导图 8.1 引言 数据集成和互操作(DII)描述了数据在不同数据存储、应用程序和组织这三者内部和之间进行移动和整合的相关过程。数据集成是将数据整合成物理的或虚拟的一致格式。数据互操作是多个系统之间进行通信的能力。数据集成和互操作的解决方案提供了大多数组织所依赖的…

携程旅行 abtest

声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!wx a15018601872 本文章…

Java 基于微信小程序的助农扶贫小程序

博主介绍:✌Java徐师兄、7年大厂程序员经历。全网粉丝13w、csdn博客专家、掘金/华为云等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇🏻 不…

React - 你知道useffect函数内如何模拟生命周期吗

难度级别:中级及以上 提问概率:65% 很多前端开发人员习惯了Vue或者React的组件式开发,熟知组件的周期过程包含初始化、挂载完成、修改和卸载等阶段。但是当使用Hooks做业务开发的时候,看见一个个useEffect函数,却显得有些迷茫,因为在us…

Flutter之Flex组件布局

目录 Flex属性值 轴向:direction:Axis.horizontal 主轴方向:mainAxisAlignment:MainAxisAlignment.center 交叉轴方向:crossAxisAlignment:CrossAxisAlignment 主轴尺寸:mainAxisSize 文字方向:textDirection:TextDirection 竖直方向排序:verticalDirection:VerticalDir…

Java 线程池 参数

1、为什么要使用线程池 线程池能有效管控线程,统一分配任务,优化资源使用。 2、线程池的参数 创建线程池,在构造一个新的线程池时,必须满足下面的条件: corePoolSize(线程池基本大小)必须大于…

JVM流程图自我总结

JVM流程图总览 运行时数据区是否有GC、OOM图 从线程共享角度区别图

【深度学习】最强算法之:图神经网络(GNN)

图神经网络 1、引言2、图神经网络2.1 定义2.2 原理2.3 实现方式2.4 算法公式2.4.1 GNN2.4.2 GCN 2.5 代码示例 3、总结 1、引言 小屌丝:鱼哥,给俺讲一讲图神经网络啊 小鱼:你看,我这会在忙着呢 小屌丝:啊~ 小鱼&#…

如何在Rust中操作JSON

❝ 越努力,越幸运 ❞ 大家好,我是「柒八九」。一个「专注于前端开发技术/Rust及AI应用知识分享」的Coder。 前言 我们之前在Rust 赋能前端-开发一款属于你的前端脚手架中有过在Rust项目中如何操作JSON。 由于文章篇幅的原因,我们就没详细介绍…

java算法day48 | 动态规划part09 ● 198.打家劫舍 ● 213.打家劫舍II ● 337.打家劫舍III

198.打家劫舍 class Solution {public int rob(int[] nums) {if(nums.length0) return 0;if(nums.length1) return nums[0];int[] dpnew int[nums.length];dp[0]nums[0];dp[1]Math.max(nums[1],nums[0]);for(int i2;i<nums.length;i){dp[i]Math.max(dp[i-1],dp[i-2]nums[i])…

网络工程师笔记18(关于网络的一些基本知识)

网络的分类 介绍计算机网络的基本概念&#xff0c;这一章最主要的内容是计算机网络的体系结构-ISO 开放系统互连参考模型&#xff0c;其中的基本概念&#xff0c;例如协议实体、协议数据单元&#xff0c;服务数据单元、面向连接的服务和无连接的服务、服务原语、服务访问点、相…