Pytorch入门(三)深度学习模型的训练的基本步骤

news2025/1/11 0:35:20

文章目录

  • 一、修改现有的网络模型
  • 二、模型的保存
  • 三、模型的加载
  • 四、模型的评估
  • 五、训练模型的完整套路
  • 六、使用GPU加速模型的训练
  • 七、模型训练完整的验证套路

一、修改现有的网络模型

import torchvision
from torch import nn
# pretrained 为True时会自动下载模型所对应的权重
vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)
print(vgg16_true)
# 向神经网络中添加训练层数
vgg16_true.add_module("linex",nn.Linear(1000,10))
print(vgg16_true)
# 修改神经网络模型中的,某一层
vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

这里加载了两个模型一个是带预训练权重的,一个是不带的。
pretrained为True时
在这里插入图片描述

pretrained为False时
在这里插入图片描述

我们可以通过上述的代码打印一下网络结构:

model1
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
VGG(
  (features): Sequential(
		    ...
		    ...
		    ...

  (linex): Linear(in_features=1000, out_features=10, bias=True)
)
VGG(
  (features): Sequential(
			  ...
			  ...
			  ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
	  		...
		  	...
		  	...
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
)

Process finished with exit code 0

可以发现vgg16_true模型多了一层 (linex): Linear(in_features=1000, out_features=10, bias=True)
vgg16_false中的

  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
	  		...
		  	...
		  	...
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )

变成了

  (classifier): Sequential(
	  		...
		  	...
		  	...
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )

二、模型的保存

训练好的模型要进行保存,模型的参数需要保存,模型的结构需要保存
最常用的有两种方式

  • ①直接保存torch.save()
    该方法有弊端,如果是自己搭建的模型,在加载的时候必须要有网络模型的声明
    将模型与保存的文件名传进去,该方式会将模型的参数与结构都保存下来
  • ②使用torch.save(vgg16.state_dict(),“文件名”)进行保存
    该方法会将模型的参数进行保存
import torch
import torch.nn as nn
import torchvision
# 加载没有经过训练的vgg16模型
vgg16=torchvision.models.vgg16(pretrained=False)
# 第一种模型保存方式(结构参数都进行保存)
torch.save(vgg16,"vgg16_save.pth")
# 第二种模型保存方式(只保存参数)
torch.save(vgg16.state_dict(),"vgg16_dict.pth")
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
    def forwark(self):
        pass

三、模型的加载

也是两种方式,分别对应上述两种保存方式。

import torch
from torch import nn
import torchvision
vgg=torchvision.models.vgg16(pretrained=False)
# 第一种加载方式
model_load=torch.load("vgg16_save.pth")
print(model_load)

# 第二种加载方式
# 打印model_load1是一些参数,没有目录结构
model_load1=torch.load("vgg16_dict.pth")
# 将参数传进加载函数中
model_load1=vgg.load_state_dict(model_load1)
print(model_load1)

或者直接使用二进制方式打开文件,将数据加载进来。

with open("pth/resnet18_200.pth",'rb') as f:
    resnet18.load_state_dict(torch.load(f))

四、模型的评估

模型的好坏通常需要使用测试集进行测试,tensor给出了很方便的测试方式。
可以使用argmax()方法,很容易的得出每行最高概率或者每列最高概率。

import torch
# output可以视为两个图片进行训练后得到的在三个类别概率分别是多少
# 进行训练之前会将图像数据,与标签放在两个数组内并对应
output= torch.tensor([
    [0.1,0.3,0.2],
    [0.3,0.4,0.7]])

# 获取到的是一列或一行数据对应概率值最大的位置对应的位置下标(以这个最大概率预测这个图像是什么)
# 参数为1的时候是对行进行操作
# 参数为0的时候会对列进行操作
print(output.argmax(1))
#这个传进去图像对应的类别(两个样本都是1)
targets=torch.tensor([1,1])
# 打印出预测准确的数据个数
true=(output.argmax(1)==targets).sum()
# 一个样本概率最高位置的下标是否与样本所打标签相同(相同代表预测正确)
# 可以通过下面方式批量对比,然后得出准确率
print((output.argmax(1)==targets).sum())
# 得出正确率
print(int(true)*100/2,'%')

五、训练模型的完整套路

大致可以分为以下几步:

  • 加载数据
  • 构建网络模型
  • 设置训练参数
  • 开始训练(如果想接着之前的训练可以先加载模型)
  • 模型保存与评估
import os

import torch.optim
import torchvision
# 准备数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from Model.model import Model
import torch.nn as nn
basepath=os.path.split(os.getcwd())[0]
# ----------------------------------------加载数据-----------------------------#
train_data=torchvision.datasets.CIFAR10(
            root=basepath+r"\数据集",
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True)
test_data=torchvision.datasets.CIFAR10(
    root=basepath+r"\数据集",
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True)
# 查看需要训练的数据长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print("训练集长度",train_data_size)
print("测试集长度",test_data_size)

# 加载数据集
train_dataloader=DataLoader(train_data,64)
test_dataloader=DataLoader(test_data,64)

# ------------------------------构建网络模型-----------------------------------#
# 创建网络模型
myModel=Model()

# 损失函数
loss_F=nn.CrossEntropyLoss()

# 优化器
learn_rate=0.01
optimizer=torch.optim.SGD(myModel.parameters(),lr=learn_rate)

writer=SummaryWriter(basepath+r"\logss\log_model")
# ------------------------------设置训练参数-----------------------------------#
# 记录训练次数
total_train_step=0
# 记录测试次数
total_test_step=0
# 训练的轮数
epoch=10
# 训练数据在所有类别中最大概率对应的类别与实际类别可以对照上的总数
total_true=0

# ----------------------------------开始训练-----------------------------------#
# ------------------
# 开始训练
# myModel.train()
# 开始测试
# myModel.eval()
# 这段话只针对某些神经层有意义
# ------------------
for i in range(epoch):
    print(f"-------------第{i+1}轮训练开始----------------")
    total_train_step = 0
    total_test_step = 0
    # 使用训练数据集对模型进行训练
    for data in train_dataloader:
        # 数据通过神经网络
        imgs,targets=data
        output=myModel(imgs)
        loss=loss_F(output,targets)
        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_step=total_train_step+1
        if total_train_step%100==0:
            # loss.item会将数值打印出来,不会带多余的东西
            print(f"训练{total_train_step}次,损失{loss.item()}")
            # 将数据加入到图像中
            writer.add_scalar("train_loss",loss.item(),total_train_step)

    # 使用测试数据集对模型进行测试,将得到的数据以图像的形式展示出来以及模型准确率评估
    total_test_loss=0
    with torch.no_grad():
        for data in test_dataloader:
            imgs,targets=data
            # 使用损失函数测试模型的好坏
            output=myModel(imgs)
            loss=loss_F(output,targets)
            total_test_loss=total_test_loss+loss.item()
            # 获取本轮数据能够对应的样本数
            temp_true=(output.argmax(1)==targets).sum()
            # 获取所有数据能够对应上的样本数
            total_true=total_true+temp_true
    print(f"整体测试集上的损失{total_test_loss}")
    print(f"整体测试集上的正确率{total_true/test_data_size}")
    writer.add_scalar("test_loss",total_test_loss,total_test_step)
    # 对训练好的模型进行测试,将正确率加入到图像中显示
    writer.add_scalar("test_accuracy",total_true/test_data_size,total_test_step)
    total_test_step=total_test_step+1

    # 模型的保存
    # 保存模型
    torch.save(myModel,os.getcwd()+rf"\Model\model_01\model_{i}.pth")
    # 保存模型对应的参数
    with open(os.getcwd()+rf"\Model\model_02.model_{i}.txt",'a') as f:
        f.write(str(total_test_loss))
    print("模型已保存!")
writer.close()

在这里插入图片描述
使用tensorboard查看训练过程。(过拟合大王哈哈哈哈)
在这里插入图片描述

六、使用GPU加速模型的训练

使用GPU训练我们的模型,可以提高很快的速度。

第一种调用GPU的方式(先判断是否有可用GPU)

# 创建网络模型
myModel=Model()
if torch.cuda.is_available():
    myModel=myModel.cuda()

# 损失函数
loss_F=nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_F=loss_F.cuda()

第二种调用GPU的方式(常用)

# gpu   cuda均可以加入其中,如果有多个gpu可以指定每一步使用那个gpu
# 先判断有没有可用GPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型到GPU
# 创建网络模型
myModel=Model()
myModel=myModel.to(device)

# 损失函数
loss_F=nn.CrossEntropyLoss()
loss_F=loss_F.to(device)

还有很多使用方式,大家可以在网上自己搜一下。通过这两种方式我们可以轻松地实现数据在GPU与CPU之间进行转换。

七、模型训练完整的验证套路

基本步骤是:

  • 加载测试数据集
  • 加载训练权重到模型
  • 要测试的图片通过模型预测
  • 对比预测出的标签与原始标签是否一致
  • 得出准确率
import os.path

import torch
import torch.nn as nn
from torchvision import transforms
import torchvision
from Model.model import Model
from PIL import Image
basepath=os.path.split(os.getcwd())[0]
# 处理要进行测试的数据
image_path1=basepath+r"\数据集\air.png"
image_path2=basepath+r"\数据集\dog.png"
image1=Image.open(image_path1)
image2=Image.open(image_path2)
# png有4种颜色通道,RGB还有一个透明度通道(要将png图片转换成rgb)
image1=image1.convert("RGB")
image2=image2.convert("RGB")
transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()])
img1=transform(image1)
img2=transform(image2)
# 导入训练好的模型
# 如果是通过GPU训练的模型导入的时候要进行 map_location=torch.device('cpu')参数的传递
print(basepath+r"\3.模型的训练\Model\model_01\model_9.pth")
myModel=torch.load(basepath+r"\3.模型的训练\Model\model_01\model_9.pth",map_location=torch.device('cpu'))
img1=torch.reshape(img1,(1,3,32,32))
img2=torch.reshape(img2,(1,3,32,32))
myModel.eval()
# 不进行反向传播,这里是测试不是训练
with torch.no_grad():
    output1=myModel(img1)
    output2=myModel(img2)
