定义损失函数并以此训练和评估模型

news2024/11/26 3:36:30

基础神经网络模型搭建 

【Pytorch】数据集的加载和处理(一)

【Pytorch】数据集的加载和处理(二)

损失函数计算模型输出和目标之间的距离。通过torch.nn 包可以定义一个负对数似然损失函数,负对数似然损失对于训练具有多个类的分类问题比较有效,负对数似然损失函数的输入为对数概率,而在模型搭建的输出层部分接触过log_softmax,它能从模型中获取对数概率

目录

基础模型搭建

数据集的加载和处理

定义损失函数

定义优化器

训练并评估模型


基础模型搭建

import torch
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self, x):
         pass
def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 20, 5, 1)
    self.conv2 = nn.Conv2d(20, 50, 5, 1)
    self.fc1 = nn.Linear(4*4*50, 500)
    self.fc2 = nn.Linear(500, 10)
def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2, 2)
    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2, 2) 
    x = x.view(-1, 4*4*50)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)
Net.__init__ = __init__
Net.forward = forward
model = Net()

检查搭建情况 

print(model)

 

原位置为cpu 

 

 转移至所需CUDA设备

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

 

数据集的加载和处理

导入MNIST训练数据集和验证数据集并处理

from torch import nn
from torchvision import datasets
from torch.utils.data import TensorDataset
path2data="./data"
train_data=datasets.MNIST(path2data, train=True, download=True)
x_train, y_train=train_data.data,train_data.targets
val_data=datasets.MNIST(path2data, train=False, download=True)
x_val,y_val=val_data.data, val_data.targets
if len(x_train.shape)==3:
    x_train=x_train.unsqueeze(1)
print(x_train.shape)
if len(x_val.shape)==3:
    x_val=x_val.unsqueeze(1)
print(x_val.shape)
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)
for x,y in train_ds:
    print(x.shape,y.item())
    break

from torch.utils.data import DataLoader 
train_dl = DataLoader(train_ds, batch_size=8)
val_dl = DataLoader(val_ds, batch_size=8)

 

 

定义损失函数

损失函数计算模型输出和目标之间的距离。Pytorch 中的 optim 包提供了各种优化算法的实现,例如SGD、Adam、RMSprop 等。

通过torch.nn 包可以定义一个负对数似然损失函数,负对数似然损失对于训练具有多个类的分类问题比较有效,负对数似然损失函数的输入为对数概率,而在模型搭建的输出层部分接触过log_softmax,它能从模型中获取对数概率。

loss_func = nn.NLLLoss(reduction="sum")
for xb, yb in train_dl:
    # move batch to cuda device
    xb=xb.type(torch.float).to(device)
    yb=yb.to(device)
    out=model(xb)
    loss = loss_func(out, yb)
    print (loss.item())
    break

得到一个测试值 

 

定义优化器

定义一个Adam优化器,优化器的输入是模型参数和学习率

from torch import optim
opt = optim.Adam(model.parameters(), lr=1e-4)

通过opt .step()自动更新模型参数,同时需要注意计算下一批的梯度之前需将梯度归0

opt.step()
opt.zero_grad()

训练并评估模型

定义一个辅助函数 loss_batch来计算每个小批量的损失值。函数的 opt 参数引用优化器,如果给定,则计算梯度并按小批量更新模型参数。

def  loss_batch(loss_func,  xb,  yb,yb_h,  opt=None): 
    loss = loss_func(yb_h, yb) 
    metric_b =  metrics_batch(yb,yb_h) 
    if opt is  not None: 
        loss.backward()
        opt.step()
        opt.zero_grad()
    return loss.item(),metric_b

 定义一个辅助函数metrics_batch来计算每个小批量的性能指标,这里以准确率作为分类任务的性能指标,并使用 output.argmax 来获取概率最高的预测类

def metrics_batch(target, output):
    pred = output.argmax(dim=1, keepdim=True)
    corrects=pred.eq(target.view_as(pred)).sum().item()
    return corrects

定义一个辅助函数loss_epoch来计算整个数据集的损失和指标值。使用数据加载器对象获取小批量,将它们提供给模型,并计算每个小批量的损失和指标,通过两个运行变量来分别添加损失值和指标值。

