Lnton羚通关于Optimization在【PyTorch】中的基础知识

news2025/1/13 13:37:36

OPTIMIZING MODEL PARAMETERS (模型参数优化)
现在我们有了模型和数据,是时候通过优化数据上的参数来训练了,验证和测试我们的模型。训练一个模型是一个迭代的过程,在每次迭代中,模型会对输出进行猜测,计算猜测数据与真实数据的误差(损失),收集误差对其参数的导数(正如前一节我们看到的那样),并使用梯度下降优化这些参数。

Prerequisite Code ( 先决代码 )
We load the code from the previous sections on

import torch 
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

training_data = datasets.FashionMNIST(
    root = "../../data/",
    train = True,
    download = True, 
    transform = transforms.ToTensor()
)

test_data = datasets.FashionMNIST(
    root = "../../data/",
    train = False,
    download = True, 
    transform = transforms.ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size = 32, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = 32, shuffle = True)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)  
        )
  
    def forward(self, x):
        out = self.flatten(x)
        out = self.linear_relu_stack(out)
        return out


model = NeuralNetwork()

Hyperparameters ( 超参数 )
超参数是可调节的参数,允许控制模型优化过程,不同的超参数会影响模型的训练和收敛速度。read more

我们定义如下的超参数进行训练:

Number of Epochs: 遍历数据集的次数
Batch Size: 每一次使用的数据集大小,即每一次用于训练的样本数量
Learning Rate: 每个 batch/epoch 更新模型参数的速度,较小的值会导致较慢的学习速度,而较大的值可能会导致训练过程中不可预测的行为,例如训练抖动频繁,有可能会发散等。

learning_rate = 1e-3
batch_size = 32
epochs = 5

Optimization Loop ( 优化循环 )
我们设置完超参数后,就可以利用优化循环训练和优化模型;优化循环的每次迭代称为一个 epoch, 每个 epoch 包含两个主要部分:

The Train Loop: 遍历训练数据集并尝试收敛到最优参数。
The Validation/Test Loop: 验证/测试循环—遍历测试数据集以检查模型性能是否得到改善。
让我们简单地熟悉一下训练循环中使用的一些概念。跳转到前面以查看优化循环的完整实现。

Loss Function ( 损失函数 )
当给出一些训练数据时,我们未经训练的网络可能不会给出正确的答案。 Loss function 衡量的是得到的结果与目标值的不相似程度,是我们在训练过程中想要最小化的 Loss function。为了计算 loss ,我们使用给定数据样本的输入进行预测,并将其与真实的数据标签值进行比较。

常见的损失函数包括nn.MSELoss (均方误差)用于回归任务,nn.NLLLoss(负对数似然)用于分类神经网络。nn.CrossEntropyLoss 结合 nn.LogSoftmax 和 nn.NLLLoss 。

我们将模型的输出 logits 传递给 nn.CrossEntropyLoss ,它将规范化 logits 并计算预测误差。

# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

Optimizer ( 优化器 )
优化是在每个训练步骤中调整模型参数以减少模型误差的过程。优化算法定义了如何执行这个过程(在这个例子中,我们使用随机梯度下降)。所有优化逻辑都封装在优化器对象中。这里,我们使用 SGD 优化器; 此外,PyTorch 中还有许多不同的优化器,如 ADAM 和 RMSProp ,它们可以更好地用于不同类型的模型和数据。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在训练的循环中,优化分为3个步骤:

调用 optimizer.zero_grad() 重置模型参数的梯度,默认情况下,梯度是累加的。为了防止重复计算,我们在每次迭代中显式将他们归零。
通过调用 loss.backward() 反向传播预测损失, PyTorch 保存每个参数的损失梯度。
一旦我们有了梯度,我们调用 optimizer.step() 在向后传递中收集梯度调整参数。
Full Implementation (完整实现)
我们定义了遍历优化参数代码的 train loop, 以及根据测试数据定义了test loop。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

