训练与优化

news2025/2/21 10:25:25

训练与优化

损失函数与反向传播

损失函数能够衡量神经网络输出与目标值之间的误差,同时为反向传播提供依据,计算梯度来优化网络中的参数。

torch.nn.L1Loss 计算所有预测值与真实值之间的绝对差。参数为 reduction

  • 'none':不对损失进行任何求和或平均,返回每个元素的损失。
  • 'mean':对损失进行平均,默认选项。
  • 'sum':对所有样本的损失进行求和。
import torch

input = torch.tensor([1, 2, 3], dtype=torch.float32)
target = torch.tensor([1, 3, 5], dtype=torch.float32)

loss = torch.nn.L1Loss(reduction="none")
res = loss(input, target)
print(res)
# tensor([0., 1., 2.])

loss = torch.nn.L1Loss(reduction="mean")
res = loss(input, target)
print(res)
# tensor(1.)

loss = torch.nn.L1Loss(reduction="sum")
res = loss(input, target)
print(res)
# tensor(3.)

torch.nn.MSELoss 计算每个样本的预测值与真实值之间的差距的平方,参数为 reduction

import torch

input = torch.tensor([1, 2, 3], dtype=torch.float32)
target = torch.tensor([1, 3, 5], dtype=torch.float32)

loss = torch.nn.MSELoss(reduction="none")
res = loss(input, target)
print(res)
# tensor([0., 1., 4.])

loss = torch.nn.MSELoss(reduction="mean")
res = loss(input, target)
print(res)
# tensor(1.6667)

loss = torch.nn.MSELoss(reduction="sum")
res = loss(input, target)
print(res)
# tensor(5.)

torch.nn.CrossEntropyLoss 计算实际类别分布预测类别分布之间的差异。输入 input 为预测的类别得分(不是概率),维度为 (N,C) ,其中 N 是样本数量,C 是类别数量,每个样本是一个未经过softmax 的类别得分。真实标签索引 target 维度为 (N) ,每个标签是一个整数,表示该样本的真实类别索引。

CrossEntropyLoss自动计算 input 的 softmax ,然后根据交叉熵公式计算每个样本的损失。

import torch
from torch import nn

# 2个样本,3个类别的得分
input = torch.tensor([[1, 2, 3], [1, 2, 3]], dtype=torch.float32)
# 真实标签:第1个样本属于类别2,第2个样本属于类别1
target = torch.tensor([2, 1])

loss = nn.CrossEntropyLoss()

res = loss(input, target)
print(res)
# tensor(0.9076)

如果数据集中的类别不平衡,可以通过 weight 参数对每个类别的损失进行加权。这样可以让模型在训练时更加关注某些类别。

import torch
from torch import nn

# 2个样本,3个类别的得分
input = torch.tensor([[1, 2, 3], [1, 2, 3]], dtype=torch.float32)
# 真实标签:第1个样本属于类别2,第2个样本属于类别1
target = torch.tensor([2, 1])

# 类别0权重为1,类别1权重为2,类别2权重为0.5
weight = torch.tensor([1.0, 2.0, 0.5])
loss = nn.CrossEntropyLoss(weight)

res = loss(input, target)
print(res)
# tensor(1.2076)

当计算出损失函数后,便可计算出每一个节点参数的梯度,从而进行反向传播,只需要加上一行:

result_loss.backward()

训练与推理

在 PyTorch 中,神经网络的 train()eval() 模式控制着 Batch NormalizationDropout 这两类层的行为,确保模型在训练和推理(测试)时的表现一致。

model.train() 负责启动 BN 和 Dropout 层的训练模式。BatchNorm 会计算当前批次的均值和方差,用于归一化数据,这些均值和方差会随着训练逐步更新。Dropout 会随机丢弃一部分神经元,以减少过拟合。

model.eval() 负责关闭训练模式,进入推理模式,确保计算的均值、方差、Dropout 影响不会波动,保证结果稳定。计算归一化时,会使用训练期间学到的全局均值和方差,而不是当前批次的统计量。也不再随机丢弃神经元,而是使用完整的网络进行预测。

在训练的时候,还需要关闭梯度计算,减少内存占用,加速推理。因为推理时不需要计算梯度,不需要 backward() 进行反向传播。

with torch.no_grad():
    output = model(input)

