Python 机器学习求解 PDE 学习项目 基础知识(4)PyTorch 库函数使用详细案例

news2024/9/22 20:32:03

PyTorch 库函数使用详细案例

在这里插入图片描述

前言

在深度学习中,PyTorch 是一个广泛使用的开源机器学习库。它提供了强大的功能,用于构建、训练和评估深度学习模型。本文档将详细介绍如何使用以下 PyTorch 相关库函数,并提供相应的案例示例:

  • torch
  • torch.nn.functional
  • torch.optim.lr_scheduler
    这些库函数的使用将成为后续我们使用 机器学习求解 PDE 的基础。

1. torch

示例:张量操作

import torch

# 创建张量
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([4.0, 5.0, 6.0])

# 张量加法
z = x + y
print(z)  # 输出: tensor([5., 7., 9.])

# 张量乘法
z = x * y
print(z)  # 输出: tensor([ 4., 10., 18.])

# 张量的加法和乘法的其他操作
z = torch.add(x, y)
print(z)  # 输出: tensor([5., 7., 9.])
z = torch.mul(x, y)
print(z)  # 输出: tensor([ 4., 10., 18.])

2. torch.nn.functional(简称 F)

torch.nn.functional(通常简写为torch.nn.f或简单地称为F)是PyTorch中一个非常重要的模块,它包含了构建神经网络所需的大部分激活函数、损失函数、归一化层等函数式接口。这些函数不保留任何内部状态,即它们是无状态的,每次调用时都会接收输入并返回输出,而不会保存任何关于之前输入或输出的信息。这使得torch.nn.functional中的函数非常适合用于定义前向传播逻辑,同时也使得模型定义更加灵活和清晰。

主要功能分类

  1. 激活函数:如ReLU、Sigmoid、Tanh等,用于在神经网络层之间添加非线性。
  2. 损失函数:如MSELoss、CrossEntropyLoss等,用于计算预测值和真实值之间的差异。
  3. 归一化函数:如BatchNorm、LayerNorm等,用于对输入数据进行归一化处理,加速训练过程并提升模型性能。
  4. 卷积和池化操作:如conv2d、max_pool2d等,用于图像等数据的特征提取。
  5. 其他操作:如dropout、padding、embedding等,提供了丰富的网络构建工具。
示例:激活函数和损失函数
import torch
import torch.nn.functional as F

# 创建张量
x = torch.tensor([-1.0, 0.0, 1.0])

# ReLU 激活函数
relu_x = F.relu(x)
print(relu_x)  # 输出: tensor([0., 0., 1.])

# Sigmoid 激活函数
sigmoid_x = torch.sigmoid(x)
print(sigmoid_x)  # 输出: tensor([0.2689, 0.5000, 0.7311])

# 计算均方误差损失
target = torch.tensor([0.0, 1.0, 1.0])
loss = F.mse_loss(sigmoid_x, target)
print(loss)  # 输出: tensor(0.2201)

使用torch.nn.functional中的ReLU激活函数和CrossEntropyLoss损失函数:
import torch  
import torch.nn.functional as F  
  
# 假设我们有以下简单的模型参数(通常这些参数会由torch.nn.Module的子类管理)  
# 假设输入图像大小为1x28x28(1个通道,28x28像素)  
# 第一个全连接层将784(28*28)个输入转换为128个输出  
weight1 = torch.randn(784, 128)  
bias1 = torch.zeros(128)  
# 第二个全连接层将128个输入转换为10个输出(对应10个类别)  
weight2 = torch.randn(128, 10)  
bias2 = torch.zeros(10)  
  
# 模拟一个批次的数据(假设批次大小为1,即一张图像)  
# 这里我们随机生成一个1x28x28的图像,并展平为1x784  
x = torch.randn(1, 1, 28, 28)  # [batch_size, channels, height, width]  
x = x.view(1, -1)  # 展平为 [batch_size, 784]  
  
# 前向传播  
# 第一层全连接 + ReLU激活  
h1 = x.mm(weight1) + bias1  # [batch_size, 128]  
h1 = F.relu(h1)  
  
# 第二层全连接  
output = h1.mm(weight2) + bias2  # [batch_size, 10]  
  
# 假设真实标签是3(即手写数字3)  
label = torch.tensor([3], dtype=torch.long)  
  
# 计算损失  
loss = F.cross_entropy(output, label)  
  
print(f'Loss: {loss.item()}')

注意事项

  • 在实际使用中,通常会通过继承torch.nn.Module来构建和管理网络参数,因为这样可以更方便地利用PyTorch提供的自动求导、模型保存/加载等功能。
  • torch.nn.functional中的函数通常与torch.nn模块中的层(Layer)相对应,但函数式接口更加灵活,适合用于快速原型设计或简单网络构建。
  • 在进行模型训练时,通常会使用torch.optim中的优化器来更新模型参数,而torch.nn.functional中的函数则用于定义前向传播逻辑和计算损失。

