PyTorch 基础学习(5)- 神经网络

news2025/1/23 13:11:37

系列文章:
PyTorch 基础学习(1) - 快速入门
PyTorch 基础学习(2)- 张量 Tensors
PyTorch 基础学习(3) - 张量的数学操作
PyTorch 基础学习(4)- 张量的类型
PyTorch 基础学习(5)- 神经网络

介绍

PyTorch 提供了一套强大的工具来构建和训练神经网络。其中的核心组件之一是 torch.nn,它提供了模块和类以帮助您创建和定制神经网络。

参数和模块

torch.nn.Parameter

  • torch.nn.Parameter() 是一种特殊的 Variable,常用于模块参数。
  • Parameter 被赋值给模块的属性时,它会自动添加到模块的参数列表中,成为模型可学习的参数。
  • VariableParameter 的区别:
    • Parameter 不能是 volatile,并且默认 requires_grad=True,而 Variable 默认 requires_grad=False

torch.nn.Module

  • 所有神经网络模块的基类。
  • 您的模型应继承此类。
  • 模块可以包含其他模块,形成树形结构。将子模块赋值为属性会自动注册它们。
示例
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

模块方法

  • add_module(name, module): 向当前模块添加子模块。
  • children(): 返回当前模块的子模块迭代器。
  • modules(): 返回网络中所有模块的迭代器,包括自身和所有子模块。

移动模块

  • cpu(): 将模块参数和缓冲区移动到 CPU。
  • cuda(device_id=None): 将模块参数和缓冲区移动到 GPU。
  • double(): 将参数和缓冲区的数据类型转换为 double
  • float(): 将参数和缓冲区的数据类型转换为 float
  • half(): 将参数和缓冲区的数据类型转换为 half

评估和训练模式

  • eval(): 将模块设置为评估模式,影响诸如 Dropout 和 BatchNorm 等模块。
  • train(mode=True): 将模块设置为训练模式。

保存和加载模型

  • load_state_dict(state_dict): 从状态字典中加载参数和缓冲区。
  • state_dict(): 返回包含模块状态的字典。

线性层

torch.nn.Linear

  • 对输入数据进行线性变换:( y = Ax + b )。
示例
import torch.nn as nn
m = nn.Linear(20, 30)

卷积层

torch.nn.Conv2d

  • 进行 2D 卷积操作。
示例
import torch.nn as nn
m = nn.Conv2d(16, 33, 3, stride=2)

池化层

torch.nn.MaxPool2d

  • 进行 2D 最大池化操作。
示例
import torch.nn as nn
m = nn.MaxPool2d(3, stride=2)

torch.nn.AvgPool2d

  • 进行 2D 平均池化操作。
示例
import torch.nn as nn
m = nn.AvgPool2d(3, stride=2)

激活函数

常用激活函数

  • ReLU: 修正线性单元, R e L U ( x ) = m a x ( 0 , x ) ReLU(x)=max(0,x) ReLU(x)=max(0,x)
  • Sigmoid: S i g m o i d ( x ) = 1 / 1 + e − x Sigmoid(x)=1/1 + e^{-x} Sigmoid(x)=1/1+ex
  • Tanh: 双曲正切函数, t a n h ( x ) tanh(x) tanh(x)
示例
import torch.nn as nn
m = nn.ReLU()

循环神经网络层

循环神经网络(RNN)是一类用于处理序列数据的神经网络。PyTorch 提供了多种循环层,包括 RNNLSTMGRU,用于构建复杂的序列模型。下面我们详细介绍这些循环层及其使用方法。

torch.nn.RNN

torch.nn.RNN 实现了多层 Elman RNN,适用于输入序列的处理。它通过循环连接来保持序列中每个时间步的信息。可以选择使用 tanhrelu 作为激活函数。

示例
import torch
import torch.nn as nn
from torch.autograd import Variable

# 创建一个 RNN 层,输入维度为 10,隐状态维度为 20,使用两层堆叠
rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2)

# 输入数据,形状为 (序列长度, 批量大小, 特征维度)
input = Variable(torch.randn(5, 3, 10))

# 初始隐状态,形状为 (层数, 批量大小, 隐状态维度)
h0 = Variable(torch.randn(2, 3, 20))

# 前向传播,计算输出和新的隐状态
output, hn = rnn(input, h0)

# 输出是最后一层的输出,hn 是最后一个时间步的隐状态

torch.nn.LSTM

