使用python的plot绘制loss、acc曲线,并存储成图片

news2024/11/16 21:21:28

使用 python的plot 绘制网络训练过程中的的 loss 曲线以及准确率变化曲线,这里的主要思想就时先把想要的损失值以及准确率值保存下来,保存到 .txt 文件中,待网络训练结束,我们再拿这存储的数据绘制各种曲线。

其大致步骤为:数据读取与存储 - > loss曲线绘制 - > 准确率曲线绘制

一、数据读取与存储部分

我们首先要得到训练时的数据,以损失值为例,网络每迭代一次都会产生相应的 loss,那么我们就把每一次的损失值都存储下来,存储到列表,保存到 .txt 文件中。

1.3817585706710815, 
1.8422836065292358, 
1.1619832515716553, 
0.5217241644859314, 
0.5221078991889954, 
1.3544578552246094, 
1.3334463834762573, 
1.3866571187973022, 
0.7603049278259277

上图为部分损失值,根据迭代次数而异,要是迭代了1万次,这里就会有1万个损失值。
而准确率值是每一个 epoch 产生一个值,要是训练100个epoch,就有100个准确率值。

这里的损失值是怎么保存到文件中的呢?首先,找到网络训练代码,就是项目中的 main.py,或者 train.py ,在文件里先找到训练部分,里面经常会有这样一行代码:

for epoch in range(resume_epoch, num_epochs):   # 就是这一行
	####
	...
	loss = criterion(outputs, labels.long())              # 损失样例
	...
    epoch_acc = running_corrects.double() / trainval_sizes[phase]    # 准确率样例
    ...
    ###

从这一行开始就是训练部分了,往下会找到类似的这两句代码,就是损失值和准确率值了。

这时候将以下代码加入源代码就可以了:

train_loss = []
train_acc = []
for epoch in range(resume_epoch, num_epochs):          # 就是这一行
	###
	...
	loss = criterion(outputs, labels.long())           # 损失样例
	train_loss.append(loss.item())                     # 损失加入到列表中
	...
	epoch_acc = running_corrects.double() / trainval_sizes[phase]    # 准确率样例
	train_acc.append(epoch_acc.item())                 # 准确率加入到列表中
	... 
with open("./train_loss.txt", 'w') as train_los:
    train_los.write(str(train_loss))

with open("./train_acc.txt", 'w') as train_ac:
     train_ac.write(str(train_acc))

这样就算完成了损失值和准确率值的数据存储了!

二、绘制 loss 曲线

主要需要 numpy 库和 matplotlib 库。

pip install numpy malplotlib

首先,将 .txt 文件中的存储的数据读取进来,以下是读取函数:

import numpy as np

# 读取存储为txt文件的数据
def data_read(dir_path):
    with open(dir_path, "r") as f:
        raw_data = f.read()
        data = raw_data[1:-1].split(", ")   # [-1:1]是为了去除文件中的前后中括号"[]"

    return np.asfarray(data, float)

然后,就是绘制 loss 曲线部分:

if __name__ == "__main__":

	train_loss_path = r"/train_loss.txt"   # 存储文件路径
	
	y_train_loss = data_read(train_loss_path)        # loss值,即y轴
	x_train_loss = range(len(y_train_loss))			 # loss的数量,即x轴

	plt.figure()

    # 去除顶部和右边框框
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.xlabel('iters')    # x轴标签
    plt.ylabel('loss')     # y轴标签
	
	# 以x_train_loss为横坐标,y_train_loss为纵坐标,曲线宽度为1,实线,增加标签,训练损失,
	# 默认颜色,如果想更改颜色,可以增加参数color='red',这是红色。
    plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
    plt.legend()
    plt.title('Loss curve')
    plt.show()
    pit.savefig("loss.png")

这样就算把损失图像画出来了!如下:
在这里插入图片描述

三、绘制准确率曲线
有了上面的基础,这就简单很多了。
只是有一点要记住,上面的x轴是迭代次数,这里的是训练轮次 epoch。

if __name__ == "__main__":

	train_acc_path = r"/train_acc.txt"   # 存储文件路径
	
	y_train_acc = data_read(train_acc_path)       # 训练准确率值,即y轴
	x_train_acc = range(len(y_train_acc))			 # 训练阶段准确率的数量,即x轴

	plt.figure()

    # 去除顶部和右边框框
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.xlabel('epochs')    # x轴标签
    plt.ylabel('accuracy')     # y轴标签
	
	# 以x_train_acc为横坐标,y_train_acc为纵坐标,曲线宽度为1,实线,增加标签,训练损失,
	# 增加参数color='red',这是红色。
    plt.plot(x_train_acc, y_train_acc, color='red',linewidth=1, linestyle="solid", label="train acc")
    plt.legend()
    plt.title('Accuracy curve')
    plt.show()
    pit.savefig("acc.png")

