PyTorch 提示和技巧:从张量到神经网络

news2025/1/10 20:39:11

在这里插入图片描述

张量和梯度

我们将深入探讨使用 PyTorch 构建自己的神经网络必须了解的 2 个基本概念:张量和梯度。

张量

张量是 PyTorch 中的中央数据单元。它们是类似于数组的数据结构,在功能和属性方面与 Numpy 数组非常相似。它们之间最重要的区别是 PyTorch 张量可以在 GPU 的设备上运行以加速计算。
在这里插入图片描述

# 使用Tensor对象创建了一个 3x3 形状的未初始化张量。
import torch
tensor_uninitialized = torch.Tensor(3, 3)
tensor_uninitialized
"""
tensor([[1.7676e-35, 0.0000e+00, 3.9236e-44],
        [0.0000e+00,        nan, 0.0000e+00],
        [1.3733e-14, 1.2102e+25, 1.6992e-07]])
"""
# 我们还可以创建用零、一或随机值填充的张量。
tensor_rand = torch.rand(3, 3)
tensor_rand
"""
tensor([[0.6398, 0.3471, 0.6329],
        [0.4517, 0.2253, 0.8022],
        [0.9537, 0.1698, 0.5718]])
"""

就像 Numpy 数组一样,PyTorch 允许我们在张量之间执行数学运算,同样的 Numpy 数组中的其他常见操作,如索引和切片,也可以使用 PyTorch 中的张量来实现。

# 数学运算
x = torch.Tensor([[1, 2, 3],
                  [4, 5, 6]])
tensor_add = torch.add(x, x)
"""
tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])
"""

梯度📉

假设有 2 个参数 a 和 b ,梯度是一个参数相对于另一个参数的偏导数。导数告诉你当你稍微改变其他一些量时,给定量会发生多少变化。在神经网络中,梯度是损失函数相对于模型权重的偏导数。我们只想找到带来损失函数梯度最低的权重。
在这里插入图片描述

PyTorch 使用torch库中的Autograd包来跟踪张量上的操作。

# 01. 默认情况下,张量没有关联的梯度。
tensor= torch.Tensor([[1, 2, 3],
                      [4, 5, 6]])
