pytorch学习(四)绘制loss和correct曲线

news2024/12/28 2:21:01

这一次学习的时候静态绘制loss和correct曲线,也就是在模型训练完成后,对统计的数据进行绘制。

以minist数据训练为例子

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')



trainning_data =datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)

train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)




print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)



class MinistNet(nn.Module):
    def __init__(self):
        super().__init__()
        # self.flat = nn.Flatten()
        self.conv1 = nn.Conv2d(1,1,3,1,1)
        self.hideLayer1 = nn.Linear(28*28,256)
        self.hideLayer2 = nn.Linear(256,10)
    def forward(self,x):
        x= self.conv1(x)
        x = x.view(-1,28*28)
        x = self.hideLayer1(x)
        x = torch.sigmoid(x)
        x = self.hideLayer2(x)
        # x = nn.Sigmoid(x)
        return x

model = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)

def train():
    train_losses = []
    train_acces = []
    eval_losses = []
    eval_acces = []
    #训练
    model.train()
    for epoch in range(10):
        batchsizeNum = 0
        train_loss = 0
        train_acc = 0
        train_correct = 0
        for x,y in train_loader:
            # print(epoch)
            # print(x.shape)
            # print(y.shape)
            x = x.to('cuda')
            y = y.to('cuda')
            bte = type(x)==torch.Tensor
            bte1 = type(y)==torch.Tensor
            A = x.device
            B = y.device
            pred_y = model(x)
            loss = criterion(pred_y,y)
            optimer.zero_grad()
            loss.backward()
            optimer.step()
            loss_val = loss.item()
            batchsizeNum = batchsizeNum +1
            train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()
            train_loss += loss.item()
            # print("loss: ",loss_val,"  ",epoch, "  ", batchsizeNum)
        train_losses.append(train_loss / len(trainning_data))
        train_acces.append(train_acc / len(trainning_data))

        #测试
        model.eval()
        with torch.no_grad():
            num_batch = len(test_data)
            numSize = len(test_data)
            test_loss, test_correct = 0,0
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                pred_y = model(x)
                test_loss += criterion(pred_y, y).item()
                test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()
            test_loss /= num_batch
            test_correct /= numSize
            eval_losses.append(test_loss)
            eval_acces.append(test_correct)
            print("test result:",100 * test_correct,"%  avg loss:",test_loss)
        PATH = "dict_model_%d_dict.pth"%(epoch)
        torch.save({"epoch": epoch,
                    "model_state_dict": model.state_dict(), }, PATH)


    plt.plot(np.arange(len(train_losses)), train_losses, label="train loss")

    plt.plot(np.arange(len(train_acces)), train_acces, label="train acc")

    plt.plot(np.arange(len(eval_losses)), eval_losses, label="valid loss")

    plt.plot(np.arange(len(eval_acces)), eval_acces, label="valid acc")
    plt.legend()  # 显示图例
    plt.xlabel('epoches')
    # plt.ylabel("epoch")
    plt.title('Model accuracy&loss')
    plt.show()

    torch.save(model,"mode_con_line2.pth")#保存网络模型结构
    # torch.save(model,) #保存模型中的参数
    torch.save(model.state_dict(),"model_dict.pth")






# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    train()

绘制的图如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1933318.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GESP CCF C++ 三级认证真题 2024年6月

第 1 题 小杨父母带他到某培训机构给他报名参加CCF组织的GESP认证考试的第1级,那他可以选择的认证语言有()种。 A. 1 B. 2 C. 3 D. 4 第 2 题 下面流程图在yr输入2024时,可以判定yr代表闰年,并输出 2月是29天 &#x…

python-字符金字塔(赛氪OJ)

[题目描述] 请打印输出一个字符金字塔,字符金字塔的特征请参考样例。输入格式: 输入一个字母,保证是大写。输出格式: 输出一个字母金字塔,输出样式见样例。样例输入 C样例输出 A ABA …

【前端8】element ui常见页面布局:注意事项

【前端8】element ui常见页面布局:注意事项 写在最前面遇到的问题Element UI 常见页面布局:注意事项1. 了解基本布局组件常用的菜单1多一个下角 常用的菜单2 2. 栅格系统的使用3. 响应式布局4. Flex 布局的应用5. 避免滥用嵌套6. 处理边距和填充 小结 &a…

基于STC89C51单片机的烟雾报警器设计(煤气火灾检测报警)(含文档、源码与proteus仿真,以及系统详细介绍)

本篇文章论述的是基于STC89C51单片机的烟雾报警器设计的详情介绍,如果对您有帮助的话,还请关注一下哦,如果有资源方面的需要可以联系我。 目录 摘要 原理图 实物图 仿真图 元件清单 代码 系统论文 资源下载 摘要 随着现代家庭用火、…

TikTok内嵌跨境商城全开源_搭建教程/前端uniapp+后端源码

多语言跨境电商外贸商城 TikTok内嵌商城,商家入驻一键铺货一键提货 全开源完美运营,接在tiktok里面的商城内嵌,也可单独分开出来当独立站运营 二十一种语言,可以做很多国家的市场,支持商家入驻,多店铺等等…

服务器IP和电脑IP有什么不同

服务器IP和电脑IP有什么不同?在当今的信息化时代,IP地址作为网络世界中不可或缺的元素,扮演着举足轻重的角色。然而,对于非专业人士来说,服务器IP和电脑IP之间的区别往往模糊不清。本文旨在深入探讨这两者之间的不同&a…

