PyTorch初探:基本函数与案例实践

news2025/1/11 14:17:12

正文

        在熟悉了PyTorch的安装和环境配置后,接下来让我们深入了解PyTorch的基本函数,并通过一个简单的案例来实践这些知识。

1. 基本函数

        PyTorch的核心是张量(Tensor),它类似于多维数组,但可以在GPU上运行以加速计算。张量上的操作是构建神经网络层的基础。

 


    以下是PyTorch中一些常用的张量操作函数:

  • torch.tensor(): 创建一个新的张量。
  • torch.ones()torch.zeros(): 创建全1或全0的张量。
  • torch.randn(): 创建一个具有随机数的张量,这些随机数服从均值为0和标准差为1的正态分布(标准正态分布)。
  • torch.matmul(): 执行矩阵乘法。
  • torch.sum(): 计算张量中所有元素的和。


         此外,PyTorch还提供了自动求导机制,这是训练神经网络的关键。通过设置张量的requires_grad属性为True,PyTorch会跟踪对该张量执行的所有操作,以便后续计算梯度。

2. 案例实践:线性回归

        为了演示PyTorch的基本用法,我们将实现一个简单的线性回归模型。线性回归是一种预测模型,其中输出是输入的线性组合。

步骤如下:

  • 导入必要的库

import torch  
import torch.nn as nn  
import torch.optim as optim
  • 准备数据

这里我们使用简单的人工数据来演示。

# 输入数据  
x_data = torch.tensor([[1.0], [2.0], [3.0]])  
# 输出数据  
y_data = torch.tensor([[2.0], [4.0], [6.0]])
  • 定义模型

线性回归模型可以表示为y = wx + b,其中w是权重,b是偏置。

class LinearRegressionModel(nn.Module):  
    def __init__(self):  
        super(LinearRegressionModel, self).__init__()  
        self.linear = nn.Linear(1, 1)  # 输入和输出都是1维的  
  
    def forward(self, x):  
        y_pred = self.linear(x)  
        return y_pred  
  
model = LinearRegressionModel()
  • 定义损失函数和优化器

criterion = nn.MSELoss()  # 均方误差损失  
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降
  • 训练模型

# 训练周期  
epochs = 100  
  
for epoch in range(epochs):  
    # 前向传播  
    outputs = model(x_data)  
    loss = criterion(outputs, y_data)  
  
    # 反向传播和优化  
    optimizer.zero_grad()  
    loss.backward()  
    optimizer.step()  
  
    # 打印损失  
    if (epoch+1) % 10 == 0:  
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')  
  
# 测试模型  
with torch.no_grad():  
    prediction = model(torch.tensor([[4.0]]))  
    print(f'Prediction after training: {4.0} => {prediction.item()}')

  • 总结       

  • 在这个简单的案例中,我们展示了如何使用PyTorch构建、训练和测试一个基本的线性回归模型。通过这个过程,你应该对PyTorch的基本函数和工作流程有了更深刻的理解在实际应用中,你可能会处理更复杂的模型和数据集,但基本的原理和操作是相似的。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1414774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

野火霸道V2学习笔记

野火霸道V2学习笔记 STM32F103学习笔记说明基础配置配置KeilMDK配置串口下载程序美化Keil界面配置VScode 理论知识STM32命名方式例子 置位与清零GPIOGPIO简介GPIO和引脚的区别引脚的分类 GPIO 框图讲解保护二极管推挽输出开漏输出补充: 高阻态与悬空复用功能输出输入模式GPIO框…

SkiaSharp:.NET强大而灵活的跨平台图形库

在.Net 6之前,我们一般是使用System.Drawing.Common来生成图像。 但在.Net 6平台需要配置,才能在非Windows平台使用System.Drawing.Common。而从.Net 7开始,非Windows不再允许使用,官方也给我们推荐了几个替代库。 今天我们一起来…

兄弟HL-1208黑白激光打印机清零方法

兄弟HL-1208黑白激光打印机基本参数: 产品类型:黑白激光打印机(定位类型家用) 最大打印幅面:A4 最高分辨率:600600dpi 黑白打印速度:20ppm 内存标配:1MB,最大&#…

代码随想录算法训练营29期|day31 任务以及具体安排

理论基础 关于贪心算法,你该了解这些! 题目分类大纲如下: #算法公开课 《代码随想录》算法视频公开课 (opens new window):贪心算法理论基础! (opens new window),相信结合视频再看本篇题解,更有助于大家…

廖雪峰Python教程实战Day 2 - 编写Web App骨架,运行后不显示网页如何解决

教程代码如下&#xff1a; import logging; logging.basicConfig(levellogging.INFO)import asyncio, os, json, time from datetime import datetimefrom aiohttp import webdef index(request):return web.Response(bodyb<h1>Awesome</h1>)asyncio.coroutine de…

快速幂算法详解

目录 介绍 原理1 实现过程 原理2 取余运算 介绍 快速幂算法的目的就是让计算机很快地求出&#xff0c;暴力相乘的话&#xff0c;电脑要计算b次。用快速幂&#xff0c;计算次数在级别&#xff0c;很实用。 原理1 (1)如果将a自乘一次&#xff0c;就会变成。再把自乘一次就…

VR拍摄+制作