torch.nn.LSTM 实现了长短时记忆网络(LSTM),用于处理更复杂的序列模式,特别是长序列。LSTM 使用门控机制(包括输入门、遗忘门和输出门)来控制信息的流动,从而有效地捕捉序列中的长期依赖关系。

示例
import torch
import torch.nn as nn
from torch.autograd import Variable

# 创建一个 LSTM 层,输入维度为 10,隐状态和细胞状态维度为 20,使用两层堆叠
lstm = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)

# 输入数据,形状为 (序列长度, 批量大小, 特征维度)
input = Variable(torch.randn(5, 3, 10))

# 初始隐状态和细胞状态,形状为 (层数, 批量大小, 隐状态维度)
h0 = Variable(torch.randn(2, 3, 20))
c0 = Variable(torch.randn(2, 3, 20))

# 前向传播,计算输出、最后的隐状态和细胞状态
output, (hn, cn) = lstm(input, (h0, c0))

# 输出是最后一层的输出,hn 和 cn 分别是最后一个时间步的隐状态和细胞状态

torch.nn.GRU

torch.nn.GRU 实现了门控循环单元(GRU)网络,是一种比 LSTM 更简单的结构,常用于处理序列数据。GRU 通过合并输入门和遗忘门,简化了门控机制,同时保持了捕捉长期依赖的能力。

示例
import torch
import torch.nn as nn
from torch.autograd import Variable

# 创建一个 GRU 层,输入维度为 10,隐状态维度为 20,使用两层堆叠
gru = nn.GRU(input_size=10, hidden_size=20, num_layers=2)

# 输入数据,形状为 (序列长度, 批量大小, 特征维度)
input = Variable(torch.randn(5, 3, 10))

# 初始隐状态,形状为 (层数, 批量大小, 隐状态维度)
h0 = Variable(torch.randn(2, 3, 20))

# 前向传播,计算输出和新的隐状态
output, hn = gru(input, h0)

# 输出是最后一层的输出,hn 是最后一个时间步的隐状态

以上这些循环层可以用于处理序列数据,如时间序列预测、自然语言处理等。选择合适的循环层和参数设置可以帮助您构建出性能优异的序列模型。

Dropout 层

torch.nn.Dropout

  • 随机将输入张量中的部分元素置零。
示例
import torch.nn as nn
m = nn.Dropout(p=0.5)

损失函数

常用损失函数

  • L1Loss: 平均绝对误差损失。
  • MSELoss: 均方误差损失。
  • CrossEntropyLoss: 将 LogSoftMax 和 NLLLoss 集成在一个类中。
示例
import torch.nn as nn
criterion = nn.MSELoss()

工具

torch.nn.utils.clip_grad_norm

  • 裁剪参数梯度的范数。

torch.nn.utils.rnn

  • 用于处理变长序列的 RNN 的函数。
序列的打包和填充
  • **pack_padded_sequence

应用实例:多项式回归

以下是一个使用 PyTorch 构建和训练循环神经网络(RNN)进行简单时间序列预测的完整示例。该脚本展示了如何使用 LSTM 层来处理序列数据,包括数据准备、模型定义、训练和评估。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
from sklearn.preprocessing import MinMaxScaler

# 生成示例数据:一个正弦波
# 设置随机种子以确保可重复性
np.random.seed(0)
torch.manual_seed(0)

# 生成一个正弦波序列
def generate_data(seq_length=50, num_samples=1000):
    x = np.linspace(0, 100, num_samples)
    y = np.sin(x) + 0.1 * np.random.randn(num_samples)  # 添加一些噪声
    return y

# 数据预处理:将数据归一化到 [0, 1] 区间,并构造序列样本
def create_dataset(data, seq_length):
    scaler = MinMaxScaler(feature_range=(0, 1))
    data_normalized = scaler.fit_transform(data.reshape(-1, 1)).flatten()

    sequences = []
    targets = []

    for i in range(len(data_normalized) - seq_length):
        sequences.append(data_normalized[i:i+seq_length])
        targets.append(data_normalized[i+seq_length])

    return np.array(sequences), np.array(targets), scaler

# 定义 LSTM 模型
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, num_layers=1):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_()
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_()

        # 前向传播 LSTM
        out, _ = self.lstm(x, (h0.detach(), c0.detach()))

        # 从最后一个时间步提取输出
        out = self.fc(out[:, -1, :])
        return out

# 参数设置
seq_length = 50
num_samples = 1000
batch_size = 16
num_epochs = 200
learning_rate = 0.01

# 生成和处理数据
data = generate_data(seq_length, num_samples)
sequences, targets, scaler = create_dataset(data, seq_length)