train() 模式下,PyTorch 默认存储计算图,以支持 backward() 计算梯度torch.no_grad() 关闭计算图,避免存储不必要的梯度信息,减少显存占用。

训练模式

model.train()  # 训练模式(启用 BatchNorm 统计 和 Dropout)
for data in dataloader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

推理模式

model.eval()  # 进入推理模式
with torch.no_grad():  # 关闭梯度计算
    output = model(input)

优化器

优化器利用通过反向传播计算得到的梯度来更新模型参数,从而减小损失函数值,提升模型的性能。

在每次训练过程中,首先使用 optimizer.zero_grad() 清零上一步的梯度,然后通过 loss.backward() 执行反向传播,计算当前模型参数的梯度,最后使用 optimizer.step() 根据梯度更新模型参数。

**SGD(随机梯度下降)**是基本的梯度下降法,每次更新一个小批量的数据(mini-batch)参数,需要调整学习率(lr)和可能的动量(momentum)等超参数。

Adam、Adagrad、Adadelta、RMSProp 是不同的优化算法,每种算法有不同的超参数调整方法,Adam自适应调整学习率,Adagrad适用于稀疏数据,Adadelta主要针对自适应学习率的调整。

学习速率不能太大(太大模型训练不稳定)也不能太小(太小模型训练慢),一般建议先采用较大学习速率,后采用较小学习速率。

优化器构造方法:

# SGD(Stochastic Gradient Descent) 随机梯度下降
# 模型参数、学习速率、动量
**optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)**  

优化器调用方法:

for input, target in dataset:
    optimizer.zero_grad()            # 清空梯度
    output = model(input)
    res= loss(output, target)        # 计算损失函数
    res.backward()                   # 反向传播计算梯度
    optimizer.step()                 # 根据梯度优化参数

以 CIFAR-10 数据集为例:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 加载数据集
dataset = torchvision.datasets.CIFAR10(root="Dataset", train=False,
                                       transform=torchvision.transforms.ToTensor(), download=False)
