机器学习周报(9.9-9.15)-Pytorch学习(三)

news2025/1/12 18:19:47

文章目录

    • 摘要
    • Abstract
    • 1 损失函数与反向传播
      • 1.1 L1Loss损失函数
      • 1.2 MSELoss损失函数
      • 1.3 交叉熵损失函数(CrossEntropyLoss)
      • 1.4 反向传播
    • 2 优化器
    • 3 现有网络模型的使用及修改
    • 4 网络模型的保存与读取
      • 4.1 保存模型
      • 4.2 读取
    • 总结

摘要

本次学习对Pytorch中有关常用的损失函数进行了相关学习和实操,并对Pytorch中交叉熵损失函数的原理进行学习和相关公式的推导;并学习了优化器通过计算模型的损失函数进行模型的优化;同时学习了现在训练成熟的网络模型的使用、修改以及网络模型的保存和读取。

Abstract

In this study, the common loss functions in Pytorch are studied and implemented, and the principle of cross-entropy loss function in Pytorch is studied and related formulas are derived. The optimizer can optimize the model by calculating the loss function of the model. At the same time, we learned how to use and modify the network model and how to save and read the network model.

1 损失函数与反向传播

1.1 L1Loss损失函数

import torch
from torch.nn import L1Loss

inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)

l1 = L1Loss()
result = l1(inputs, targets)
print(result)   # tensor(0.6667)

l1 = L1Loss(reduction=‘mean’)

  1. 默认reduction=‘mean’,求每个数据差的绝对值再取平均
  2. 当reduction=‘sum’,即求每个数据差的绝对值求和,此时输出为:tensor(2.)

1.2 MSELoss损失函数

torch.nn.MSELoss
在这里插入图片描述

import torch
from torch.nn import L1Loss, MSELoss

inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32)

# MSELoss损失函数
m1 = MSELoss(reduction='mean') #默认值
result = m1(inputs, targets)
print(result)     # tensor(1.3333)

m2 = MSELoss(reduction='sum')
result = m2(inputs, targets)
print(result)     # tensor(4.)

1.3 交叉熵损失函数(CrossEntropyLoss)

在这里插入图片描述

softmax函数又称归一化指数函数,是基于 sigmoid 二分类函数在多分类任务上的推广;在多分类网络中,常用 Softmax 作为最后一层进行分类。

import torch
import torch.nn as nn

input1 = torch.tensor([-0.5, -0.3, 0, 0.3, 0.5])
input2 = torch.tensor([-3, -1, 0, 1, 3], dtype=torch.float32)

softmax = nn.Softmax(dim=0)
output1 = softmax(input1)
output2 = softmax(input2)
print(output1) # tensor([0.1135, 0.1386, 0.1871, 0.2525, 0.3084])
print(output2) # tensor([0.0021, 0.0152, 0.0413, 0.1122, 0.8292])
  1. Softmax 可以使正样本(正数)的结果趋近于 1,使负样本(负数)的结果趋近于 0;且样本的绝对值越大,两极化越明显。
  2. Softmax 可以使数值较大的值获得更大的概率

Pytorch中nn.CrossEntropyLoss,结合了nn.LogSoftmax()和nn.NLLLoss()两个函数,在做分类训练时非常有用

在这里插入图片描述

import torch
import torch.nn as nn

input2 = torch.tensor([0.1, 0.2, 0.3])
target2 = torch.tensor([1])
input2 = torch.reshape(input2, (1, 3))
l = crossEntropyLoss(input2, target2)
print(l) # tensor(1.1019)
# 计算公式:
# -0.2 + ln(exp(0.1)+exp(0.2)+exp(0.3))
import torch
import torch.nn as nn

crossEntropyLoss = nn.CrossEntropyLoss()
input = torch.tensor([[-0.1342, -2.5835, -0.9810],
                     [0.1867, -1.4513, -0.3225],
                     [0.6272, -0.1120, 0.3048]])