这样就把准确率变化曲线画出来了!如下:
在这里插入图片描述

以下是完整代码,以绘制准确率曲线为例,并且将x轴换成了iters,和损失曲线保持一致,供参考:

import numpy as np
import matplotlib.pyplot as plt


# 读取存储为txt文件的数据
def data_read(dir_path):
    with open(dir_path, "r") as f:
        raw_data = f.read()
        data = raw_data[1:-1].split(", ")

    return np.asfarray(data, float)


# 不同长度数据,统一为一个标准,倍乘x轴
def multiple_equal(x, y):
    x_len = len(x)
    y_len = len(y)
    times = x_len/y_len
    y_times = [i * times for i in y]
    return y_times


if __name__ == "__main__":

    train_loss_path = r"/train_loss.txt"
    train_acc_path = r"/train_acc.txt"

    y_train_loss = data_read(train_loss_path)
    y_train_acc = data_read(train_acc_path)

    x_train_loss = range(len(y_train_loss))
    x_train_acc = multiple_equal(x_train_loss, range(len(y_train_acc)))

    plt.figure()

    # 去除顶部和右边框框
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.xlabel('iters')
    plt.ylabel('accuracy')

    # plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
    plt.plot(x_train_acc, y_train_acc,  color='red', linestyle="solid", label="train accuracy")
    plt.legend()

    plt.title('Accuracy curve')
    plt.show()
    pit.savefig("acc.png")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/605129.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码自动生成工具——TableGo(实例演示)

一、常用的代码生成器工具介绍 在SpringBoot项目开发中,为了提高开发效率,我们经常需要使用代码自动生成工具来生成一些重复性的代码,比如实体类、DAO、Service、Controller等等。下面介绍几个常用的代码自动生成工具: ①、MyBat…

如何在Linux 启用组播

第一章: 前言 多播技术,也被称为“组播”,是一种网络通信机制,它允许一个节点(发送者)向一组特定的节点(接收者)发送信息。这种方式在网络编程中非常有用,因为它可以大大提高效率和…

深度学习(Pytorch):Softmax回归

Softmax简介 Softmax回归是一个用于多类分类问题的线性模型,它是从Logistic回归模型演变而来的。Softmax回归与Logistic回归类似,但是输出不再是二元的,而是多类的。Softmax回归引入了softmax激活函数来将输出转换为合理的概率分布。与线性回…

HCIE-Cloud Computing LAB备考--第五题:规划--Type13练习--记忆技巧+默写

对LLD表,交换机接口表,ensp配置进行练习,如下图,设置答案和空白表,进行默写,汇总自己的容易犯的错误 LLD表默写思路 交换机接口配置表默写思路 以Type3为例,同颜色复制即可,共用ST.P0是A25,ST.P2是A21,FS是ST.P0是A21,ST.P2是A21。 ensp配置默写思路 特点: 所…

一步一步学习 Stable Diffusion

一步一步学习 Stable Diffusion 0. 背景1. 安装2. 汉化3. 安装 sd-webui-controlnet 插件4. 安装 sd-webui-segment-anything 插件5. 安装 ultimate-upscale 插件6. 安装 SadTalker 插件7. 下载和配置 VAE 模型8. 使用 ChilloutMix 模型99. 未完待续 0. 背景 网上看了很多 Sta…

priority_queue(优先级队列)

priority_queue 1. priority_queue的介绍及使用1.1 priority_queue的介绍1.2 priority_queue的使用1.2.1 constructor(构造)1.2.2 empty1.2.3 size1.2.4 top1.2.5 emplace1.2.6 push、pop、swap 1.3 数组中第K个大的元素 2.priority_queue的深度剖析及模拟实现 1. priority_que…

Makerbase SimpleFOC ESP32例程4 双电机闭环速度测试

Makerbase SimpleFOC ESP32例程4 双电机闭环速度测试 第一部分 硬件介绍 1.1 硬件清单 序号品名数量1ESP32 FOC V1.0 主板12YT2804电机2312V电源适配器14USB 线156pin杜邦线2 注意:YT2804是改装的云台无刷电机,带有AS5600编码器,可实现360连续运转。…

柔性作业车间调度