# 批量加载数据
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
writer = SummaryWriter("logs")

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(in_features=1024, out_features=64),
            nn.Linear(in_features=64, out_features=10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

model = Model()
# 定义损失函数
loss = torch.nn.CrossEntropyLoss()
# 定义优化器
**optimizer = torch.optim.SGD(model.parameters(), lr=0.01)**

# 训练 20 个 epoch
for epoch in range(20):
    totalloss = 0.0
    for data in dataloader:
        optimizer.zero_grad()    # 清空梯度
        imgs, targets = data
        outputs = model(imgs)
        lossres = loss(outputs, targets)    # 计算损失
        totalloss = totalloss + lossres     # 累加损失
        lossres.backward()                  # 反向传播计算梯度
        optimizer.step()                    # 更新模型参数
    print("Epoch{} : {}".format(epoch, totalloss))
    # 写入 TensorBoard
    writer.add_scalar("train_loss", totalloss, epoch)

writer.close()

如果模型在训练时过早出现 nan 或损失不收敛,可以尝试调整学习率,使用更小的学习率或更高级的优化器(如 Adam)。

预训练模型

PyTorch 主要提供搭建神经网络的核心工具,TorchVision 提供了一系列预训练模型、标准数据集(如 ImageNet、CIFAR-10 等)和图像变换工具(transforms)。预训练模型(如 VGG16)在 ImageNet 数据集上已经训练好,可以直接使用或者在此基础上微调。

VGG16 是一种经典的卷积神经网络,主要用于图像分类任务。VGG16 由多层卷积层、池化层和全连接层组成,features 部分用于提取图像特征,classifier 部分用于分类,最终输出1000个类别。

torchvision.models.vgg16(weights, progress)

progess 代表是否显示下载进度条,默认 True,表示在下载权重时显示进度条。

weights 是预训练权重,默认为 None 不加载预训练模型。权重 VGG16_Weights.IMAGENET1K_V1 适用于分类任务,基于 ImageNet 训练,包含完整的分类器(classifier 层),VGG16_Weights.DEFAULT 等同于 VGG16_Weights.IMAGENET1K_V1

import torchvision

# 无预训练权重(随机初始化参数)
vgg16_false = torchvision.models.vgg16(weights=None)
# 使用 ImageNet 预训练参数
vgg16_true = torchvision.models.vgg16(
                     weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
# 默认使用 ImageNet 预训练权重
vgg16_default = torchvision.models.vgg16(
                     weights=torchvision.models.VGG16_Weights.DEFAULT)

但是 VGG16 对于图像输入有严格要求,输入维度必须是 224 × 224 224 \times 224 224×224

# 图像预处理(按 VGG16 需要的格式)
transform = transforms.Compose([
    transforms.Resize(256),                 # 先缩放到 256
    transforms.CenterCrop(224),             # 再中心裁剪到 224
    transforms.ToTensor(),
    # 归一化
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

权重 VGG16_Weights.IMAGENET1K_FEATURES 用于特征提取,不包含 classifier 部分权重,只能提取特征,不能进行分类(只是不包含预训练的分类器权重,并没有移除分类器层)。适用于迁移学习,可以用 features 层进行特征提取。

import torchvision

vgg16_feature = torchvision.models.vgg16(
                    weights=torchvision.models.VGG16_Weights.IMAGENET1K_FEATURES)

VGG Model Structure

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    ......
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

分类任务

import torch
from PIL import Image
from torchvision import models, transforms

# 加载 VGG16 预训练模型
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

# 定义图像预处理步骤
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img_path = r"Dataset/airplane.png"
img = Image.open(img_path)
input = transform(img)
# 添加 batch 维度
input = torch.reshape(input, (1, 3, 224, 224))

# 进入推理模式
model.eval()
# 前向传播
with torch.no_grad():
    output = model(input)

# 获取预测类别索引
predicted_class = torch.argmax(output)
# 获取 ImageNet 1000 类的类别名称
classes = models.VGG16_Weights.IMAGENET1K_V1.meta["categories"]
print(classes[predicted_class])

迁移学习微调模型

如果要迁移到 CIFAR-10 的分类任务,需要修改最后一层

from torch import nn
from torchvision import models

# 加载 VGG16 预训练模型
model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)

# 修改 classifier 部分(改为 10 类)
**model.classifier[6]** = nn.Linear(in_features=4096, out_features=10)

或者添加新层

model.classifier.add_module("7", nn.Linear(in_features=1000, out_features=10))

如果只训练最后一层,可以冻结前面的参数

for param in model.features.parameters():
    param.requires_grad = False  # 冻结 features 部分(不更新)

这样可以 保留原有的卷积特征,仅微调分类层,提高训练效率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2301141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VsCode美化 Json

1.扩展中输入:pretty json 2. (CtrlA)选择Json文本 示例:{ "name" : "runoob" , "alexa" :10000, "site" : null , "sites" :[ "Google" , "Runoob" , "T…

ssm121基于ssm的开放式教学评价管理系统+vue(源码+包运行+LW+技术指导)

项目描述 临近学期结束,还是毕业设计,你还在做java程序网络编程,期末作业,老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下,你想解决的问…

《深度学习》——ResNet网络

文章目录 ResNet网络ResNet网络实例导入所需库下载训练数据和测试数据设置每个批次的样本个数判断是否使用GPU定义残差模块定义ResNet网络模型导入GPU定义训练函数定义测试函数创建损失函数和优化器训练测试数据结果 ResNet网络 ResNet(Residual Network&#xff0…

【Windows软件 - HeidiSQL】导出数据库

HeidSQL导出数据库 软件信息 具体操作 示例文件 选项分析 选项(1) 结果(1) -- -------------------------------------------------------- -- 主机: 127.0.0.1 -- 服务器版本: …

【达梦数据库】dblink连接[SqlServer/Mysql]报错处理

目录 背景问题1:无法测试以ODBC数据源方式访问的外部链接!问题分析&原因解决方法 问题2:DBLINK连接丢失问题分析&原因解决方法 问题3:DBIINK远程服务器获取对象[xxx]失败,错误洋情[[FreeTDS][SQL Server]Could not find stored proce…

java断点调试(debug)

在开发中,新手程序员在查找错误时, 这时老程序员就会温馨提示,可以用断点调试,一步一步的看源码执行的过程,从而发现错误所在。 重要提示: 断点调试过程是运行状态,是以对象的运行类型来执行的 断点调试介绍 断点调试是…

最新智能优化算法:牛优化( Ox Optimizer,OX)算法求解经典23个函数测试集,MATLAB代码

一、牛优化算法 牛优化( OX Optimizer,OX)算法由 AhmadK.AlHwaitat 与 andHussamN.Fakhouri于2024年提出,该算法的设计灵感来源于公牛的行为特性。公牛以其巨大的力量而闻名,能够承载沉重的负担并进行远距离运输。这种…

Redis7——基础篇(四)

前言:此篇文章系本人学习过程中记录下来的笔记,里面难免会有不少欠缺的地方,诚心期待大家多多给予指教。 基础篇: Redis(一)Redis(二)Redis(三) 接上期内容&…

Git备忘录(三)

设置用户信息: git config --global user.name “itcast” git config --global user.email “ helloitcast.cn” 查看配置信息 git config --global user.name git config --global user.email $ git init $ git remote add origin gitgitee.com:XXX/avas.git $ git pull or…

MySQL 之INDEX 索引(Index Index of MySQL)

MySQL 之INDEX 索引 1.4 INDEX 索引 1.4.1 索引介绍 索引:是排序的快速查找的特殊数据结构,定义作为查找条件的字段上,又称为键 key,索引通过存储引擎实现。 优点 大大加快数据的检索速度; 创建唯一性索引,保证数…

Linux基础24-C语言之分支结构Ⅰ【入门级】

分支结构 问题抛出 我们在程序设计中往往会遇到如下问题,比如下面的函数计算: 也就是我们必须要通过一个条件的结果来选择下一步的操作,算法上属于一个分支结构,处于严重实现分支结构主要使用if语句。 条件判断 根据某个条件成…

LeetCode47

LeetCode47 目录 题目描述示例思路分析代码段代码逐行讲解复杂度分析总结的知识点整合总结 题目描述 给定一个可包含重复数字的整数数组 nums,按任意顺序返回所有不重复的全排列。 示例 示例 1 输入: nums [1, 1, 2]输出: [[1, 1, 2],[1, 2, 1],[2, 1, 1] ]…

【Unity动画】导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动。

导入动画资源到项目中,Animator播放角色动画片段,角色会跟随着动画播放移动,但我只想要角色在原地播放动画。比如:播放一个角色Run动画,希望角色在原地奔跑,而不是产生了移动距离。 问题排查: 1.是否勾选…

图解循环神经网络(RNN)

目录 1.循环神经网络介绍 2.网络结构 3.结构分类 4.模型工作原理 5.模型工作示例 6.总结 1.循环神经网络介绍 RNN(Recurrent Neural Network,循环神经网络)是一种专门用于处理序列数据的神经网络结构。与传统的神经网络不同&#xff0c…

【数据结构】(9) 优先级队列(堆)

一、优先级队列 优先级队列不同于队列,队列是先进先出,优先级队列是优先级最高的先出。一般有两种操作:返回最高优先级对象,添加一个新对象。 二、堆 2.1、什么是堆 堆也是一种数据结构,是一棵完全二叉树&#xff0c…

4、IP查找工具-Angry IP Scanner

在前序文章中,提到了多种IP查找方法,可能回存在不同场景需要使用不同的查找命令,有些不容易记忆,本文将介绍一个比较优秀的IP查找工具,可以应用在连接树莓派或查找IP的其他场景中。供大家参考。 Angry IP Scanner下载…

【Linux】命令操作、打jar包、项目部署

阿华代码,不是逆风,就是我疯 你们的点赞收藏是我前进最大的动力!! 希望本文内容能够帮助到你!! 目录 一:Xshell下载 1:镜像设置 二:阿里云设置镜像Ubuntu 三&#xf…

瑞萨RA-T系列芯片ADCGPT功能模块的配合使用

在马达或电源工程中,往往需要采集多路AD信号,且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时,如何根据需求来设置主要功能模块ADC&GPT,包括采样通道打包和分组,GPT触发启动…

Unity Shader学习6:多盏平行光+点光源 ( 逐像素 ) 前向渲染 (Built-In)

0 、分析 在前向渲染中,对于逐像素光源来说,①ForwardBase中只计算一个平行光,其他的光都是在FowardAdd中计算的,所以为了能够渲染出其他的光照,需要在第二个Pass中再来一遍光照计算。 而有所区别的操作是&#xff0…

tailwindcss学习01

系列教程 01 入门 02 vue中接入 入门 # 注意使用cmd不要powershell npm init -y # 如果没有npx则安装 npm install -g npx npm install -D tailwindcss3.4.17 --registry http://registry.npm.taobao.org npx tailwindcss init修改tailwind.config.js /** type {import(tai…