1.VR制作需要的图片宽高是2:1&#xff0c;需要360✖️180的图片&#xff0c;拍摄设备主要有两种&#xff1a; 1&#xff09;通过鱼眼相机拍摄&#xff0c;拍摄一组图片&#xff0c;然后通过PTGui来合成(拍摄复杂) 2&#xff09;全景相机&#xff0c;一键拍摄直接就能合成需要的…

【智能家居】6、语音控制及网络控制代码实现

一、语音控制 1、指令结构体编写 这个结构体定义了一个命令输入的模型。在这个模型中,包含以下几个部分: cmdName:一个长度为128的字符串,用于存储命令名称。dvicesName:一个长度为128的字符串,用于存储设备名称。cmd:一个长度为32的字符串,用于存储具体的命令。Init:…

南卡Neo2评测:实力诠释骨传导耳机全能旗舰,细节展现匠心之作

前段时间朋友让我帮他寻找一款佩戴舒适、音质体验好的蓝牙耳机&#xff0c;因为比较忙所以一直把这件事搁置了&#xff0c;刚好这两天比较闲&#xff0c;所以也是在综合个人的经验和目前较为热门的一些品牌款式&#xff0c;决定帮他寻找一款骨传导耳机&#xff0c;因为骨传导耳…

手机视频压缩怎么压缩?一键瘦身~

现在手机已经成为我们日常生活中必不可少的工具&#xff0c;而在手机的应用领域中&#xff0c;文件的传输和存储是一个非常重要的问题。很多用户都会遇到这样一个问题&#xff0c;那就是在手机上存储的文件太多太大&#xff0c;导致手机存储空间不足&#xff0c;那么怎么在手机…

docker-compose Install influxdb1+influxdb2+telegraf

influxd2前言 influxd2 是 InfluxDB 2.x 版本的后台进程,是一个开源的时序数据库平台,用于存储、查询和可视化时间序列数据。它提供了一个强大的查询语言和 API,可以快速而轻松地处理大量的高性能时序数据。 telegraf 是一个开源的代理程序,它可以收集、处理和传输各种不…

[Python] 如何在Windows下安装图形可视化工具graphviz

什么是graphviz? Graphviz是一款开源的图形可视化工具&#xff0c;用于生成各种结构化数据的图形表示。它支持多种图形排列算法&#xff0c;可以将复杂的数据关系用图形的方式直观地展示出来。Graphviz广泛应用于软件工程、数据可视化、计算机网络以及其他领域的可视化分析中…

SpringBoot之时间数据前端显示格式化

背景 在实际我们通常需要在前端显示对数据操作的时间或者最近的更新时间&#xff0c;如果我们只是简单的使用 LocalDateTime.now()来传入数据不进行任何处理那么我们就会得到非常难看的数据 解决方式&#xff1a; 1). 方式一 在属性上加上注解&#xff0c;对日期进行格式…

【JavaEE进阶】 #{}和${}

文章目录 &#x1f343;前言&#x1f333;#{}和${}使⽤&#x1f6a9;Interger类型的参数&#xff08;基础数据类型&#xff09;&#x1f388;使用#{}&#x1f388;使用${} &#x1f6a9;String类型的参数使用&#x1f388;#{}使用&#x1f388;${} &#x1f38d;#{}和${}区别&a…

中国新能源汽车持续跑出发展“加速度”,比亚迪迎来向上突破

2023年已经过去&#xff0c;对于汽车圈而言&#xff0c;2023年是中国车市的分水岭&#xff0c;在这一年&#xff0c;中国汽车工业70年以来首次进入全球序列&#xff0c;自主品牌强势霸榜&#xff0c;销量首次超过合资车。要知道&#xff0c;这是自大众于1984年进入中国市场成立…

【PLC 网络通信及 MODBUS TCP通信测试】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、PLC的通信方式有&#xff1f;二、PLC网络通信1.MODBUS TCP 三、安装虚拟软件2.测试2.1 测试程序2.2 对接虚拟软件 总结 前言 提示&#xff1a;这里可以添加…

2024年人工智能产业十大发展趋势

2024年人工智能产业十大发展趋势 技术变革1. 多模态预训练大模型将是人工智能产业的标配2. 高质量数据愈发稀缺将倒逼数据智能飞跃3. 智能算力无处不在的计算新范式加速实现 应用创新4. 人工智能生成内容&#xff08;AIGC&#xff09;应用向全场景渗透5. 人工智能驱动科学研究&…

第8章 异常

第8章 异常 学习目标 能够辨别程序中异常和错误 说出异常的分类 说出虚拟机处理异常的方式 列出常见的5个运行时异常 列出常见的5个编译时异常 能够使用try…catch关键字处理异常 能够使用throw抛出异常对象 能够使用throws关键字处理异常 能够自定义异常类 能够处理自定义异常…

【精选推荐】3款强大的API渗透测试工具

1免责声明 请勿利用文章内的相关技术从事非法测试&#xff0c;由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;作者不为此承担任何责任。工具来自网络&#xff0c;安全性自测。 2前言 给大家介绍三款优秀的…

单片机学习笔记---独立按键控制LED显示二进制

这节我们来实现独立按键的第三个功能&#xff0c;独立按键控制LED显示二进制 新创建一个工程文件&#xff0c;然后上来我们就要把基本框架写好&#xff0c;这是基本的习惯 老规矩&#xff0c;然后把Delay 1ms的代码复制过来 复制过来后改造一下&#xff1a; 把1ms删掉&#x…