## 数据集
training_data = datasets.FashionMNIST(
    root="../../data/",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

test_data = datasets.FashionMNIST(
    root="../../data/",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

## dataloader
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)

## 定义神经网络
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        out = self.flatten(x)
        out = self.linear_relu_stack(out)
        return out

## 实例化模型
model = NeuralNetwork()

## 损失函数
loss_fn = nn.CrossEntropyLoss()

## 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

## 超参数
learning_rate = 1e-3
batch_size = 32
epochs = 5

## 训练循环
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # 计算预测和损失
        pred = model(X)
        loss = loss_fn(pred, y)

        ## 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

## 测试循环
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")


## 训练网络
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/897182.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mqtt开关实现

这个项目的主要需求其实并不复杂,只是需要让用户可以在小程序上控制预约后的自习室座位的灯和柜子等的开关。这里的关键是需要通过一个网络应用来转发用户对智能硬件的控制请求。 物联网应用的主要几个难点及对应的思路如下: 通信数据量小、通信环境不…

优酷视频码率、爱奇艺视频码率、B站视频码率、抖音视频码率对比

优酷视频码率、爱奇艺视频码率与YouTube视频码率对比 优酷视频码率: 优酷的视频码率可以根据视频质量、分辨率和内容类型而变化。一般而言,优酷提供了不同的码率选项,包括较低的标清(SD)码率和较高的高清(…

[Openwrt-21.02]MT7981 增加 USB RNDIS功能支持操作说明

环境说明 ubuntu18.04编译环境,openwrt-21.02版本,MT7981开发板 openwrt配置项 make menuconfig配置 ​​ ​​​​​​ 配置后.config配置 CONFIG_PACKAGE_kmod-usb-core=y CONFIG_PACKAGE_kmod-usb-ehci=y CONFIG_PACKAGE_kmod-usb-net=y CONFIG_PACKAGE_kmod-usb-net-…

centOS7.6虚拟机设置桥接方式联网

1、虚拟机设置 设置添加进来的虚拟机,选择“网络适配器”,网络连接方式选择“桥接模式”。点击确定。 2、虚拟网络编辑器设置 VMware中选择编辑中的“虚拟机网络编辑器”,选中桥接模式,“已桥接至”选择当前本机电脑的网络信息。…

百度云BOS云存储的图片如何在访问时,同时进行格式转换、缩放等处理

前言 之前做了一个图片格式转换和压缩的服务,结果太占内存。后来查到在访问图片链接时,支持进行图片压缩和格式转换,本来想着先格式转换、压缩图片再上传到BOS,现在变成了上传后,访问时进行压缩和格式转换。想了想&am…

【java】为什么文件上传要转成Base64?

文章目录 1 前言2 multipart/form-data上传3 Base64上传3.1 Base64编码原理3.2 Base64编码的作用 4 总结 1 前言 最近在开发中遇到文件上传采用Base64的方式上传,记得以前刚开始学http上传文件的时候,都是通过content-type为multipart/form-data方式直接…

虫情测报系统的工作原理及功能优势

KH-CQPest虫情测报系统能够在不对虫体造成任何破坏的情况下,无公害的杀死虫子,利用高倍显微镜和高清摄像头拍摄虫体照片,并将虫体照片发送到远端平台,让工作人员无需要到现场,通过平台就可以观察害虫的种类和数量&…

Visual Studio 2022离线源码编译onnxruntime

1. 首先参考前述文章《Visual Studio 2019源码编译cpu版本onnxruntime_xunan003的博客-CSDN博客》第1~3步,将anaconda python3.8虚拟环境copy至内网离线环境envs中。 并将下载的onnxruntime包迁移至内网固定位置; 2.查看onnxruntime/cmake/external所依…

USB3.2链路训练及状态机解析

1.简介 LTSSM(Link Training and Status State Machine)定义了USB3.2总线链路层连接性及链路层电源管理。LTSSM由12种不同的链路状态组成,可以根据它们的功能对其进行表征。 LTSSM有4个可操作的link状态,分别为U0、U1、U2及U3。U0是使能Enhanced Super…

Spring框架中JavaBean的生命周期及单例模式与多列模式

Spring框架中JavaBean的生命周期及单例模式与多列模式 1. Spring框架中JavaBean的管理过程1.1 #定义Bean1.2 Bean的实例化1.3 属性注入1.4 初始化方法1.5 Bean的使用和引用1.6 销毁方法 2. 单例模式与原型模式在JavaBean管理中的应用1.在Spring管理JavaBean的过程中&#xff0c…

STM32 CubeMX (第三步Freertos中断管理和软件定时)

STM32 CubeMX STM32 CubeMX (第三步Freertos中断管理和软件定时) STM32 CubeMX一、STM32 CubeMX设置时钟配置HAL时基选择TIM1(不要选择滴答定时器;滴答定时器留给OS系统做时基)使用STM32 CubeMX 库,配置Fre…

Java请求Http接口-hutool的HttpUtil(超详细-附带工具类)

概述 HttpUtil是应对简单场景下Http请求的工具类封装&#xff0c;此工具封装了HttpRequest对象常用操作&#xff0c;可以保证在一个方法之内完成Http请求。 此模块基于JDK的HttpUrlConnection封装完成&#xff0c;完整支持https、代理和文件上传。 导包 <dependency>&…

第二章MyBatis入门程序

入门程序 创建maven程序 导入MyBatis依赖。pom.xml下导入如下依赖 <dependencies><dependency><groupId>org.mybatis</groupId><artifactId>mybatis</artifactId><version>3.5.6</version></dependency><dependen…

vue3 简易用对话框实现点击头像放大查看

设置头像悬停手势 img:hover{cursor: pointer;}效果&#xff1a; 编写对话框 <el-dialog class"bigAvatar"style"border-radius: 4px;"v-model"deleteDialogVisible"title"查看头像"top"5px"><div><img src&…

[python] Kmeans文本聚类算法+PAC降维+Matplotlib显示聚类图像

0 前言 本文主要讲述以下几点&#xff1a; 1.通过scikit-learn计算文本内容的tfidf并构造N*M矩阵(N个文档 M个特征词)&#xff1b; 2.调用scikit-learn中的K-means进行文本聚类&#xff1b; 3.使用PAC进行降维处理&#xff0c;每行文本表示成两维数据&…

vscode 安装勾选项解释

1、通过code 打开“操作添加到windows资源管理器文件上下文菜单 &#xff1a;把这个两个勾选上&#xff0c;可以对文件使用鼠标右键&#xff0c;选择VSCode 打开。 2、将code注册为受支持的文件类型的编辑器&#xff1a;不建议勾选&#xff0c;这样会默认使用VSCode打开支持的相…

opencv简单使用

cv2库安装&#xff0c; conda install opencv-python注意cv2使用时&#xff0c;路径不能有中文。&#xff08;不然会一直’None’ _ update # 处理中文路径问题 def cv_imread(file_path): #使用之前需要导入numpy、cv2库&#xff0c;file_path为包含中文的路径return cv2.imd…

使用sklearn函数对模型进行交叉验证

使用sklearn函数对模型进行交叉验证 交叉验证用来做什么sklearn 中的函数 交叉验证用来做什么 交叉验证&#xff08;Cross-Validatio&#xff09;&#xff0c;是用于在驯良过程中对训练模型的性能和参数进行评估选择的技术。 它的意义在于能够充分利用优先的数据集&#xff0…

08-信息收集-架构、搭建、WAF等

信息收集-架构、搭建、WAF等 信息收集-架构、搭建、WAF等一、前言说明二、CMS识别技术三、源码获取技术四、架构信息获取技术五、站点搭建分析1、搭建习惯-目录型站点2、搭建习惯-端口类站点3、搭建习惯-子域名站点4、搭建习惯-类似域名站点5、搭建习惯-旁注&#xff0c;c段站点…

汽车OTA活动高质量发展的“常”与“新”

伴随着车主的频繁崔更&#xff0c;车企除了卷硬件、拼价格&#xff0c;逐渐将精力转移到汽车全生命周期的常用常新。时至下半年&#xff0c;车企OTA圈愈发热闹&#xff0c;以新势力、新实力为代表新一代车企&#xff0c;OTA运营活动逐渐进入高质量发展期。 所谓高质量&#xf…