3. torch.optim.lr_scheduler

PyTorch 学习率调度器详细案例

背景

在训练深度学习模型时,学习率的设置和调整对模型的训练效果和速度有着重要的影响。PyTorch 提供了多种学习率调度器,可以在训练过程中动态调整学习率。下面将详细解释如何使用 StepLRMultiStepLR 学习率调度器,并演示它们的使用。

示例代码

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR, MultiStepLR

# 创建一个简单的模型
model = torch.nn.Linear(10, 1)

# 创建优化器
optimizer = SGD(model.parameters(), lr=0.1)

# 创建学习率调度器
scheduler_step = StepLR(optimizer, step_size=10, gamma=0.1)
scheduler_multistep = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

# 模拟训练过程
for epoch in range(100):
    optimizer.step()  # 更新模型参数
    scheduler_step.step()  # 更新学习率
    scheduler_multistep.step()  # 更新学习率
    print(f"Epoch {epoch}: StepLR LR={scheduler_step.get_last_lr()}, MultiStepLR LR={scheduler_multistep.get_last_lr()}")

解释:

  • StepLR
    StepLR 是一种按固定步数调整学习率的调度器。
    step_size=10 表示每 10 个 epoch 调整一次学习率。
    gamma=0.1 表示每次调整时,将学习率乘以 0.1.
  • MultiStepLR
    MultiStepLR 是一种在指定的 epoch 列表中调整学习率的调度器。
    milestones=[30, 80] 表示在第 30 和第 80 个 epoch 时调整学习率。
    gamma=0.1 表示在这些 epoch 调整时,将学习率乘以 0.1.

请添加图片描述


本专栏致力于普及各种偏微分方程的不同数值求解方法,所有文章包含全部可运行代码。欢迎大家支持、关注!

作者 :计算小屋
个人主页 : 计算小屋的主页

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1996994.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

我在高职教STM32——I2C通信之读写EEPROM(1)

大家好,我是老耿,高职青椒一枚,一直从事单片机、嵌入式、物联网等课程的教学。对于高职的学生层次,同行应该都懂的,老师在课堂上教学几乎是没什么成就感的。正是如此,才有了借助CSDN平台寻求认同感和成就感的想法。在这里,我准备陆续把自己花了很多心思设计的教学课件分…

日撸Java三百行(day18:循环队列)

目录 一、顺序队列与循环队列 二、代码实现 1.循环队列创建 2.循环队列遍历 3.循环队列入队 4.循环队列出队 5.数据测试 6.完整的程序代码 总结 一、顺序队列与循环队列 在昨天,我们提到队列实现除了采用链式存储结构,还可以采用顺序存储结构&…

数字电路波形图绘制工具WaveDrom简介

最近写东西的时候,需要画波形图,无意中找到了一蛮好用的工具:WaveDrom WaveDrom 是一个 JavaScript 应用程序。WaveJSON 是一种描述数字时序图的格式。WaveDrom 直接在浏览器内部渲染这些时序图。元素 “signal” 是一个 WaveLane 数组。每个…

NO.4 软件外包公司

今天我们来聊聊国内的四大软件外包公司。这些公司不仅在国内市场中占据重要地位,还对全球软件外包行业产生了影响。 部分数据来源网络排名,按照职位量、增长速度排名,排名仅供参考,去某家公司一定要多方位参考,比如企查…

uniapp基础知识【搬代码】

基础知识 HTML、css、javaScript&#xff08;ES6&#xff09; HTML结构 1.View 类似于传统html中的div&#xff0c;用于包裹各种元素内容。 2.text文本 3.swoper 4.image 5.video 6.button 7.input <template><!-- <view class"content"><imag…

泛微E-office 10 schema_mysql接口敏感信息泄露漏洞复现 [附POC]

文章目录 泛微E-office 10 schema_mysql接口敏感信息泄露漏洞复现 [附POC]0x01 前言0x02 漏洞描述0x03 影响版本0x04 漏洞环境0x05 漏洞复现1.访问漏洞环境2.构造POC3.复现泛微E-office 10 schema_mysql接口敏感信息泄露漏洞复现 [附POC] 0x01 前言 免责声明:请勿利用文章内…

两篇论文同时获最佳论文荣誉提名,SIGGRAPH上首个Real-Time Live的中国团队用生成式AI创建3D世界

专注于计算机图形学的全球学术顶会 SIGGRAPH&#xff0c;正在出现新的趋势。 点击访问我的技术博客https://ai.weoknow.comhttps://ai.weoknow.com 在上周举行的 SIGGRAPH 2024 大会上&#xff0c;最佳论文等奖项中&#xff0c;来自上海科技大学 MARS 实验室的团队同时拿到两篇…

HTML表单元素

