pytorch -- CIFAR10 完整的模型训练套路

news2024/9/29 21:33:46
  1. 网络结构
    在这里插入图片描述
  2. 代码
# CIFAR 10
'''
完整的模型训练套路:

'''
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import *

# 1. 准备数据集
train_data = torchvision.datasets.CIFAR10('data',train=True,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=True)
# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('训练数据集的长度为{}'.format(train_data_size))
print('测试数据集的长度为{}'.format(test_data_size))

# 2 利用DataLoader加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

# 3 搭建神经网络
# 4 创建网络模型
tudui = Tudui()

# 5 损失函数
loss_fn = nn.CrossEntropyLoss()

# 6 优化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)

# 7 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 #训练轮数
# 添加tensorboard
writer = SummaryWriter('logs_model')

for i in range(epoch):
    print('-----------第{}轮训练开始-----------'.format(i+1))
    # 训练开始
    # 训练步骤开始 dropout batchNorm仅对某些层次有作用
    tudui.train()
    for data in train_dataloader:
        imgs, targets = data
        output = tudui(imgs) #训练模型的预测输出
        loss = loss_fn(output,targets)
        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print('训练次数是{}时,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor变成了数字
            writer.add_scalar('train_loss',loss.item(),total_train_step)

    # 训练完一轮,看是否训练好,有没有达到想要的需求,测试数据集中跑一篇看准确率或者损失
    # 测试步骤开始
    tudui.eval()
    total_test_loss = 0
    total_accuracy = 0
    # 测试不需要对梯度进行调整
    with torch.no_grad():
        for data in test_dataloader:
            imgs,targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs,targets)
            total_test_loss += loss.item()
            # accuracy 正确预测的样本数量
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    print('整体测试集上的loss是{}'.format(total_test_loss))
    print('整体测试集上的正确率是{}'.format(total_accuracy/test_data_size))
    writer.add_scalar('test_loss',total_test_loss,total_test_step)
    writer.add_scalar('test_accuracy', total_accuracy, total_test_step)
    total_test_step+=1

    torch.save(tudui,'tudui_{}.pth'.format(i))
    print('模型已保存')

writer.close()

# model.py
import torch
from torch import nn

# 3 搭建神经网络
class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024,64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

if __name__ == '__main__':
    tudui = Tudui()
    # 验证一下输入输出尺寸
    input = torch.ones((64,3,32,32))
    output = tudui(input)
    print(output.shape)

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1472992.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用 React 实现搜索 GitHub 用户功能

用 React 实现搜索 GitHub 用户功能 在本篇博客中,我们将介绍如何在 React 应用中搜索 GitHub 用户并显示他们的信息。 创建 React 应用 首先,我们使用 Create React App 创建一个新的 React 应用。Create React App 是一个快速搭建 React 项目的工具…

【QT+QGIS跨平台编译】之五十四:【QGIS_CORE跨平台编译】—【qgssqlstatementlexer.cpp生成】

文章目录 一、Flex二、生成来源三、构建过程一、Flex Flex (fast lexical analyser generator) 是 Lex 的另一个替代品。它经常和自由软件 Bison 语法分析器生成器 一起使用。Flex 最初由 Vern Paxson 于 1987 年用 C 语言写成。 “flex 是一个生成扫描器的工具,能够识别文本中…

Vue笔记(一)

常用指令 1.v-show与v-if底层原理的区别 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>创建一个V…

【教3妹学编程-算法题】匹配模式数组的子数组数目 I

3妹&#xff1a;2哥2哥&#xff0c;你有没有看到上海女老师出轨男学生的瓜啊。 2哥 : 看到 了&#xff0c;真的是太毁三观了&#xff01; 3妹&#xff1a;是啊&#xff0c; 老师本是教书育人的职业&#xff0c;明确规定不能和学生谈恋爱啊&#xff0c;更何况是出轨。 2哥 : 是啊…

好用的IP反查接口

IP-API.com - Geolocation API - Documentation - JSON 自定义返回参数调用&#xff08;1&#xff09;&#xff1a; http://ip-api.com/json/24.48.0.1?fieldsstatus,message,country,countryCode,region,regionName,cityhttp://ip-api.com/json/24.48.0.1?fieldscountry,co…

读人工不智能:计算机如何误解世界笔记04_数据新闻学

1. 计算化和数据化的变革 1.1. 每一个领域都在进行计算化和数据化的变革 1.1.1. 出现了计算社会科学、计算生物学、计算化学或其他数字人文学科 1.1.2. 生活已走向计算化&#xff0c;人们却一点也没有变 1.2. 在如今的计算化和数据化世界中&#xff0c;调查性新闻的实践必须…

vue - - - - Vue3+i18n多语言动态国际化设置

Vue3i18n多语言动态国际化设置 前言一、 i18n 介绍二、插件安装三、i18n配置3.1 创建i18n对应文件夹/文件3.2 en-US.js3.3 zh-CN.js3.4 index.js 四、 mian.js 引入 i18n配置文件五、 组件内使用六、使用效果 前言 继续【如何给自己的网站添加中英文切换】一文之后&#xff0c…

【MySQL】SQL 优化