target = torch.tensor([0, 2, 1])
loss = crossEntropyLoss(input, target)
print(loss)
'''
    [-(-0.1342)+ln(exp(-0.1342)+exp(-2.5835)+exp(-0.9810)) 
    -(-0.3225)+ln(exp(0.1867)+exp(-1.4513)+exp(-0.3225))
    -(-0.1120)+ln(exp(0.6272)+exp(-0.1120)+exp(0.3048))]/3 = 3.03842655071/3 = 1.01280885024

'''

1.4 反向传播

import torch.nn as nn
import torchvision
from torch.nn import Conv2d, MaxPool2d, Sequential, Linear, Flatten, CrossEntropyLoss
from torch.utils.data import DataLoader

#数据集
dataset = torchvision.datasets.CIFAR10("dataset2", train=False, transform= torchvision.transforms.ToTensor())

data_loader = DataLoader(dataset, batch_size=1)

class seq(nn.Module):
    def __init__(self):
        super(seq, self).__init__()
        self.model = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

s = seq()
# 交叉熵
loss = CrossEntropyLoss()

for data in data_loader:
    imgs, target = data
    output = s(imgs)
    # print(output)
    # print(target)
    result_loss = loss(output, target)
    # print(result_loss)
    # 反向传播
    # 计算出来的 loss 值有 backward 方法属性,
    # 反向传播来计算每个节点的更新的参数。
    # 这里查看网络的属性 grad 梯度属性刚开始没有,
    # 反向传播计算出来后才有,后面优化器会利用梯度优化网络参数。      
    result_loss.backward()
    print('ok')

还未执行反向传播
在这里插入图片描述

执行反向传播之后,进行了gradient descent,grad值进行了更新

在这里插入图片描述

2 优化器

torch.optim

# 数据集
dataset = torchvision.datasets.CIFAR10("dataset2", train=False, transform= torchvision.transforms.ToTensor())
data_loader = DataLoader(dataset, batch_size=1)

# 定义模型
model=...

#训练模型
for data in data_loader:
    imgs, target = data
    output = seq(imgs)
    result_loss = loss(output, target)
    # 优化器先将网络中的每个参数的梯度清零
    optim.zero_grad()
    # 调用损失函数的反向传播求出每个节点的梯度
    result_loss.backward()
    # 更新参数
    optim.step()

Debug:
将这三行代码打上断点,依次执行观察grad和data的变化

在这里插入图片描述

在这里插入图片描述

执行第42行代码前跟执行42行代码之后,grad都是没有值的

在这里插入图片描述

在这里插入图片描述

执行44行反向传播代码之后,grad由none变化,出现参数

在这里插入图片描述

执行46行代码前后,data的数值发生了变化

在这里插入图片描述
在这里插入图片描述

  • 训练20个回合(epoch)

训练20个回合,看每个回合的loss值

#训练模型:
for epoch in range(20):
    sum_loss = 0
    for data in data_loader:
        imgs, target = data
        output = seq(imgs)
        result_loss = loss(output, target)
        # 优化器先将网络中的每个参数的梯度清零
        optim.zero_grad()
        # 调用损失函数的反向传播求出每个节点的梯度
        result_loss.backward()
        # 更新参数
        optim.step()
        # print(result_loss)
        sum_loss = sum_loss+result_loss

    print(sum_loss)

在这里插入图片描述

3 现有网络模型的使用及修改

vgg16模型为例,它是以ImageNet数据集进行训练得到的,但是ImageNet数据集不公开并且数据量非常庞大,不下载,仅用于增加和修改该网络模型的学习

import torchvision

vgg16_true = torchvision.models.vgg16()

print(vgg16_true)

在这里插入图片描述

在这里插入图片描述

可以看到,该网络模型最后是一个线性变化:Linear(4096,1000),现在想该网络模型最后的线性变化改为10输出

方法一:在VGG16后面添加一个线性层

vgg16_true.add_module('add_linear', nn.Linear(1000, 10))

或者

vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

方法二:直接修改VGG16的最后一个线性层

vgg16_true.classifier[6] = nn.Linear(4096, 10)

