pytorch 利用Tensorboar记录训练过程loss变化

news2024/10/6 14:35:56

文章目录

    • 1. LossHistory日志类定义
    • 2. LossHistory类的使用
      • 2.1 实例化LossHistory
      • 2.2 记录每个epoch的loss
      • 2.3 训练结束close掉SummaryWriter
    • 3. 利用Tensorboard 可视化
      • 3.1 显示可视化效果
    • 参考

利用Tensorboard记录训练过程中每个epoch的训练loss以及验证loss,便于及时了解网络的训练进展。

代码参考自 B导github仓库: https://github.com/bubbliiiing/deeplabv3-plus-pytorch

1. LossHistory日志类定义

import os
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signal

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
class LossHistory():
    def __init__(self, log_dir, model, input_shape):
        self.log_dir    = log_dir
        self.losses     = []
        self.val_loss   = []
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass

    def append_loss(self, epoch, loss, val_loss):
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self.losses.append(loss)
        self.val_loss.append(val_loss)

        with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
            f.write(str(loss))
            f.write("\n")
        with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
            f.write(str(val_loss))
            f.write("\n")

        self.writer.add_scalar('loss', loss, epoch)
        self.writer.add_scalar('val_loss', val_loss, epoch)
        self.loss_plot()

    def loss_plot(self):
        iters = range(len(self.losses))

        plt.figure()
        plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
        plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
        try:
            if len(self.losses) < 25:
                num = 5
            else:
                num = 15
            
            plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
            plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
        except:
            pass

        plt.grid(True)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend(loc="upper right")

        plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))

        plt.cla()
        plt.close("all")
  • (1) 首先利用LossHistory类的构造函数__init__, 实例化TensorboardSummaryWriter对象self.writer,并将网络结构图添加到self.writer中。其中__init__方法接收的参数包括,保存log的路径log_dir以及模型model和输入的shape
def __init__(self, log_dir, model, input_shape):
        self.log_dir    = log_dir
        self.losses     = []
        self.val_loss   = []
        
        os.makedirs(self.log_dir)
        self.writer     = SummaryWriter(self.log_dir)
        try:
            dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
            self.writer.add_graph(model, dummy_input)
        except:
            pass
  • (2) 记录每个epoch的训练损失loss以及验证val_loss,并保存到tensorboar中显示
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

同时将训练的loss以及验证val_loss逐行保存到.txt文件中

 with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
      f.write(str(loss))
      f.write("\n")

并且在每个epoch时,调用loss_plot绘制历史的loss曲线,并保存为epoch_loss.png, 由于每个epoch保存的图片都是重名的,因此在训练结束时,会保存最新的所有epoch绘制的loss曲线

2. LossHistory类的使用

2.1 实例化LossHistory

在训练开始前,实例化LossHistory类,调用__init__实例化时,会创建SummaryWriter对象,用于记录训练的过程中的数据,比如loss, graph以及图片信息等

local_rank  = int(os.environ["LOCAL_RANK"]) 
model   = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained)
input_shape     = [512, 512]

if local_rank == 0:
      time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
      log_dir         = os.path.join(save_dir, "loss_" + str(time_str))
      loss_history    = LossHistory(log_dir, model, input_shape=input_shape)
  else:
      loss_history    = None
  • 对于多GPU训练时,只在主进程(local_rank == 0)记录训练的日志信息
  • log 保存的路径log_dir,利用loss_ + 当前时间的形式记录
log_dir         = os.path.join(save_dir, "loss_" + str(time_str))

2.2 记录每个epoch的loss

在每个epoch中,利用loss_history的append_loss方法,利用SummaryWriter对象保存loss:

for epoch in range(start_epoch, total_epoch):
	...
	loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  • 记录了每个epoch的训练loss以及验证val_loss
  • 同时将最新的loss曲线,保存到本地epoch_loss.png
  • 并将历史的训练loss和val_loss保存为txt文件,方便查看

2.3 训练结束close掉SummaryWriter

loss_history.writer.close()

3. 利用Tensorboard 可视化

  • Tensorboard最早是在Tensorflow中开发和应用的,pytorch 中也同样支持Tensorboard的使用,pytorch中的Tensorboard工具叫TensorboardX, 它需要依赖于tensorflow库中的一些组件支持。因此在安装Tensorboardx之前,需要先安装TensorFlow, 否则直接安装Tensorboardx运行会报错。