tensor.requires_grad
"""
False
"""
# 02. 可以通过调用requires_grad_函数在张量上启用跟踪历史记录。
tensor.requires_grad_()
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]], requires_grad=True)
"""
# 03. 但是目前该 Tensor 还没有梯度
print(tensor.grad)
"""
None
"""
# 04. 现在,让我们创建一个等于前一个张量中元素均值的新张量,以计算张量相对于新张量的梯度。
mean_tensor = tensor.mean()
mean_tensor
"""
tensor(3.5000, grad_fn=<MeanBackward0>)
"""
# 05. 要计算梯度,我们需要显式执行调用backward()函数的反向传播。
mean_tensor.backward()
print(tensor.grad)
"""
tensor([[0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667]])
"""

使用 PyTorch 的神经网络

我们可以将神经网络定义为扩展 torch.nn.Module 类的 Python 类。在这个类中,我们必须定义 2 个基本方法:

init()是类的构造函数。在这里,我们必须定义构成我们网络的层。forward()是我们定义网络结构以及各层连接方式的地方。这个函数接受一个输入,代表模型将被训练的特征。我将向你展示如何构建可用于分类问题的简单卷积神经网络并在 MNIST 数据集上训练它。
在这里插入图片描述

首先,我们必须导入torch和我们需要的所有模块。可以创建我们的模型了

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

# CNN 由 2 个卷积层组成,后面是一个全局平均池化层。最后,我们有 2 个全连接层和一个softmax来获得最终的输出概率。

class My_CNN(nn.Module):
   def __init__(self):
       super(My_CNN, self).__init__()
       self.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), padding=1)
       self.conv2 = nn.Conv2d(64, 64, kernel_size=(3, 3), padding=1)
       self.avg_pool = nn.AvgPool2d(28)
       self.fc1 = nn.Linear(64, 64)
       self.fc2 = nn.Linear(64, 10)
   def forward(self, x):
       x = F.relu(self.conv1(x))
       x = F.relu(self.conv2(x))
       x = self.avg_pool(x)
       x = x.view(-1, 64)
       x = F.relu(self.fc1(x))
       x = self.fc2(x)
       x = F.softmax(x)
       
       return x

其次,加载数据集,直接从 PyTorch 检索 MNIST 数据集,并使用 PyTorch 实用程序将数据集拆分为训练集和验证集。

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
mnist = MNIST("data", download=True, train=True)
## create training and validation split
split = int(0.8 * len(mnist))
index_list = list(range(len(mnist)))
train_idx, valid_idx = index_list[:split], index_list[split:]
## create sampler objects using SubsetRandomSampler
train = SubsetRandomSampler(train_idx)
valid = SubsetRandomSampler(valid_idx)

# 使用DataLoader创建迭代器对象,它提供了使用多处理 worker 并行批处理、随机播放和加载数据的能力。
train_loader = DataLoader(mnist, batch_size=256, sampler=train)
valid_loader = DataLoader(mnist, batch_size=256, sampler=valid)

现在我们拥有了开始训练模型的所有要素。然后再定义损失函数和优化器,Adam将用作优化器,交叉熵用作损失函数。

model = My_CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()

最后开始训练,所有 PyTorch 训练循环都将经过每个 epoch 和每个DataPoint(在训练DataLoader 对象中)。


epochs = 10
for epoch in range(epochs):
  train_loss, valid_loss = [], []
  for data, target in train_loader:
    
    # forward propagation  
    outputs = model(data)
    # loss calculation
    loss = loss_function(outputs, target)
    # backward propagation
    optimizer.zero_grad()
    loss.backward()
    # weights optimization
    optimizer.step()
    train_loss.append(loss.item())
  for data, target in valid_loader:
    outputs = model(data)
    loss = los_function(outputs, target)
    valid_loss.append(loss.item())
  print('Epoch: {}, training loss: {}, validation loss: {}'
        .format(epoch, np.mean(train_loss), np.mean(valid_loss)))

在验证阶段,我们必须像在训练阶段所做的那样循环验证集中的数据。不同之处在于我们不需要对梯度进行反向传播。


with torch.no_grad():
  correct = 0
  total = 0
  for data, target in valid_loader:
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print('Validation set Accuracy: {} %'.format(100 * correct / total))

就是这样!现在你已准备好构建自己的神经网络。你可以尝试通过增加模型复杂性向网络添加更多层来获得更好的性能。

请关注博主,一起玩转人工智能及深度学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/611054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hadoop中HDFS概述

Hadoop概述之HDFS HDFS架构概述优缺点HDFS架构HDFS文件块大小HDFS的shell命令HDFS读写流程写数据流程 HDFS读数据流程NameNode 和 SecondaryNameNode工作机制DataNode工作机制DataNode数据完整性如何保证 端口名称Hadoop2.xHadoop3.xNameNode内部通信端口8020/9000NameNode HTT…

【STM32单片机】基于语音识别的智能分类垃圾桶,ld3320语音识别模块如何使用,mp3播放模块如何使用

文章目录 需求语音识别模块MY1690 播放模块舵机源码 需求 对于“可回收物”“有害垃圾”“厨余垃圾”“其它垃圾”&#xff0c;不能分清扔到哪个垃圾桶怎么办&#xff1f; 基于语音识别的智能分类垃圾桶&#xff0c;识别到关键词就打开对应的垃圾桶&#xff0c;完全没有分不清…

echarts中彻底清除所有实例和相关数据

单个实例 dispose销毁实例&#xff0c;销毁后实例无法再被使用。 myChart.dispose();实例比较多的时候 获取Dom元素 let doms document.getElementsByClassName(my-chart)销毁所有实例 if(doms && doms.length) {for (let i 0; i < doms.length; i) {let chartIn…

Jumpserver 2.28.8使用分享

目录 一、Jumpserver 介绍 1、跳板机和堡垒机理解 1.1、跳板机 1.2、堡垒机 2、jumpserver简介 二、Jumpserver 安装部署 2.1、部署规划 2.2 、安装要求 JumpServer 环境要求: 2.3、安装方法介绍 官方提供了多种安装方法 三、Jumpserver平台使用 3.1、Admin登录 3.…

bug(Tomcat):StandardContext.startInternal 由于之前的错误,Context[/day01]启动失败

引出 项目启动失败&#xff0c;一个困扰了一上午的bug 报错信息 org.apache.catalina.core.StandardContext.startInternal 一个或多个筛选器启动失败。完整的详细信息将在相应的容器日志文件中找到 org.apache.catalina.core.StandardContext.startInternal 由于之前的错误…

骨传导是哪个意思,推荐几款性能优的骨传导耳机

​骨传导耳机是通过头部骨迷路传递声音&#xff0c;而不是直接通过耳膜的振动来传递声音。与传统的入耳式耳机相比&#xff0c;骨传导耳机不会堵耳朵&#xff0c;在跑步、骑车等运动时可以更好的接收外界环境音&#xff0c;保护听力&#xff0c;提升安全性。此外&#xff0c;骨…

图解LeetCode——114. 二叉树展开为链表

一、题目 给你二叉树的根结点 root &#xff0c;请你将它展开为一个单链表&#xff1a; 展开后的单链表应该同样使用 TreeNode &#xff0c;其中 right 子指针指向链表中下一个结点&#xff0c;而左子指针始终为 null 。 展开后的单链表应该与二叉树 先序遍历 顺序相同。 二…

公式+ChatGPT:为你的标题创作注入新鲜活力

大家是不是经常遇到文章已经写好了&#xff0c;但是标题却还空着&#xff0c;不是不会写&#xff0c;就是写得平淡无奇&#x1f602;。自己都觉得无趣的标题又怎么能吸引有趣的灵魂呢&#xff1f;何不让chatGPT来试试呢&#xff1f; 首先&#xff0c;我们要明白一个基础理念&am…

微软 AD 已成过去式,这个身份领域国产化替代方案你了解吗?

随着全球互联网和数字化浪潮的不断发展&#xff0c;信息安全已成为不可忽视的问题&#xff0c;并随着日益复杂的国内外市场格局&#xff0c;其重要性更加凸显。我国政府也相继印发和实施了《数字中国建设整体布局规划》、《全国一体化大数据体系建设指南》等一系列政策&#xf…

【JavaEE】Servlet的API详解

Servlet的API详解O(∩_∩)O~&#xff1a; 文章目录 JavaEE & Servlet的API详解1. HttpServlet抽象类1.1 init方法1.2 destroy方法1.3 service方法 2. HttpRequest接口2.1 在浏览器上显示请求首行2.2 在浏览器上显示请求header2.3 getParameter方法 - 最常用的API之一2.4 js…

Stable Didffusion 学习笔记经验总结

值的概念 在Stable Diffusion中&#xff0c;有很多要设置的参数&#xff0c;这些参数起到的作用非常重要&#xff0c;直接决定了出图的各种样子和质量&#xff0c;经过实践&#xff0c;我大概搞明白他们遵循的规律&#xff0c;因为程序员是要与AI对话的&#xff0c;所以所谓的…

【CMake 入门与进阶(3)】 CMakeLists.txt 语法规则基础及部分常用指令(附使用代码)

在上两篇中&#xff0c;笔者通过几个简单地示例向大家演示了 cmake 的使用方法&#xff0c;由此可知&#xff0c;cmake 的使用方法其实还是非常简单的&#xff0c;重点在于编写 CMakeLists.txt&#xff0c;CMakeLists.txt 的语法规则也简单&#xff0c;并没有 Makefile 的语法规…

操作系统复习2.3.4-进程同步问题

生产者-消费者 系统中有一组生产者进程和一组消费者进程 两者共享一个初始为空&#xff0c;大小为n的缓冲区 缓冲区没满&#xff0c;生产者才能放入 缓冲区没空&#xff0c;消费者才能取出 互斥地访问缓冲区 互斥要在同步之后&#xff0c;不然会导致想要同步&#xff0c;但由…

39从零开始学Java之面向对象的继承到底是怎么回事?

作者&#xff1a;孙玉昌&#xff0c;昵称【一一哥】&#xff0c;另外【壹壹哥】也是我哦 千锋教育高级教研员、CSDN博客专家、万粉博主、阿里云专家博主、掘金优质作者 前言 在上一篇文章中&#xff0c;壹哥给大家讲解了面向对象三大特征之一的封装&#xff0c;现在我们还有另…

JWT strings must contain exactly 2 period characters. Found: 0

登录接口异常报错&#xff1a; 这是登录接口报错&#xff0c;实际上他不走登录接口&#xff0c;直接走的拦截器&#xff0c;拦截器应配置好了登录接口的放行&#xff0c;登录接口写的也没有问题&#xff0c;拦截器解析也没有问题&#xff0c;因为之前都是好用的&#xff0c;本…

人车网租赁软件开发|人车网租赁系统|租赁系统源码功能

经过租赁小程序不只可以使物品得到充沛的运用&#xff0c;还能减少一些资源的浪费&#xff0c;租赁行业这两年因为互联网技术的完善&#xff0c;发展也在不断进步&#xff0c;租赁系统定制开发功能也在不断完善&#xff0c;那么企业想要开发租赁小程序的时分需求留意哪些方面呢…

深入了解Java虚拟机之高效并发

目录 Java内存模型与线程 概述 硬件的效率与一致性 Java内存模型 主内存与工作内存 内存间交互操作 对于volatile型变量的特殊规则 原子性、可见性与有序性 先行发生原则 Java与线程 线程实现 线程调度 状态切换 小结 线程安全与锁优化 概述 线程安全 Java中…

HDR显示技术

什么是HDR? HDR&#xff08;High-Dynamic Range&#xff0c;简称HDR&#xff09;是指高动态范围图像&#xff0c;是一种能够显示更大的亮度范围和对比度的图像技术。HDR可以让暗部的细节变亮&#xff0c;亮部的细节不失真&#xff0c;呈现出更自然、更真实的画面&#xff0c;…

记一次618军演压测TPS上不去排查及优化 | 京东云技术团队

本文内容主要介绍&#xff0c;618医药供应链质量组一次军演压测发现的问题及排查优化过程。旨在给大家借鉴参考。 背景 本次军演压测背景是&#xff0c;2B业务线及多个业务侧共同和B中台联合军演。 现象 当压测商品卡片接口的时候&#xff0c;cpu达到10%&#xff0c;TPS只有…

Tomcat基本原理

1.Tomcat核心&#xff1a; Http服务器Servlet容器 组件分工&#xff1a; 连接器Connector&#xff1a;处理 Socket 连接&#xff0c;负责网络字节流与 Request 和 Response 对象的转化。容器Container&#xff1a;加载和管理 Servlet&#xff0c;以及具体处理 Request 请求。 …