print(output1.argmax(1).item())
print(output2.argmax(1).item())
'''
0 飞机
5 狗
'''
'''
CIFAR10包含哪几类 这10类分别是airplane (飞机),automobile(汽车),bird(鸟),cat(猫),deer(鹿),
dog(狗),frog(青蛙),horse(马),ship(船)和truck(卡车)
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/596243.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ES6-ES13学习笔记(4.0)

includes函数 判断字符串是否存在指定字符 <!--* Author: RealRoad1083425287qq.com* Date: 2023-06-01 08:40:33* LastEditors: Mei* LastEditTime: 2023-06-01 08:58:54* FilePath: \vscode\ECMA\05\01.html* Description: * * Copyright (c) 2023 by ${git_name_ema…

Docker+Jenkins+Gitee自动化部署maven项目

1.简介 各位看官老爷&#xff0c;本文为Jenkins实战&#xff0c;注重实际过程&#xff0c;阅读完会有以下收获&#xff1a; 了解如何使用Docker安装Jenkins了解如何使用Jenkins部署maven项目了解如何使用JenkinsGitee实现自动化部署 2.Jenkins介绍 相信&#xff0c;正在读这…

美国频频对中国芯片出手,却没想到最先倒下的是美芯巨头

据报道指出全球知名的硬盘厂商西部数据已基本敲定与日本存储芯片巨头铠侠的合并计划&#xff0c;不过让人意外的是最终主导者将是铠侠而不是西部数据&#xff0c;这意味着西部数据将从此消失于历史之中。 西部数据是全球最大的硬盘厂商&#xff0c;它先后收购了知名硬盘厂商希捷…

【实用篇】Docker

文章目录 Docker实用篇1.初识Docker1.1.什么是Docker1.1.1.应用部署的环境问题1.1.2.Docker解决依赖兼容问题1.1.3.Docker解决操作系统环境差异1.1.4.小结 1.2.Docker和虚拟机的区别1.3.Docker架构1.3.1.镜像和容器1.3.2.DockerHub1.3.3.Docker架构1.3.4.小结 1.4.安装Docker1.…

springboot+vue高校班级管理系统 java 同学录校友录网站

本海滨学院班级回忆录管理员功能有个人中心&#xff0c;用户信息管理&#xff0c;班委信息管理&#xff0c;班级信息管理&#xff0c;加入班级管理&#xff0c;新闻信息管理&#xff0c;班级相册管理&#xff0c;活动信息管理&#xff0c;捐赠信息管理&#xff0c;论坛信息管理…

界面控件DevExpress ASP.NET新主题——Office 365暗黑主题的应用

DevExpress ASP.NET Web Forms Controls拥有针对Web表单&#xff08;包括报表&#xff09;的110种UI控件&#xff0c;DevExpress ASP.NET MVC Extensions是服务器端MVC扩展或客户端控件&#xff0c;由轻量级JavaScript小部件提供支持的70个高性能DevExpress ASP.NET Core Contr…

2023年6月跟教授学DAMA-CDGA/CDGP数据治理认证到这里

DAMA认证为数据管理专业人士提供职业目标晋升规划&#xff0c;彰显了职业发展里程碑及发展阶梯定义&#xff0c;帮助数据管理从业人士获得企业数字化转型战略下的必备职业能力&#xff0c;促进开展工作实践应用及实际问题解决&#xff0c;形成企业所需的新数字经济下的核心职业…

什么?要求设计一个循环队列?

&#x1f388;个人主页:&#x1f388; :✨✨✨初阶牛✨✨✨ &#x1f43b;推荐专栏: &#x1f354;&#x1f35f;&#x1f32f;C语言初阶 &#x1f354;&#x1f35f;&#x1f32f;C语言进阶 &#x1f511;个人信条: &#x1f335;知行合一 &#x1f349;本篇简介:>:讲解用c…

PortSwigger 基于不安全的反序列化漏洞

一、反序列化漏洞简单介绍 反序列化漏洞是指攻击者通过在应用程序中注入恶意的序列化对象来利用应用程序的反序列化功能&#xff0c;从而导致应用程序受到攻击的漏洞。 在一些编程语言和应用程序中&#xff0c;对象可以被序列化为一些字节流或字符串&#xff0c;然后在不同的应…

基于Java+SpringBoot+Vue前后端分离网课在线学习观看系统

博主介绍&#xff1a;✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

C#正则表达式的使用

C#正则表达式 System.Text.RegularExpressions.Regex 使用时需要引入命名空间 using System.Text.RegularExpressions; 如果不引用则写成 System.Text.RegularExpressions.Regex 使用方法如下&#xff1a; string str"测试123456"; string result""; re…

chatgpt赋能python:Python代码中的符号

Python代码中的符号 Python是一种简单易学的编程语言&#xff0c;拥有着广泛的应用领域&#xff0c;比如数据分析、人工智能、Web开发等等。在Python的编程过程中&#xff0c;符号是我们必须要熟悉的一部分。在本文中&#xff0c;我们将介绍Python代码中常见的符号&#xff0c…

华为OD机试真题B卷 Java 实现【人民币转换】,附详细解题思路

一、题目描述 考试题目和要点&#xff1a; 中文大写金额数字前应标明“人民币”字样。中文大写金额数字应用壹、贰、叁、肆、伍、陆、柒、捌、玖、拾、佰、仟、万、亿、元、角、分、零、整等字样填写。中文大写金额数字到“元”为止的&#xff0c;在“元”之后&#xff0c;应…

【Python PyInstaller】零基础也能轻松掌握的学习路线与参考资料

一、Python PyInstaller介绍 Python PyInstaller是一个用于将Python应用程序打包成可执行文件的工具&#xff0c;支持Windows、Mac OS X和Linux平台。使用PyInstaller可以方便地将Python应用程序和所需的依赖项&#xff08;包括Python解释器本身&#xff09;打包成一个独立的可…

JS的DOM对象获取元素

测试1 getElementById <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdev…

I.MX RT1170:如何在SRAM/SDRAM运行程序

一般Flash为non-XIP时&#xff0c;我们需要在RAM上运行程序。还有一种情况&#xff0c;就是我们不想每次调试都要将程序写入Flash&#xff0c;然后由BootROM进行代码的拷贝和跳转&#xff0c;这样可以减少Flash的烧写次数。本篇文章就来讨论一下如何实现这两种情形的RAM代码运行…

chatgpt赋能python:如何更好地理解Python代码

如何更好地理解Python代码 引言 Python是一种高级编程语言&#xff0c;它越来越受欢迎。由于Python内置的强大功能和易学性&#xff0c;许多开发者选择使用Python来开发应用程序。但是&#xff0c;有时候我们可能会面临一些难以理解的代码&#xff0c;尤其是在阅读其他人的代…

一天吃透Spring面试八股文

内容摘自我的学习网站&#xff1a;topjavaer.cn Spring是一个轻量级的开源开发框架&#xff0c;主要用于管理 Java 应用程序中的组件和对象&#xff0c;并提供各种服务&#xff0c;如事务管理、安全控制、面向切面编程和远程访问等。它是一个综合性框架&#xff0c;可应用于所有…

解决Wsl2中Ubuntu无法更新软件的问题

本文排版不太好&#xff0c;详情可见笔记 有道云笔记 安装wsl2之后&#xff0c;在Ubuntu中更新软件&#xff0c;执行apt-get update命令报错&#xff0c;如下 rootjiangcheng01:~# sudo apt-get update Ign:1 http://mirrors.aliyun.com/ubuntu groovy InRelease Ign:2 http:…

一个注解的事儿,数据脱敏解决了

目录 什么是数据脱敏开整使用 Hutool 工具类实现数据掩码Hutool 信息脱敏工具类使用 Jackson 进行数据序列化脱敏 注解实现数据脱敏1、定义一个注解2、创建一个枚举类3、创建我们的自定义序列化类4、测试 项目 pom 文件 总结 本文主要分享什么是数据脱敏&#xff0c;如何优雅的…