# 转换为 PyTorch 的张量格式
sequences = torch.from_numpy(sequences).float().unsqueeze(2)  # (样本数, 序列长度, 特征数)
targets = torch.from_numpy(targets).float().unsqueeze(1)  # (样本数, 1)

# 构造数据集和数据加载器
dataset = torch.utils.data.TensorDataset(sequences, targets)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 创建模型、定义损失函数和优化器
model = LSTMModel(input_size=1, hidden_size=50, num_layers=1)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for batch_seqs, batch_targets in dataloader:
        # 前向传播
        outputs = model(batch_seqs)
        loss = criterion(outputs, batch_targets)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if (epoch+1) % 20 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# 评估模型
model.eval()
with torch.no_grad():
    # 使用训练数据进行预测
    train_pred = model(sequences).detach().numpy()
    train_pred_rescaled = scaler.inverse_transform(train_pred)

    # 原始数据逆归一化
    targets_rescaled = scaler.inverse_transform(targets.numpy())

# 绘制结果
plt.figure(figsize=(10, 6))
plt.plot(data, label='Original Data')
plt.plot(range(seq_length, seq_length + len(train_pred_rescaled)), train_pred_rescaled, label='LSTM Prediction')
plt.xlabel('Time')
plt.ylabel('Value')
plt.legend()
plt.show()

输出结果:
在这里插入图片描述

代码说明

  1. 生成数据:

    • 生成一个正弦波,并添加噪声以模拟真实数据。
    • 使用 np.linspace 创建一个线性间隔的数组来表示时间。
  2. 数据预处理:

    • 使用 MinMaxScaler 将数据归一化到 [0, 1] 区间,以帮助模型更快地收敛。
    • 将数据转换为固定长度的序列样本,每个样本的长度为 seq_length
  3. LSTM 模型定义:

    • 定义 LSTMModel 类,继承自 nn.Module
    • 使用 LSTM 层和全连接层来实现序列到序列的映射。
  4. 训练过程:

    • 使用 MSELoss 作为损失函数,Adam 作为优化器。
    • 在每个 epoch 内,迭代数据加载器进行批次训练,并更新模型参数。
  5. 评估和可视化:

    • 在训练结束后,用训练数据进行预测,并将结果与原始数据对比。
    • 使用 matplotlib 绘制原始数据和预测结果。

该示例展示了如何使用 PyTorch 实现基本的时间序列预测任务,您可以根据需要对数据和模型进行调整以适应不同的应用场景,如:股票预测。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2036479.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

<数据集>路面坑洼识别数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:681张 标注数量(xml文件个数):681 标注数量(txt文件个数):681 标注类别数:1 标注类别名称:[pothole] 使用标注工具:labelImg 标注规则:对类…

15.3 模型评估与调优

欢迎来到我的博客,很高兴能够在这里和您见面!欢迎订阅相关专栏: 工💗重💗hao💗:野老杂谈 ⭐️ 全网最全IT互联网公司面试宝典:收集整理全网各大IT互联网公司技术、项目、HR面试真题.…

给SQL server数据库表字段添加注释SQL,附修改、删除注释SQL及演示

目录 一. 前提小知识(数据库连接,数据库,SCHEMA,Table的关系) 二. 添加备注 2.1 添加备注基本语法(sys.sp_addextendedproperty) 2.2 SQL演示 2.3 fn_listextendedproperty函数查询备注个数 2.4 开发常用添加注释语法 三. 修改备注 3…

深入理解 PHP 高性能框架 Workerman 守护进程原理

大家好,我是码农先森。 守护进程顾名思义就是能够在后台一直运行的进程,不会霸占用户的会话终端,脱离了终端的控制。相信朋友们对这东西都不陌生了吧?如果连这个概念都还不能理解的话,建议回炉重造多看看 Linux 进程管…

C++:vector类(default关键字,迭代器失效)

目录 前言 成员变量结构 iterator定义 size capacity empty clear swap []运算符重载 push_back pop_back reserve resize 构造函数 默认构造函数 default 迭代器构造 拷贝构造函数 赋值重载函数 析构函数 insert erase 迭代器失效问题 insert失效 er…

Linux使用学习笔记3 系统运维监控基础

系统运维监控类命令 查询每个进程的线程数 for pid in $(ps -ef | grep -v grep|grep "systemd" |awk {print $2});do echo ${pid} > /tmp/a.txt;cat /proc/${pid}/status|grep Threads > /tmp/b.txt;paste /tmp/a.txt /tmp/b.txt;done|sort -k3 -rn for pid…