def loss_epoch(model,loss_func,dataset_dl,opt=None):
    loss=0.0
    metric=0.0
    len_data=len(dataset_dl.dataset)
    for xb, yb in dataset_dl:
        xb=xb.type(torch.float).to(device)
        yb=yb.to(device)
        yb_h=model(xb)
        loss_b,metric_b=loss_batch(loss_func, xb, yb,yb_h, opt)
        loss+=loss_b
        if metric_b is not None:
            metric+=metric_b
    loss/=len_data
    metric/=len_data
    return loss, metric

最后,定义一个辅助函数train_val来训练多个时期的模型。在每个时期使用验证数据集评估模型的性能。训练和评估需要分别使用 model.train()和 model.eval()模式。torch.no_grad()可以阻止 autograd 在评估期间计算梯度。

def train_val(epochs, model, loss_func, opt, train_dl, val_dl):
    for epoch in range(epochs):
        model.train()
        train_loss,train_metric=loss_epoch(model,loss_func,train_dl,opt)
        
        model.eval()
        with torch.no_grad():
            val_loss, val_metric=loss_epoch(model,loss_func,val_dl)
        accuracy=100*val_metric
        
        print("epoch: %d, train loss: %.6f, val loss: %.6f,accuracy: %.2f" %(epoch, train_loss,val_loss,accuracy))

 设定时期数为5,调用函数进行训练和评估

num_epochs=5
train_val(num_epochs, model, loss_func, opt, train_dl, val_dl)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1930119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

炎炎夏日,这份锂电AGV叉车保养指南赶紧收藏!

AGV 随着工厂自动化、计算机集成制造系统技术的逐步发展以及柔性制造系统、自动化立体仓库的广泛应用,AGV(Automatic GuidedVehicle)即自动导引车作为联系和调节离散型物流系统以使其作业连续化的必要的自动化搬运装卸手段,其应用范围和技术水平得到了迅…

云服务器实际内存与购买不足量问题

君衍 一、本篇缘由二、问题研究1、dmidecode2、dmesg | grep -i memory 三、kdump四、解决方案1、卸载kdump-tools2、清理依赖包3、修改配置文件4、重新生成配置文件5、重启服务器6、再次查看 一、本篇缘由 本篇由于最近买了云服务器,之前基本在本地使用VMware进行虚…

初识单片机之点亮LED灯

1、前言 如果说编程的开始是Hello world,那么单片机的开始就是点亮LED灯,这个操作最直接的展示了单片机强大的控制功能,这里我就以直接点亮指定位置的LED灯的形式演示这个功能。 2、原理介绍 我的单片机的LED灯都是接在单片机的P1口,从P10~P…

【数据结构(邓俊辉)学习笔记】高级搜索树02——B树

文章目录 1. 大数据1.1 640 KB1.2 越来越大的数据1.3 越来越小的内存1.4 一秒与一天1.5 分级I/O1.6 1B 1KB 2. 结构2.1 观察体验2.2 多路平衡2.3 还是I/O2.4 深度统一2.5 阶次含义2.6 紧凑表示2.7 BTNode2.8 BTree 3. 查找3.1 算法过程3.2 操作实例3.3 算法实现3.4 主次成本3.…

YOLOv8白皮书-第Y8周:yolov8.yaml文件解读

本文为365天深度学习训练营中的学习记录博客 原作者:K同学啊|接辅导、项目定制 请根据YOLOv8n、YOLOv8s模型的结构输出,手写出YOLOv8l的模型输出 文件位置:./ultralytics/cfg/models/v8/yolov8.yaml 一、参数配置 # Parameters nc: 80 # n…

Bug:时间字段显示有问题

Bug:时间字段显示有问题 文章目录 Bug:时间字段显示有问题1、问题2、解决方法一:添加注解3、解决方法二:消息转换器自定义对象映射器配置消息转换器 1、问题 ​ 在后端传输时间给前端的时候,发现前端的时间显示有问题…

设计模式总结(设计模式的原则及分类)

1.什么是设计模式? 设计模式(Design pattern)代表了最佳的实践,通常被有经验的面向对象的软件开发人员所采用。设计模式是软件开发人员在软件开发过程中面临的一般问题的解决方案。这些解决方案是众多软件开发人员经过相当长的一段时间的试验和错误总结…

自动化任务调度,轻松管理海量数据采集项目

摘要: 本文深入探讨了在数据驱动时代,如何通过自动化任务调度轻松驾驭海量数据采集项目,提升工作效率,确保数据处理的准确性和及时性。我们分享了一系列实用策略与工具推荐,帮助企业和开发者优化数据采集流程&#xf…