pip install tensorflow
pip install tensorboardX

3.1 显示可视化效果

训练结束后,cd到SummaryWriter中定义好日志保存目录log_dir下,执行如下指令

cd log_dir # log_dir为定义的日志保存目录
tensorboard  --logdir=./     --port 6006 

然后会显示出访问的链接地址,点击链接就可以查看Tensorboard可视化效果

  • Scalar模块展示训练过程中,每个epoch的train_loss、Accuracy、Learn_Rating的数值变化
    在这里插入图片描述
  • GRAPH模块展示的是模型的网络结构
    在这里插入图片描述
  • HISTOGRAMS模块展示添加到tensorboard中各层的权重分布情况
    在这里插入图片描述

参考

  • 1 https://github.com/bubbliiiing/deeplabv3-plus-pytorch
  • 2 pytorch中使用tensorboard实现训练过程可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1435990.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java数据结构】单向 不带头 非循环 链表实现

模拟实现LinkedList&#xff1a;下一篇文章 LinkedList底层是双向、不带头结点、非循环的链表 /*** LinkedList的模拟实现*单向 不带头 非循环链表实现*/ class SingleLinkedList {class ListNode {public int val;public ListNode next;public ListNode(int val) {this.val …

【多模态大模型】视觉大模型SAM:如何使模型能够处理任意图像的分割任务?

SAM&#xff1a;如何使模型能够处理任意图像的分割任务&#xff1f; 核心思想起始问题: 如何使模型能够处理任意图像的分割任务&#xff1f;5why分析5so分析 总结子问题1: 如何编码输入图像以适应分割任务&#xff1f;子问题2: 如何处理各种形式的分割提示&#xff1f;子问题3:…

c++之说_10|自定义类型 union 联合体

之前我们说了一些 struct 结构体 现在来了解新的自定义类型 union 联合体 语法 union ptr {void* fptr;CLassFunPtr p;FunPtr p2;ptr& operator(CLassFunPtr ptr){p ptr;return *this;}ptr& operator(FunPtr Fptr){p2 Fptr;return *this;} } FunPtr_; 我们看到了…

第 383 场 LeetCode 周赛题解

A 边界上的蚂蚁 模拟 class Solution { public:int returnToBoundaryCount(vector<int> &nums) {int s 0;int res 0;for (auto x: nums) {s x;if (s 0)res;}return res;} };B 将单词恢复初始状态所需的最短时间 I 枚举&#xff1a;若经过 i i i 秒后 w o r d w…

Leetcode刷题笔记题解(C++):257. 二叉树的所有路径

思路&#xff1a;深度优先搜索 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* TreeNode(int x) : val(x), left(nullptr), right…

leetcode 算法 67.二进制求和(python版)

需求 给你两个二进制字符串 a 和 b &#xff0c;以二进制字符串的形式返回它们的和。 示例 1&#xff1a; 输入:a “11”, b “1” 输出&#xff1a;“100” 示例 2&#xff1a; 输入&#xff1a;a “1010”, b “1011” 输出&#xff1a;“10101” 代码 class Solution…

这个门禁考勤技术,看了都说好!

在当今数字化时代&#xff0c;考勤管理对于企业、学校、机构等各类组织至关重要。随着科技的不断进步&#xff0c;传统的考勤方式逐渐显露出效率低、安全性差等问题。 因此&#xff0c;为了应对这些挑战&#xff0c;三维人脸考勤系统作为一项创新的解决方案应运而生。 客户案例…

C#,纽曼-尚克斯-威廉士素数(Newman Shanks Williams prime)的算法与源代码

1 NSW素数 素数是纽曼-尚克斯-威廉士素数&#xff08;Newman-Shanks-Williams prime&#xff0c;简写为NSW素数&#xff09;当且仅当它能写成以下的形式&#xff1a; 1981年M. Newman、D. Shanks和H. C. Williams在研究有限集合时&#xff0c;率先描述了NSW素数。 首几个NSW素…

【经典例子】Java实现2048小游戏(附带源码)

一、游戏回顾 2048游戏是一款数字益智游戏&#xff0c;目标是通过合并相同数字的方块来达到2048这个目标。游戏在一个4x4的方格上进行&#xff0c;每个方格上都有一个数字&#xff08;初始时为2或4&#xff09;。玩家可以通过滑动方向键&#xff08;上、下、左、右&#xff09;…