4 网络模型的保存与读取

4.1 保存模型

import torch
import torchvision

vgg16 = torchvision.models.vgg16()

# 保存方式一,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式二,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

4.2 读取

import torch
import torchvision

# 方式一 -> 保存方式一,加载模型
model = torch.load("vgg16_method1.pth")
print(model)


# 方式二:对应保存方式2
vgg16 = torchvision.models.vgg16()
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

  • 方式一保存模型有陷阱
# save.py
# 保存方式一存在陷阱
class modelcc(nn.Module):
    def __init__(self):
        super(modelcc, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, 1)

    def forward(self, x):
        return self.conv1(x)

cc = modelcc()
torch.save(cc, "cctest.pth")
# load.py

cc = torch.load("cctest.pth")
print(cc)

在实际运用时,一般把model单独写一个python文件,然后通过下面这行代码在使用时进行引入

from model_save import *

总结

本周学习了Pytorch中一些小简单的损失函数的数学公式和使用,搜索相关资料更深刻学习了交叉熵损失函数,学习了网络模型的使用、修改、保存和读取。下周,我将通过学习minist数据集相关任务,加深对CNN原理的学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2133596.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

8.3Sobel算子边缘检测

实验原理 Sobel算子是一种广泛使用的一阶导数边缘检测算子,它通过计算图像在水平和垂直方向上的梯度来检测边缘。Sobel算子使用一对3x3的掩模来实现这一功能。相比于其他边缘检测算子,Sobel算子在检测边缘的同时还能提供一定的抗噪能力。 在OpenCV中&a…

【射频通信电子线路基础第四讲】LC匹配网络、史密斯圆图、噪声与噪声系数

一、LC匹配网络 1、L-I型(负载与电抗并联) 2、L-II型(负载与电抗串联) 3、T型网络和π型网络例子 二、Smith圆图 这里先附上知乎大神的讲解链接,推荐直接去看非常适合入门理解,看完之后茅塞顿开 https://…

MySQL 安全机制全面解析

‍ 在如今的数字化时代,数据库安全 变得越来越重要。为了防止对数据库进行非法操作,MySQL 定义了一套完整的安全机制,包括用户管理、权限管理 和 角色管理。本文将为你深入浅出地介绍这三大安全机制,帮助你轻松掌握MySQL的安全管…

MPP数据库之SelectDB

SelectDB 是一个高性能、云原生的 MPP(大规模并行处理)数据库,旨在为分析型数据处理场景提供快速、弹性和高效的解决方案。它专为处理大规模结构化和半结构化数据设计,常用于企业级业务分析、实时分析和决策支持。 SelectDB 是在…

Vue2时间轴组件(TimeLine/分页、自动顺序播放、暂停、换肤功能、时间选择,鼠标快速滑动)

目录 1介绍背景 2实现原理 3组件介绍 4代码 5其他说明 1介绍背景 项目背景是 一天的时间轴 10分钟为一间隔 一天被划分成144个节点 一页面12个节点 代码介绍的很详细 可参考或者借鉴 2实现原理 对Element-plus滑块组件的二次封装 基于Vue2(2.6.14&#x…

数字孪生引领智慧医院革新:未来医疗的智能化之路

数字孪生(Digital Twin) 是指将物理实体或系统的数字化模型与其实时运行数据相结合,以反映实体的状态、行为和性能,并通过数据分析和仿真来优化决策和管理。在智慧医院建设中,数字孪生技术扮演着关键角色。 1. 数字孪生…

基于SpringBoot+Vue的瑜伽体验课预约管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于JavaSpringBootVueMySQL的…

国内按月/季/年使用GPT4.0及OpenAI最新的模型

其实gpt官方版本不仅对于网络要求很高,且订阅用户对高级模型的使用也是有次数限制的, 国内想要稳定且最快同步官网的最新模型,我推荐一个地址,可以方便的不限次数的使用GPT4.0等模型, 今天早上刚出的OpenAI全新的草莓模型&#xf…

uniapp 发布苹果IOS详细流程,包括苹果开发者公司账号申请、IOS证书、.p12证书文件等

记录一下uniapp发布苹果IOS的流程。 一、苹果开发者公司账号申请 1、邓白氏编码申请(先申请公司邓白氏编码,这一步需要1-2周,没有这个编码苹果开发者没法申请,已有编码的跳过此步骤): 1)联系苹…

