PyTorch快速入门教程【小土堆】之完整模型训练套路

news2025/1/7 14:55:34

视频地址完整的模型训练套路(一)_哔哩哔哩_bilibili

import torch
import torchvision
from model import *
from torch import nn
from torch.utils.data import DataLoader

# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
# Length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 如果train_data_size=10,训练数据集的长度为:10
# print("训练数据集的长度为: {}".format(train_data_size))
# print("测试数据集的长度为: {}".format(test_data_size))

# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型
tudui = Tudui()

# 损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

for i in range(epoch):
    print("--------第{}轮训练开始---------".format(i + 1))

    # 训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))

    # 测试步骤开始
    total_test_loss =0
    total_accuracy = 0
    with torch.no_grad():# 保证不会调优
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy

    print("整体测试集上的Loss: {}".format(total_test_loss))
    print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))
    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2271872.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蓝桥杯备赛:C++基础,顺序表和vector(STL)

目录 一.C基础 1.第一个C程序: 2.头文件: 3.cin和cout初识: 4.命名空间: 二.顺序表和vector(STL) 1.顺序表的基本操作: 2.封装静态顺序表: 3.动态顺序表--vector:…

node.js之---事件循环机制

事件循环机制 Node.js 事件循环机制(Event Loop)是其核心特性之一,它使得 Node.js 能够高效地处理大量并发的 I/O 操作。Node.js 基于 非阻塞 I/O,使用事件驱动的模型来实现异步编程。事件循环是 Node.js 实现异步编程的基础&…

如何在 Ubuntu 22.04 上部署 Nginx 并优化以应对高流量网站教程

简介 本教程将教你如何优化 Nginx,使其能够高效地处理高流量网站。 Nginx 是一个强大且高性能的 Web 服务器,以其高效处理大量并发连接的能力而闻名,这使得它成为高流量网站的流行选择。 正确优化 Nginx 可以显著提高服务器的性能&#xff0…

AIRemoveBackground:用 AI 技术轻松去除背景图的前端程序

在当今数字化时代,图像处理技术不断发展,其中 AI 去除背景图的功能备受关注。本文将介绍一款名为 AIRemoveBackground 的前端程序,它利用人工智能技术,为用户提供便捷、高效的背景去除解决方案。 一、简介 随着互联网的普及和多媒…

【踩坑指南2.0 2025最新】Scala中如何在命令行传入参数以运行主函数

这个地方基本没有任何文档记录,在学习的过程中屡屡碰壁,因此记录一下这部分的内容,懒得看可以直接跳到总结看结论。 踩坑步骤 首先来看看书上让我们怎么写: //main.scala object Start {def main(args:Array[String]) {try {v…

Excel VBA 自动填充空白并合并相同值的解决方案

文章目录 Excel VBA: 自动填充空白并合并相同值的解决方案问题背景解决方案1. VBA代码实现2. 代码说明3. 使用方法4. 注意事项 扩展优化总结 Excel VBA: 自动填充空白并合并相同值的解决方案 问题背景 在Excel中经常会遇到这样的数据处理需求:一列数据中存在多个空…

SpringSecurity中的过滤器链与自定义过滤器

关于 Spring Security 框架中的过滤器的使用方法,系列文章: 《SpringSecurity中的过滤器链与自定义过滤器》 《SpringSecurity使用过滤器实现图形验证码》 1、Spring Security 中的过滤器链 Spring Security 中的过滤器链(Filter Chain)是一个核心的概念,它定义了一系列过…

【STC库函数】Compare比较器的使用

如果我们需要比较两个点的电压,当A点高于B点的时候我们做一个操作,当B点高于A点的时候做另一个操作。 我们除了加一个运放或者比较器,还可以直接使用STC内部的一个比较器。 正极输入端可以是P37、P50、P51,或者从ADC的十六个通道…

Postgresql 命令还原数据库

因为PgAdmin打不开,但是数据库已经安装成功了,这里借助Pg命令来还原数据库 C:\Program Files\PostgreSQL\15\bin\psql.exe #链接数据库 psql -U postgres -p 5432#创建数据库 CREATE DATABASE "数据库名称"WITHOWNER postgresENCODING UTF8…