若依前端和后端时间相差8小时

原因基类未设置时区 实体类继承 BaseEntity 加上timezone"GMT8" /** 创建时间 */ JsonFormat(pattern "yyyy-MM-dd HH:mm:ss" , timezone"GMT8") private Date createTime; 解决

golang程序性能提升改进篇之文件的读写---第一篇

背景:接手的项目是golang开发的(本人初次接触golang)经常出现oom。这个程序是计算和io密集型,调用流量属于明显有波峰波谷,但是因为各种原因,当前无法快速通过serverless或者动态在高峰时段调整资源&#x…

MViTv2:Facebook出品,进一步优化的多尺度ViT | CVPR 2022

论文将Multiscale Vision Transformers (MViTv2) 作为图像和视频分类以及对象检测的统一架构进行研究,结合分解的相对位置编码和残差池化连接提出了MViT的改进版本 来源:晓飞的算法工程笔记 公众号 论文: MViTv2: Improved Multiscale Vision Transforme…

Fiddler抓包过滤host及js、css等地址

1、如上图所示 在Filter页面中勾选Hide if URL contains;输入框输入 REGEX:\.(js|css|png|google|favicon\?.*) 隐藏掉包含js、css、png、google等的地址: Hide if URL contains: REGEX:\.(js|css|png|google|favicon\?.*) 2、使Filters设置生效 A…

微软新版WSL 2.3.11子系统带来“数百个新内核模块“和新功能

微软今天发布了新版的 Windows Subsystem for Linux(WSL)。与当前的 WSL 2.2.4 稳定版相比,WSL 2.3.11 具有许多特性:它从旧版的 Linux 5.15 LTS 内核转到了 Linux 6.6LTS内核。今天的发布说明指出,WSL 2.3.11 基于 Linux 6.6.36.3&#xff0…

【C++刷题】[UVA 489]Hangman Judge 刽子手游戏

题目描述 题目解析 这一题看似简单其实有很多坑,我也被卡了好久才ac。首先题目的意思是,输入回合数,一个答案单词,和一个猜测单词,如果猜测的单词里存在答案单词里的所有字母则判定为赢,如果有一个字母是答…

力扣622.设计循环队列

力扣622.设计循环队列 通过数组索引构建一个虚拟的首尾相连的环当front rear时 队列为空当front rear 1时 队列为满 (最后一位不存) class MyCircularQueue {int front;int rear;int capacity;vector<int> elements;public:MyCircularQueue(int k) {//最后一位不存…

基于python的三次样条插值原理及代码

1 三次样条插值 1.1 三次样条插值的基本概念 三次样条插值是通过求解三弯矩方程组&#xff08;即三次样条方程组的特殊形式&#xff09;来得出曲线函数组的过程。在实际计算中&#xff0c;还需要引入边界条件来完成计算。样条插值的名称来源于早期工程师制图时使用的细长木条&…

【机器学习】--过采样原理及代码详解

过采样&#xff08;Oversampling&#xff09;是一个在多个领域都有应用的技术&#xff0c;其具体含义和应用方法会根据领域的不同而有所差异。以下是对过采样技术的详细解析&#xff0c;主要从机器学习和信号处理两个领域进行阐述。 一、机器学习中的过采样 在机器学习中&…

未来的社交标杆:如何通过AI让Facebook更加智能化?

在当今信息爆炸的时代&#xff0c;社交媒体平台的智能化已成为提高用户体验和互动质量的关键因素。Facebook&#xff0c;作为全球最大的社交平台之一&#xff0c;通过人工智能&#xff08;AI&#xff09;的广泛应用&#xff0c;正不断推进其智能化进程。本文将探讨Facebook如何…

Qt日志库QsLog使用教程

前言 最近项目中需要用到日志库。上一次项目中用到了log4qt库&#xff0c;这个库有个麻烦的点是要配置config文件&#xff0c;所以这次切换到了QsLog。用了后这个库的感受是&#xff0c;比较轻量级&#xff0c;嘎嘎好用&#xff0c;推荐一波。 下载QsLog库 https://github.c…

CSS技巧专栏:一日一例 7 - 纯CSS实现炫光边框按钮特效

CSS技巧专栏&#xff1a;一日一例 7 - 纯CSS实现炫光边框按钮特效 本例效果图 案例分析 相信你可能已经在网络见过类似这样的流光的按钮&#xff0c;在羡慕别人做的按钮这么酷的时候&#xff0c;你有没有扒一下它的源代码的冲动&#xff1f;或者你当时有点冲动&#xff0c;却…

在Oxygen中比较两个目录的差异,用于编写手册两个版本的变更说明

▲ 搜索“大龙谈智能内容”关注公众号▲ 当我们对手册进行改版的时候&#xff0c;我们通常需要编写变更说明&#xff0c;如下图&#xff1a; 改版通常会改动很多文件的很多地方&#xff0c;如何知道哪些地方更改了呢&#xff1f; Oxygen提供了比较两个目录的功能&#xff0c…

载均衡技术全解析:Pulsar 分布式系统的最佳实践

背景 Pulsar 有提供一个查询 Broker 负载的接口&#xff1a; /*** Get load for this broker.** return* throws PulsarAdminException*/ LoadManagerReport getLoadReport() throws PulsarAdminException;public interface LoadManagerReport extends ServiceLookupData { Re…