SQL 中的 EXISTS 子句:探究其用途与应用

目录 EXISTS 子句简介语法 EXISTS 与 NOT EXISTSEXISTS 子句的工作原理实际应用场景场景一:筛选存在关联数据的记录场景二:优化查询性能 EXISTS 与其他 SQL 结构的比较EXISTS vs. JOINEXISTS vs. IN 多重 EXISTS 条件在 UPDATE 语句中使用 EXISTS常见问题…

基于 AntV F2 的雷达图组件开发

本文由ScriptEcho平台提供技术支持 项目地址:传送门 基于 AntV F2 的雷达图组件开发 应用场景介绍 雷达图是一种多变量统计图表,用于可视化展示多个维度的数据。它通常用于比较不同对象的多个属性或指标,直观地反映各维度之间的差异和整体…

LoRaWAN协议

目录 一、介绍 1、LPWA是什么? 2、LoRa是什么? 3、LoRaWAN是什么? 4、浅谈LoRa与LoRaWAN的区别 5、LoRaWAN开发背景 6、LoRaWAN与NB-IOT如何选择? 二、LoRaWAN网络结构 1、组网结构 2、星型拓扑结构 三、数据格式 1、…

golang AST语法树解析

1. 源码示例 package mainimport ("context" )// Foo 结构体 type Foo struct {i int }// Bar 接口 type Bar interface {Do(ctx context.Context) error }// main方法 func main() {a : 1 }2. Golang中的AST golang官方提供的几个包,可以帮助我们进行A…

代码随想录算法训练营第五十五天|101.孤岛的总面积、102.沉没孤岛、103.水流问题、104.建造最大岛屿

101.孤岛的总面积 题目链接:101.孤岛的总面积沉没孤岛 文档讲解:代码随想录 状态:不会 思路: 步骤1:将边界上的陆地变为海洋 步骤2:计算孤岛的总面积 题解: public class Main {// 保存四个方…

【UE5.1】NPC人工智能——02 NPC移动到指定位置

效果 步骤 1. 新建一个蓝图,父类选择“AI控制器” 这里命名为“BP_NPC_AIController”,表示专门用于控制NPC的AI控制器 2. 找到我们之前创建的所有NPC的父类“BP_NPC” 打开“BP_NPC”,在类默认值中,将“AI控制器类”一项设置为“…

动手学深度学习——3.多层感知机

1.线性模型 线性模型可能出错 例如,线性意味着单调假设: 任何特征的增大都会导致模型输出的增大(如果对应的权重为正), 或者导致模型输出的减小(如果对应的权重为负)。 有时这是有道理的。 例…

R绘制Venn图及其变换

我自己在用R做各种分析时有不少需要反复用到的基础功能,比如一些简单的统计呀,画一些简单的图等等,虽说具体实现的代码也不麻烦,但还是不太想每次用的时候去找之前的代码。 索性将常用的各种函数整成了一个包:pcutils…

前端JS特效第34集:jQuery俩张图片局部放大预览插件

jQuery俩张图片局部放大预览插件&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"zh"> <head> <meta charset"UTF-8"> <meta http-equiv"X-UA-Co…

数据结构与算法02迭代|递归

目录 一、迭代(iteration) 1、for循环 2、while循环 二、递归&#xff08;recursion&#xff09; 1、普通递归 2、尾递归 3、递归树 三、对比 简介&#xff1a;在算法中&#xff0c;重复执行某个任务是常见的&#xff0c;它与复杂度息息相关&#xff0c;在程序中实现重…

MySQL MVCC原理

全称Multi-Version Concurrency Control&#xff0c;即多版本并发控制&#xff0c;主要是为了提高数据库的并发性能。 1、版本链 对于使用InnoDB存储引擎的表来说&#xff0c;它的聚簇索引记录中都包含两个必要的隐藏列&#xff1a; 1、trx_id&#xff1a;每次一个事务对某条…

connect by prior 递归查询

connect by prior 以公司组织架构举例&#xff0c;共四个层级&#xff0c;总公司&#xff0c;分公司&#xff0c;中心支公司&#xff0c;支公司 总公司level_code为1 下一层级的parent_id为上一层级的id&#xff0c;建立关联关系 SELECT id, name, LEVEL FROM org_info a STA…