mfc100u.dll丢失问题分析,详细讲解mfc100u.dll丢失解决方法

面对mfc100u.dll文件丢失带来的挑战时,许多用户都可能感到有些无助,尤其是当这一问题影响到他们日常使用的软件时。但实际上,存在几种有效方法可以帮助您快速恢复该关键的系统文件。为了方便不同水平的用户,本文将详细解析各种处理…

自动化测试工具Selenium IDE

简介 Selenium IDE 是实现Web自动化的一种便捷工具,本质上它是一种浏览器插件。该插件支持Chrome和Firefox浏览器,拥有录制、编写及回放操作等功能,能够快速实现Web的自动化测试。 使用场景 1、Selenium IDE本身的定位并不是用于复杂的自动…

Ps:首选项 - 技术预览

Ps菜单:编辑/首选项 Edit/Preferences 快捷键:Ctrl K Photoshop 首选项中的“技术预览” Technology Previews选项卡允许用户启用或禁用一些实验性功能,以测试或使用 Adobe 提供的最新技术。 技术预览 Technology Previews 启用保留细节 2.0…

如何解决浏览器页面过曝,泛白等问题

问题描述,分别对应edge和chrome浏览器这是什么原因?

使用C#禁止Windows系统插入U盘(除鼠标键盘以外的USB设备)

试用网上成品的禁用U盘的相关软件,发现使用固态硬盘改装的U盘以及手机等设备,无法被禁止,无奈下,自己使用C#手搓了一个。 基本逻辑: 开机自启;启动时,修改注册表,禁止系统插入USB存…

字符串函数!!!(续)(C语言)

一. strtok函数的使用 继续上次的学习,今天我们来认识一个新的函数strtok,它的原型是char* strtok(char* str,const char* sep),sep参数指向了一个字符串,定义了用作分隔符的字符合集,第一个参数指定⼀个字符串&#…

DataStream API的Joining操作

目录 Window Join Tumbling Window Join Sliding Window Join Session Window Join Interval Join Window CoGroup Window Join 窗口连接(window join)将两个流的元素连接在一起,这两个流共享一个公共键,并且位于同一窗口。…

从老旧到智慧病房,全视通智慧病房方案减轻医护工作负担

传统的老旧病房面临着诸多挑战。例如,患者风险信息的管理仍依赖于手写记录和人工核查;病房呼叫系统仅支持基本的点对点呼叫,缺乏与其它系统的集成能力;信息传递主要依靠医护人员口头传达;护理信息管理分散且不连贯&…

JavaEE-多线程

前情提要:本文内容过多,建议搭配目录食用,想看哪里点哪里~ PC端目录 手机端目录 话不多说,我们正式开始~~ 目录 多线程的概念进程和线程的区别和联系:使用Java代码进行多线程编程Thread类中的方法和属性线程的核心操作1. 启动…

【mamba学习】(一)SSM原理与说明

mamba输入输出实现与transformer几乎完全一样的功能,但速度和内存占用具有很大优势。对比transformer,transformer存在记忆有限的情况,如果输入或者预测的序列过长可能导致爆炸(非线性),而mamba不存在这种情…

物理网卡MAC修改器v3.0-直接修改网卡内部硬件MAC地址,重装系统不变!

直接在操作系统里就能修改网卡硬件mac地址,刷新网卡mac序列号硬件码机器码,电脑主板集成网卡,pcie网卡,usb有线网卡,usb无线网卡,英特尔网卡,瑞昱网卡全支持! 一键修改mac&#xff0…

百问网全志系列开发板音频ALSA配置步骤详解

8 ALSA 8.1 音频相关概念 ​ 音频信号是一种连续变化的模拟信号,但计算机只能处理和记录二进制的数字信号,由自然音源得到的音频信号必须经过一定的变换,成为数字音频信号之后,才能送到计算机中作进一步的处理。 ​ 数字音频系…

MySQL本地Windows安装

下载MySQL 官网:MySQL 下载完成后文件 安装类型 选择功能 功能过滤筛选,系统为64位操作系统,所以选【64-bit】,32位可选【32.bit】 下方勾选后自动检查安装环境,如果提示缺少运行库,请先安装VC_redist.x64。…

【Dash】plotly.express 气泡图

一、More about Visualization The Dash Core Compnents module dash.dcc includes a componenet called Graph. Graph renders interactive data visualizations using the open source plotly.js javaScript graphing library.Plotly.js supports over 35 chart types and …