【C++ Primer Plus习题】16.1

大家好,这里是国中之林! ❥前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。有兴趣的可以点点进去看看← 问题: 解答: main.cpp #include <iostream> #include <string> usin…

Linux通配符*、man 、cp、mv、echo、cat、more、less、head、tail、等指令、管道 | 、指令的本质 等的介绍

文章目录 前言一、Linux通配符*二、man 指令三、 cp 指令四、mv指令五、 echo 指令六、cat 指令七、more 指令八、 less 指令九、 head 指令十、 tail指令十一、 管道 |十二、指令的本质总结 前言 Linux通配符*、man 、cp、mv、echo、cat、more、less、head、tail、等指令、管…

[Unity Demo]重启项目之从零开始制作空洞骑士Hollow Knight第一集:导入素材以及建立并远程连接git仓库

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、导入素材二、使用步骤 1.建立并远程连接git2.github和仓库连接总结 前言 好久没来CSDN看看&#xff0c;突然看到前两年自己写的文章从零开始制作空洞骑士只…

【计算机网络 - 基础问题】每日 3 题(一)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/fYaBd &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 C 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏&…

基于云计算的虚拟电厂负荷预测

基于云计算的虚拟电厂负荷预测 随着电网规模的扩大及新能源的不断应用&#xff0c;并网电网的安全性和经济性备受关注。 电网调度不再是单一或局部控制&#xff0c;而是采用智能网络集成方式调度 。 智能电网应具有以下特点&#xff1a;坚强自愈&#xff0c;可以抵御外来干扰甚…

使用 Milvus、vLLM 和 Llama 3.1 搭建 RAG 应用

vLLM 是一个简单易用的 LLM 推理服务库。加州大学伯克利分校于 2024 年 7 月将 vLLM 作为孵化项目正式捐赠给 LF AI & Data Foundation 基金会。欢迎 vLLM 加入 LF AI & Data 大家庭&#xff01;&#x1f389; 在主流的 AI 应用架构中&#xff0c;大语言模型&#xff0…

【devops】devops-git之介绍以及日常使用

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…

【GBase 8c V5_3.0.0 分布式数据库常用几个SQL】

1.检查应用连接数 以管理员用户 gbase&#xff0c;登录数据库主节点。 接数据库&#xff0c;并执行如下 SQL 语句查看连接数。 SELECT count(*) FROM (SELECT pg_stat_get_backend_idset() AS backendid) AS s;2.查看空闲连接 查看空闲(state 字段为”idle”)且长时间没有更…

【linux-Day3】linux下的基本指令

【linux-Day3】linux下的基本指令 linux下的基本指令&#x1f4e2;man&#xff1a;访问linux手册页&#x1f4e2;echo&#xff1a;把字符串写入指定文件中&#x1f4e2;cat&#xff1a;查看目标文件的内容&#x1f4e2;cp&#xff1a;复制文件或目录&#x1f4e2;mv&#xff1a…

【【通信协议ARP的verilog实现】】

【【通信协议ARP的verilog实现】】 eth_arp_test.v module eth_arp_test(input sys_clk , //系统时钟input sys_rst_n , //系统复位信号&#xff0c;低电平有效input touch_key , //触摸按键,用于触发开发…

【JVM】判断对象能否回收的两种方法:引用计数算法,可达性分析算法

1、引用计数算法&#xff1a; 给对象添加一个引用计数器&#xff0c;当该对象被其它对象引用时计数加一&#xff0c;引用失效时计数减一&#xff0c;计数为0时&#xff0c;可以回收。 特点&#xff1a;占用了一些额外的内存空间来进行计数&#xff0c;原理简单&#xff0c;判…