MySQL - SQL 优化 1. 在 MySQL 中&#xff0c;如何定位慢查询&#xff1f; 1.1 发现慢查询 现象&#xff1a;页面加载过慢、接口压力测试响应时间过长&#xff08;超过 1s&#xff09; 可能出现慢查询的场景&#xff1a; 聚合查询多表查询表数据过大查询深度分页查询 1.2 通…

【Flink精讲】Flink反压调优

Flink 网络流控及反压的介绍&#xff1a; Apache Flink学习网 反压的理解 简单来说&#xff0c; Flink 拓扑中每个节点&#xff08;Task&#xff09;间的数据都以阻塞队列的方式传输&#xff0c;下游来不及消费导致队列被占满后&#xff0c;上游的生产也会被阻塞&#xff0c;…

Jessibuca 插件播放直播流视频

jessibuca官网&#xff1a;http://jessibuca.monibuca.com/player.html git地址&#xff1a;https://gitee.com/huangz2350_admin/jessibuca#https://gitee.com/link?targethttp%3A%2F%2Fjessibuca.monibuca.com%2F 项目需要的文件 1.播放组件 <template ><div i…

Qt项目:网络1

文章目录 项目&#xff1a;网路项目1&#xff1a;主机信息查询1.1 QHostInfo类和QNetworkInterface类1.2 主机信息查询项目实现 项目2&#xff1a;基于HTTP的网络应用程序2.1 项目中用到的函数详解2.2 主要源码 项目&#xff1a;网路 项目1&#xff1a;主机信息查询 使用QHostI…

浅析ARMv8体系结构:原子操作

文章目录 概述LL/SC机制独占内存访问指令多字节独占内存访问指令 独占监视器经典自旋锁实现 LSE机制原子内存操作指令CAS指令交换指令 相关参考 概述 在编程中&#xff0c;当多个处理器或线程访问共享数据&#xff0c;并且至少有一个正在写入时&#xff0c;操作必须是原子的&a…

JAVA集合进阶(Set、Map集合)

一、Set系列集合 1.1 认识Set集合的特点 Set集合是属于Collection体系下的另一个分支&#xff0c;它的特点如下图所示 下面我们用代码简单演示一下&#xff0c;每一种Set集合的特点。 //Set<Integer> set new HashSet<>(); //无序、无索引、不重复 //Set<…

【kubernetes】关于k8s集群中kubectl的陈述式资源管理

目录 一、k8s集群资源管理方式分类&#xff1a; &#xff08;1&#xff09;陈述式资源管理方式&#xff1a;增删查比较方便&#xff0c;但是改非常不方便 &#xff08;2&#xff09;声明式资源管理方式&#xff1a;yaml文件管理 二、陈述式资源管理方法&#xff1a; 三、ku…

计算机设计大赛 深度学习大数据物流平台 python

文章目录 0 前言1 课题背景2 物流大数据平台的架构与设计3 智能车货匹配推荐算法的实现**1\. 问题陈述****2\. 算法模型**3\. 模型构建总览 **4 司机标签体系的搭建及算法****1\. 冷启动**2\. LSTM多标签模型算法 5 货运价格预测6 总结7 部分核心代码8 最后 0 前言 &#x1f5…

Nodejs 第四十二章(jwt)

什么是jwt? JWT&#xff08;JSON Web Token&#xff09;是一种开放的标准&#xff08;RFC 7519&#xff09;&#xff0c;用于在网络应用间传递信息的一种方式。它是一种基于JSON的安全令牌&#xff0c;用于在客户端和服务器之间传输信息。 https://jwt.io/ JWT由三部分组成&…

FL Studio All Plugins Edition2024中文完整版Win/Mac

FL Studio All Plugins Edition&#xff0c;常被誉为数字音频工作站&#xff08;DAW&#xff09;的佼佼者&#xff0c;是音乐制作人和声音工程师钟爱的工具。它集音频录制、编辑、混音以及MIDI制作为一体&#xff0c;为用户提供了从创作到最终作品输出的完整工作流程。这个版本…

下载huggingface数据集到本地并读取.arrow文件遇到的问题

文章目录 1. 524MB中文维基百科语料&#xff08;需要下载的数据集&#xff09;2. 下载 hugging face 网站上的数据集3. 读取 .arrow 文件报错代码4. 纠正后代码 1. 524MB中文维基百科语料&#xff08;需要下载的数据集&#xff09; 2. 下载 hugging face 网站上的数据集 要将H…

CrossOver24破解版下载安装与激活

在 Mac 上运行Windows 软件&#xff0c;CrossOver Mac 可以轻松地从 Dock 本地启动 Windows 应用程序&#xff0c;并将 Mac 操作系统功能&#xff08;如跨平台复制和粘贴以及共享文件系统&#xff09;集成到您的 Windows 程序中。 CrossOver 产品特性 无需重启 CrossOver 可以…

python Matplotlib Tkinter(pack)-->最大化去掉,仅剩关闭按钮

环境 python:python-3.12.0-amd64 包: matplotlib 3.8.2 pillow 10.1.0 import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk import tkinter as tk import tkinter.messagebox as messagebox import …