Java实现用户画像活动推荐系统 JAVA+Vue+SpringBoot+MySQL

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 兴趣标签模块2.3 活动档案模块2.4 活动报名模块2.5 活动留言模块 三、系统设计3.1 用例设计3.2 业务流程设计3.3 数据流程设计3.4 E-R图设计 四、系统展示五、核心代码5.1 查询兴趣标签5.2 查询活动推荐…

升级GPT4保姆级教程

前言&#xff1a; 2024-01-26开通了GPT4之后至今已经使用了两周&#xff0c;体验下来是真的强&#xff0c;各种GPTs使用起来也很丝滑&#xff0c;不需要自己额外调试。之前看版本计划&#xff0c;2024年会发布GPT5&#xff0c;如果你还没有用上GPT4的话快快来升级体验一下吧&a…

C语言之自定义类型:联合和枚举

目录 1. 联合体类型的声明2. 联合体的特点3. 联合体大小的计算联合的一个练习 4. 枚举类型的声明5. 枚举类型的优点6. 枚举类型的使用 1. 联合体类型的声明 像结构体一样&#xff0c;联合体也是由一个或者多个成员构成&#xff0c;这些成员可以不同的类型 但是编译器只为最大…

机器学习系列4-特征工程

机器学习系列4-特征工程 学习内容来自&#xff1a;谷歌ai学习 https://developers.google.cn/machine-learning/crash-course/framing/check-your-understanding?hlzh-cn 本文作为学习记录自己归纳整理的思维导图 这里写目录标题 机器学习系列4-特征工程一级目录二级目录三…

Mac利用brew安装mysql并设置初始密码

前言 之前一直是在windows上开发后段程序&#xff0c;所以只在windows上装mysql。(我记得linux只需要适应yum之类的命令即可) 另外, linux的移步 linux安装mysql (详细步骤,初次初始化,sql小例子,可视化操作客户端推荐) 好家伙&#xff0c;我佛了&#xff0c;写完当天网上发…

集群clickhouse使用和clickhouse索引的使用

ClickHouse支持多种索引类型&#xff0c;包括普通索引、范围索引、哈希索引、倒排索引等。使用索引可以加快查询速度和提高查询效率。下面是ClickHouse索引的一些使用方法&#xff1a; 1 普通索引 可以使用普通索引来加速查询特定的列&#xff0c;例如&#xff1a; CREATE TA…

MES生产制造管理:汽车零部件生产MES解决方案

某某汽车部件科技有限公司是一家铝合金零部件研发、压铸和精加工为一体的高新技术企业,拥有先进压铸、机加、检测等设备,并配套自动化生产线。为解决发动机支架等产品的全程生产质量追溯和实现机台设备联网,梅施科技提供了车间级的MES解决方案,如图所示&#xff1a; 梅施科技采…

Axios设置token到请求头的三种方式

1、为什么要携带token? 用户登录时&#xff0c;后端会返回一个token&#xff0c;并且保存到浏览器的localstorage中&#xff0c;可以根据localstorage中的token判断用户是否登录&#xff0c;登录后才有权限访问相关的页面&#xff0c;所以当发送请求时&#xff0c;都要携带to…

每日一题——LeetCode1408.数组中的字符串匹配

方法一 暴力枚举&#xff1a; 对每个单词循环判断是否是其他单词的子字符串 var stringMatching function(words) {const ret [];for (let i 0; i < words.length; i) {for (let j 0; j < words.length; j) {if (i ! j && words[j].search(words[i]) ! -1)…

探索Gin框架:Golang Gin框架请求参数的获取

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站https://www.captainbed.cn/kitie。 前言 我们在专栏的前面几篇文章内讲解了Gin框架的路由配置&#xff0c;服务启动等内容。 专栏地址&…

TS项目实战二:网页计算器

使用ts实现网页计算器工具&#xff0c;实现计算器相关功能&#xff0c;使用tsify进行项目编译&#xff0c;引入Browserify实现web界面中直接使用模块加载服务。   源码下载&#xff1a;点击下载 讲解视频 TS实战项目四&#xff1a;计算器项目创建 TS实战项目五&#xff1a;B…