HTML表单元素 表单把用户的信息发给服务器。 <!DOCTYPE html> <html><head><meta charset"utf-8"><title></title> </head><body><form class"stylin_form1" action"process_form.php" met…

uni-app开发微信小程序注意事项,不要用element-ui

前端扩展组件千万不要用element-ui&#xff0c;开发的时候不报错&#xff0c;发布的时候会报错无法发布。 可以用vant weapp【注意是weapp】 iView weapp 附上hbuilder官方文档 组件的概念 | uni-app官网 (dcloud.net.cn)

git-贮藏区打补丁

1.显示所有贮藏 git stash list 2.将贮藏区的修改打补丁 git stash show -p stash{0} > patchName.patch commit打补丁 git 生成补丁文件及打补丁_git 生成指定目录补丁-CSDN博客 git patch的使用方法_git pattch-CSDN博客

「MyBatis」数据库相关操作

MyBatis 简介 MyBatis 是⼀个持久层框架&#xff0c;用于简化 JDBC 的开发 持久层指的就是持久化操作的层&#xff0c;通常指数据访问层 (dao)&#xff0c;是用来操作数据库的 Mapper 注解的接口表示该接口是 MyBatis 中的 Mapper 接口 回顾一下之前提到过的图 简单来说&…

如何选用合适的开源知识管理系统?10款软件推荐

国内外主流的10款开源知识管理软件对比&#xff1a;PingCode、Worktile、DokuWiki、MediaWiki、GitBook、Nuclino、Think、TiddlyWiki、AFFiNE、Foam。 在管理知识的广阔天地中&#xff0c;选择合适的工具可能会让你感到头痛。开源知识管理软件以其灵活性和成本效益在行业内脱颖…

Java设计模式-单例模式最佳实践

1. 单例模式简介 Java 单例模式是四大设计模式之一&#xff0c;属于创建型设计模式。从定义上看&#xff0c;它似乎是一种简单的设计模式&#xff0c;但在实现时&#xff0c;如若不注意&#xff0c;它会带来很多问题。 在本文中&#xff0c;我们将了解单例设计模式原则&#…

使用 GPU 加速的 XGBoost 预测出租车费用

目录 XGBoost GPU 加速的 XGBoost 用例数据集示例 将文件中的数据加载到 DataFrame 定义特征数组 保存模型 总结 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家&#xff0c; 可以当故事来看&#xf…

小智纯前端js报表实战5-绝对坐标-横向扩展

绝对坐标-横向扩展 概述 绝对坐标-横向扩展&#xff1a;绝对坐标定位 层次坐标是实现复杂报表的一个重要功能。 在进行小智报表模板设计时&#xff0c;单元格尚未进行扩展&#xff0c;但是有些时候需要获取扩展后的单元格并进行计算。例如&#xff0c;A1单元格扩展成A1-D1&am…

VUE+Spring前后台传值的坑,后台接收的String参数在末尾会出现 “=”

一、问题 VUESpringBoot做增删改查时&#xff0c;前端使用axios.post发起请求&#xff0c;传输主键字符型参数 taskId 到后台&#xff0c;后台再进行删除处理。 实际过程中发现后台拿到的数据再末尾多了一个等号&#xff0c;但是通过console.log(taskId)前台打印参数是正常的…

巴洛克风格的现代演绎,戴上亚法银耳机,感受古典雕花与现代声学的碰撞

flipears耳机品牌以其独特的风格、精细的配置和卓越的音质在耳机市场中很受欢迎&#xff0c;像是我最近用过的一款Artha Argentum亚法银&#xff0c;就采用了纯银外壳&#xff0c;而且用料扎实&#xff0c;具有出众的声学表现&#xff0c;带来了更干净清澈的声底。内在配置方面…

[Linux] LVM挂载的硬盘重启就掉的问题解决

问题&#xff1a;系统重启后挂在逻辑卷的盘会掉&#xff08;必现&#xff09; 环境&#xff1a;SUSE Linux 11 SP4 原因&#xff1a;boot.lvm是关闭的 解决&#xff1a;boot.lvm设置开启 参考资料&#xff1a; linux下lvm状态Not avaliable问题排查及处理(常见Suse操作系统…

使用ubuntu串口数据收和发不一致问题

串口配置 使用virtual Serial Port Driver Pro模拟串口两个串口&#xff0c;com2和com3&#xff0c;使用默认配置&#xff1b;通过virtual box 串口映射功能&#xff0c;在Ubuntu里使用CuteCom打开com2接受和发送数据&#xff0c;在windows里使用com3发送和接收数据。 遇到问…

24/8/9算法笔记 随机森林

"极限森林"&#xff08;Extremely Randomized Trees&#xff0c;简称ERT&#xff09;是一种集成学习方法&#xff0c;它属于决策树的变体&#xff0c;通常被归类为随机森林&#xff08;Random Forest&#xff09;的一种。极限森林的核心思想是在构建决策树时引入极端…