Backend - C# 的日志 NLog日志

目录 一、注入依赖和使用 logger 二、配置记录文件 1.安装插件 NLog 2.创建 nlog.config 配置文件 3. Programs配置日志信息 4. 设置 appsettings.json 的 LogLevel 5. 日志设定文件和日志级别的优先级 (1)常见的日志级别优先级 (2&…

急需升级,D-Link 路由器漏洞被僵尸网络广泛用于 DDoS 攻击

僵尸网络活动增加 :新的“FICORA”和“CAPSAICIN”僵尸网络(Mirai 和 Kaiten 的变体)的活动激增。 被利用的漏洞 :攻击者利用已知的 D-Link 路由器漏洞(例如 CVE-2015-2051、CVE-2024-33112)来执行恶意命…

[ubuntu-22.04]ubuntu不识别rtl8153 usb转网口

问题描述 ubuntu22.04插入rtl8153 usb转网口不识别 解决方案 安装依赖包 sudo apt-get install libelf-dev build-essential linux-headers-uname -r sudo apt-get install gcc-12 下载源码 Realtek USB FE / GBE / 2.5G / 5G Ethernet Family Controller Softwarehttps:/…

WinForm开发-自定义组件-1. 工具栏: UcompToolStrip

这里写自定义目录标题 1. 工具栏: UcompToolStrip1.1 展示效果1.2 代码UcompToolStrip.csUcompToolStrip.Designer.cs 1. 工具栏: UcompToolStrip 自定义一些Winform组件 1.1 展示效果 1)使用效果 2)控件事件 1.2 代码 设计 编码 UcompToolStrip.…

Hypium纯血鸿蒙系统 HarmonyOS NEXT自动化测试框架

1、什么是Hypium Hypium是华为官方为鸿蒙操作系统开发的一款以python为语言的自动化测试框架。 引用华为官网介绍如下: DevEco Testing Hypium(以下简称Hypium)是HarmonyOS平台的UI自动化测试框架,支持开发者使用python语言为应用编写UI自动化测试脚本…

基于Spring Boot微信小程序电影管理系统

一、系统背景与意义 随着移动互联网的普及和用户对个性化娱乐需求的不断增长,电影行业迎来了新的发展机遇。然而,传统的电影管理方式存在信息不对称、购票流程繁琐、用户体验不佳等问题。因此,开发一个基于Spring Boot微信小程序的电影管理系…

软件工程实验-实验2 结构化分析与设计-总体设计和数据库设计

一、实验内容 1. 绘制工资支付系统的功能结构图和数据库 在系统设计阶段,要设计软件体系结构,即是确定软件系统中每个程序是由哪些模块组成的,以及这些模块相互间的关系。同时把模块组织成良好的层次系统:顶层模块通过调用它的下层…

深度学习blog- 数学基础(全是数学)

矩阵‌:矩阵是一个二维数组,通常由行和列组成,每个元素可以通过行索引和列索引进行访问。 张量‌:张量是一个多维数组的抽象概念,可以具有任意数量的维度。除了标量(0D张量)、向量(…

JMH338-剑侠情缘2【开服端】-2017版【剑荡三界】+服务端+客户端+登录器+外网

资源介绍: 激情服;剑荡三界基本上可以直接开服玩,总之每个服都有他的特色;云中,红莲山,葬雪城三大地图三种世界BOSS每个小时刷一次 云中押镖劫镖,出城就是PK模式 剑荡烟云副本分为普通和难度…

QML自定义滑动条Slider的样式

代码展示 import QtQuick 2.9 import QtQuick.Window 2.2 import QtQuick.Controls 2.1Window {visible: truewidth: 640height: 480title: qsTr("Hello World")Slider {id: controlvalue: 0.5background: Rectangle {x: control.leftPaddingy: control.topPadding …

什么是.net framework,什么是.net core,什么是.net5~8,版本对应关系

我不知道有多少人和我一样,没学习过.netCore,想要学习,但是版本号太多就蒙了,不知道学什么了,这里解释下各个版本的关系 我们一般开始学习微软的时候,都是开始学习的.netframework,常用的就是4…