1柔性车间作业调度 个工件 要在 台机器 上加工。每个工件包含一道或多道工序,工序顺序是预先确定的,每道工序可以在多台不同加工机器上进行加工,工序的加工时间随加工机器的不同而不同。调度目标是为每道工序选择最合适的机器、确定每台机器…

【C语言】语言篇——数组和字符串

C站的小伙伴们,大家好呀😝😝!我最近在阅读学习刘汝佳老师的《算法竞赛入门经典》,今天将整理本书的第三章——数组和字符串的一些习题,本章习题较多,下选取部分习题进行练习总结,在这…

200道面试题(附答案)

最近有不少小伙伴跑来咨询: 想找网络安全工作,应该要怎么进行技术面试准备?工作不到 2 年,想跳槽看下机会,有没有相关的面试题呢? 为了更好地帮助大家高薪就业,今天就给大家分享两份网络安全工…

ubuntu20.04 ffmpeg mp4转AES加密的m3u8分片视频

样本视频(时长2分35秒): 大雄兔_百度百科 大雄兔_百度百科不知大家否看过世界上第一部开源电影:Elephants Dream(大象之梦)。这是一部由主要由开源软件Blender制作的电影短片,证明了用开源软件也能制作出效果媲美大公司的作品。…

1-9 随机算法【手写+Xmind笔记】

文章目录 1 Min-Cut【手写笔记】1.1 问题描述1.2 解决方案1.3 概率证明 2 赠券收集【手写笔记】3 快排期望【手写笔记】4 素数性质【手写笔记】4.1 基本性质4.2 解决方案4.3 群论4.4 费马小定理4.5 Miller Rabin素性测试 5-6 力矩与偏差【手写笔记】5.1 基础不等式5.2 矩生成函…

[图表]pyecharts模块-柱状图

[图表]pyecharts模块-柱状图 先来看代码: from pyecharts.charts import Bar from pyecharts.faker import Faker from pyecharts.globals import ThemeTypec (Bar({"theme": ThemeType.MACARONS}).add_xaxis(Faker.choose()).add_yaxis("商家A&q…

Spring 核心概念之一 IoC

前言 欢迎来到本篇文章!通过上一篇什么是 Spring?为什么学它?的学习,我们知道了 Spring 的基本概念,知道什么是 Spring,以及为什么学习 Spring。今天,这篇就来说说 Spring 中的核心概念之一 Io…

day2 -- 数据库的安全管理和维护

brief 访问控制的目的不仅仅是防止用户的恶意企图。数据梦魇更为常见的是无意识错误的结果,如错打MySQL语句,在不合适的数据库中操作或其他一些用户错误。通过保证用户不能执行他们不应该执行的语句,访问控制有助于避免这些情况的发生。管理…

Makerbase SimpleFOC ESP32 例程6 双电机闭环位置力矩互控

Makerbase SimpleFOC ESP32 例程6 双电机闭环位置力矩互控 第一部分 硬件介绍 1.1 硬件清单 序号品名数量1ESP32 FOC V1.0 主板12YT2804电机2312V电源适配器14USB 线156pin杜邦线2 注意:YT2804是改装的云台无刷电机,带有AS5600编码器,可实现360连续运…

Go 字节跳动—从需求到上线全流程

走进后端开发流程 整个课程会带大家先从理论出发,思考为什么有流程 大家以后工作的团队可能不一样,那么不同的团队也会有不同的流程,这背后的逻辑是什么 然后会带大家按照走一遍从需求到上线的全流程,告诉大家在流程的每个阶段&am…

angular环境安装 (含nodejs详细安装步骤)

在安装本次环境之前,需要先把本机上的nodejs环境卸载,环境变量手动删除!安装过程种环境才不会产生副作用!实际项目安装的一次记录,踩了太多坑,记录一下,旨在记录!项目需要两个不用版…

常用设计模式介绍~~~ Java实现 【概念+案例+代码】

前言 想要读懂源码、让自己的代码写的更加优雅,重构系统等。理解设计模式的思想,可以让我们事半功倍。以下稍微整理了常用的设计模式、每一种设计模式都有详细的概念介绍、案例说明、代码实例、运行截图等。这里给出目录导航。 目录 一、创建型模式 【一…

现在的面试把我卷崩溃了....

现在的面试也太卷了,前几天组了一个软件测试面试的群,没想到效果直接拉满,看来大家对面试这块的需求还是挺迫切的。昨天我就看到群友们发的一些面经,感觉非常有参考价值,于是我就问他还有没有。 结果他给我整理了一份…