pytorch学习day3

news2025/4/17 14:27:39

一、模型创建(Module)

网络创建流程

上面的图表展示了使用PyTorch创建神经网络模型的主要步骤。每个步骤按顺序连接,展示了从导入必要的库到最终测试模型的整个流程:

  1. 导入必要的库:首先导入PyTorch及其相关模块。
  2. 定义网络结构:通过继承 nn.Module 类定义神经网络的层和前向传播过程。
  3. 实例化模型:使用定义的结构实例化模型对象。
  4. 定义损失函数和优化器:选择并定义损失函数和优化器。
  5. 准备数据:加载并预处理数据,创建数据加载器。
  6. 训练模型:通过训练循环进行前向传播、计算误差和反向传播更新权重。
  7. 测试模型:在测试数据上评估模型的性能。

模型构建的两个要素

在PyTorch中,构建神经网络模型的关键在于两个要素:构建子模块拼接子模块。这两个要素分别在模型类的 __init__() 方法和 forward() 方法中实现。

1. 构建子模块

在自定义模型中,通过继承 nn.Module 类,并在 __init__() 方法中定义子模块。这些子模块通常是神经网络的各层,例如卷积层、全连接层、激活函数等。

示例:

import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        # 定义子模块
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(in_features=16*16*16, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

在上面的代码中,我们定义了一个卷积层 conv1,一个池化层 pool,以及三个全连接层 fc1fc2fc3。这些子模块是模型的基本组成部分。

2. 拼接子模块

forward() 方法中定义子模块的拼接方式。forward() 方法描述了输入数据如何经过这些子模块的传递过程,最终输出结果。

示例:

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(in_features=16*16*16, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 拼接卷积层和池化层
        x = x.view(-1, 16*16*16)              # 展平张量
        x = F.relu(self.fc1(x))               # 拼接第一个全连接层
        x = F.relu(self.fc2(x))               # 拼接第二个全连接层
        x = self.fc3(x)                       # 拼接第三个全连接层
        return x

forward() 方法中,我们定义了输入数据的传递路径。数据首先通过卷积层 conv1 和池化层 pool,然后展平为一维张量,依次通过三个全连接层 fc1fc2fc3,最后输出结果。

通过这两个步骤,我们可以构建出一个功能齐全的神经网络模型。以下是流程图,帮助理解这两个要素在模型构建中的位置和作用。

当然,以下是关于模型构建的两个要素的表格,可以直接复制使用:

| 模型构建的两个要素 |                  描述                  |
|--------------------|----------------------------------------|
| 构建子模块         | 在自定义模型(继承 nn.Module)的 `__init__()` 方法中定义各个层(卷积层、池化层、全连接层等)|
| 拼接子模块         | 在自定义模型的 `forward()` 方法中定义层的连接方式,描述前向传播过程              |

这个表格简明地展示了模型构建的两个关键步骤和它们分别在哪个方法中实现。希望这能帮助你更好地理解和使用PyTorch进行模型构建。

通过以上两个步骤,我们可以灵活地定义各种复杂的神经网络模型,并通过 forward() 方法灵活地组合这些子模块,实现数据的前向传播过程。

二、nn.Mudule的属性

nn.Module 是 PyTorch 中所有神经网络模块的基类。它提供了一些关键属性和方法,用于构建和管理神经网络模型。以下是 nn.Module 的一些重要属性和方法:

1. parameters()

  • 描述:返回模型所有参数的迭代器。
  • 用途:通常用于优化器来获取模型参数进行训练。
for param in model.parameters():
    print(param.size())

2. named_parameters()

  • 描述:返回一个包含模型参数名字和参数本身的迭代器。
  • 用途:当你需要获取特定层的参数时特别有用。
for param in model.parameters():
    print(param.size())

3. children()

  • 描述:返回模型所有子模块的迭代器。
  • 用途:用于递归遍历模型的各个子模块。
for child in model.children():
    print(child)

4. named_children()

  • 描述:返回一个包含模型子模块名字和子模块本身的迭代器。
  • 用途:用于详细查看每个子模块。
for name, child in model.named_children():
    print(name, child)

5. modules()

  • 描述:返回模型所有模块(包括模型本身和其子模块)的迭代器。
  • 用途:用于遍历所有模块。
for module in model.modules():
    print(module)

6. named_modules()

  • 描述:返回一个包含模型模块名字和模块本身的迭代器。
  • 用途:当你需要以层级结构查看所有模块时使用。
  • for name, module in model.named_modules():
        print(name, module)

7. add_module(name, module)

  • 描述:将一个子模块添加到当前模块。
  • 用途:动态地添加子模块。
model.add_module('extra_layer', nn.Linear(10, 10))

8. forward()

  • 描述:定义前向传播逻辑。用户需要在自己的子类中重载这个方法。
  • 用途:定义输入数据如何通过网络层进行传递。
def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    return x

9. train(mode=True)

  • 描述:将模块设置为训练模式。
  • 用途:启用或禁用 Dropout 和 BatchNorm。
model.train()  # 设置为训练模式
model.eval()   # 设置为评估模式

10. zero_grad()

  • 描述:将所有模型参数的梯度清零。
  • 用途:在每次反向传播前清除旧的梯度。
model.zero_grad()

这些属性和方法提供了强大的功能,使得 nn.Module 能够灵活且高效地管理神经网络模型。通过这些接口,你可以构建、管理和训练复杂的神经网络。

三、模型容器Containers

模型容器(Containers)

在 PyTorch 中,模型容器(Containers)是用于组织和管理神经网络层的一种方式。通过使用模型容器,可以更方便地构建和管理复杂的神经网络结构。以下是 PyTorch 中常用的几种模型容器:

1. nn.Sequential

描述:

nn.Sequential 是一个按顺序执行子模块的容器。它将子模块按定义顺序串联起来,适合用于简单的前向传播模型。

用途:

用于快速构建按顺序堆叠的网络结构,例如多层感知机(MLP)和简单的卷积神经网络(CNN)。

示例:

import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(1, 20, 5),
    nn.ReLU(),
    nn.Conv2d(20, 64, 5),
    nn.ReLU()
)

在这个例子中,输入数据依次通过两个卷积层和两个 ReLU 激活函数。

2. nn.ModuleList

描述:

nn.ModuleList 是一个存储子模块的有序列表,但并没有定义前向传播的具体顺序。它主要用于需要灵活前向传播定义的模型。

用途:

适用于需要在前向传播过程中动态选择层或者有条件执行层的情况。

示例:

import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.layers = nn.ModuleList([nn.Conv2d(1, 20, 5), nn.Conv2d(20, 64, 5)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

在这个例子中,layers 存储了两个卷积层,并在 forward 方法中以循环的方式应用它们。

3. nn.ModuleDict

描述:

nn.ModuleDict 是一个存储子模块的字典,可以使用键来访问子模块。它提供了灵活的模块管理方式,可以通过键值对的方式存取模块。

用途:

适用于需要命名访问子模块,且不需要严格的前向传播顺序的情况,例如多分支的模型结构。

示例:

import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.layers = nn.ModuleDict({
            'conv1': nn.Conv2d(1, 20, 5),
            'conv2': nn.Conv2d(20, 64, 5)
        })

    def forward(self, x):
        x = self.layers['conv1'](x)
        x = self.layers['conv2'](x)
        return x

在这个例子中,layers 存储了两个卷积层,可以通过键名 'conv1''conv2' 进行访问。

4. nn.ParameterListnn.ParameterDict

描述:

这两个容器分别用于存储参数列表和参数字典,与 ModuleListModuleDict 类似,但它们存储的是参数而不是模块。

用途:

适用于需要灵活管理模型参数的情况。

示例:

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(3)])
        self.param_dict = nn.ParameterDict({
            'param1': nn.Parameter(torch.randn(10, 10)),
            'param2': nn.Parameter(torch.randn(10, 10))
        })

    def forward(self, x):
        # 使用 self.params 和 self.param_dict 进行前向传播
        pass

在这个例子中,params 存储了三个参数,而 param_dict 则存储了两个命名参数。

4 总结

通过使用这些模型容器,PyTorch 提供了灵活且高效的方式来组织和管理神经网络模型的层和参数。nn.Sequential 适用于简单的顺序结构,nn.ModuleListnn.ModuleDict 提供了更多的灵活性,适用于更复杂的网络结构。nn.ParameterListnn.ParameterDict 则用于更灵活的参数管理。利用这些容器,可以更方便地构建和管理复杂的神经网络模型。

5 实现一个简单VGG网络

创建一个简单的VGG网络

VGG网络是一种深度卷积神经网络,因其简单且具有良好的性能而广泛应用。下面我们利用PyTorch提供的模型容器,构建一个简化版的VGG网络。我们将主要使用nn.Sequential来按顺序堆叠卷积层和全连接层。

相关论文地址:https://arxiv.org/abs/1409.1556

1. 导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F

2. 定义VGG块

VGG块由多个卷积层和一个池化层组成。我们定义一个函数来创建这些块。

def vgg_block(num_convs, in_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.ReLU(inplace=True))
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

3. 定义VGG网络

我们利用nn.Sequential来堆叠多个VGG块,最后添加全连接层。

class SimpleVGG(nn.Module):
    def __init__(self):
        super(SimpleVGG, self).__init__()
        self.features = nn.Sequential(
            vgg_block(2, 3, 64),
            vgg_block(2, 64, 128),
            vgg_block(3, 128, 256),
            vgg_block(3, 256, 512),
            vgg_block(3, 512, 512)
        )
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

4. 实例化和测试模型

我们创建模型实例并打印其结构,确保其正确性。

model = SimpleVGG()
print(model)

5. 测试模型结构

为了确保模型构建正确,我们可以打印模型结构或者传递一个随机张量进行测试。

if __name__ == "__main__":
    model = SimpleVGG()
    print(model)
    # 测试输入数据
    input_tensor = torch.randn(1, 3, 224, 224)
    output = model(input_tensor)
    print(output.shape)  # 应输出 torch.Size([1, 10])

通过这些步骤,我们利用PyTorch提供的模型容器创建了一个简化版的VGG网络。这个网络由五个VGG块和三个全连接层组成,适用于图像分类任务。根据需求可以进一步调整网络结构和参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1719273.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

六一见!|Post Microsoft Build and AI Day 上海开发者日

编辑/排版:Alan Wang 大小朋友明天见! 6月1日,Microsoft Azure & Microsoft Reactor 面向大小朋友特别推出六一特辑,「Post Microsoft Build and AI Day 上海开发者日」 探讨 Microsoft Build 2024 带来的最新发布&#xff0…

民国漫画杂志《时代漫画》第36期.PDF

时代漫画36.PDF: https://url03.ctfile.com/f/1779803-1248636233-8a4a9d?p9586 (访问密码: 9586) 《时代漫画》的杂志在1934年诞生了,截止1937年6月战争来临被迫停刊共发行了39期。 ps: 资源来源网络!

整合框架(spring...) 统一异常处理

1、 我们想让异常结果也显示为统一的返回结果对象,并且统一处理系统的异常信息,那么需要统一异常处理。 附加:创建封装错误状态码和错误消息VO 代码如下: Result import io.swagger.v3.oas.annotations.media.Schema; impo…

探索 Android Studio 中的 Gemini:加速 Android 开发的新助力

探索 Android Studio 中的 Gemini:加速 Android 开发的新助力 在 Gemini 时代的下一篇章中,Gemini融入了更多产品中,Android Studio 正在使用 Gemini 1.0 Pro 模型,使 Android 开发变得更快、更简单。 Studio Bot 现已更名为 And…

【Text2SQL 论文】DIN-SQL:分解任务 + 自我纠正 + in-context 让 LLM 完成 Text2SQL

论文:DIN-SQL: Decomposed In-Context Learning of Text-to-SQL with Self-Correction ⭐⭐⭐⭐ NeurIPS 2023, arXiv:2304.11015 Code: Few-shot-NL2SQL-with-prompting | GitHub 文章目录 一、论文速读1.1 Schema Linking Module1.2 Classification & Decompo…

浅谈线性化

浅谈线性化 原文:浅谈线性化 - 知乎 (zhihu.com) All comments and opinions expressed on Zhihu are mine alone and do not necessarily reflect those of my employers, past or present. 本文内容所有内容仅代表本人观点,和Mathworks无关 (这里所说…

AIGC智能办公实战 课程,祝你事业新高度

在数字化时代,人工智能(AI)已经渗透到我们生活的方方面面,从智能家居到自动驾驶,从医疗诊断到金融分析,AI助手正在改变我们的工作方式和生活质量。那么,你是否想过自己也能从零开始,…

服务器怎么被远程桌面连接不上,远程桌面连接不上服务器的问题有效解决方案

远程桌面连接不上服务器是一个极其严重的问题,它可能直接影响到我们的工作效率、数据安全,甚至是整个业务运营的顺畅。因此,这个问题必须得到迅速且有效的解决。 当我们尝试远程桌面连接服务器时,可能会遇到连接不上的情况。这其中…

thinkphp6 自定义的查询构造器类

前景需求&#xff1a;在查询的 时候我们经常会有一些通用的&#xff0c;查询条件&#xff0c;但是又不想每次都填写一遍条件&#xff0c;这个时候就需要重写查询类&#xff08;Query&#xff09; 我目前使用的thinkphp版本是6.1 首先自定义CustomQuery类继承于Query <?p…

Ubuntu server 24 (Linux) Snort3 3.2.1.0 Guardian IPtables 联动实战 主动防御系统(ids+ips)

一 Snort3 安装配置&#xff0c;参考:Ubuntu server 24 安装配置 snort3 3.2.1.0 网络入侵检测防御系统 配置注册规则集-CSDN博客 二 安装主动防御程序Guardian 1 下载&#xff0c;解压 tar zxvf guardian-1.7.tar.gz cd guardian-1.7/ 2 配置 #拷贝文件 sudo cp guard…

【文献阅读】基于模型设计的汽车软件质量属性

参考文献&#xff1a;《基于模型设计满足汽车软件质量和快速交付的挑战》&#xff0c;深向科技在2024年MATLAB XEPO大会的演讲 Tips&#xff1a;KISS原则&#xff0c;全称为“Keep It Simple, Stupid”&#xff0c;直译为“保持简单&#xff0c;愚蠢的人也能懂”

使用 EBS 和构建数据库服务器并使用应用程序与数据库交互

实验 4&#xff1a;使用 EBS 实验概览 本实验着重介绍 Amazon Elastic Block Store (Amazon EBS)&#xff0c;这是一种适用于 Amazon EC2 实例的重要底层存储机制。在本实验中&#xff0c;您将学习如何创建 Amazon EBS 卷、将其附加到实例、向卷应用文件系统&#xff0c;然后进…

师彼长技以助己(2)产品思维

师彼长技以助己&#xff08;2&#xff09;产品思维 前言 我把产品思维称之为&#xff1a;人生底层的能力以及蹉跎别人还蹉跎自己的能力&#xff0c;前者说明你应该具备良好产品思维原因&#xff0c;后者是你没有好的产品思维去做产品带来的灾难。 人欲即天理 请大家谈谈看到这…

JavaWeb ServletContext 对象 应用

ServletContext: ServletContext是Java EE Servlet 定义的一个让 Web 应用中的 Servlet 能够和服务器交流的一个接口&#xff0c;每个应用都有自己的 ServletContext&#xff0c;除了分布式应用中的每个服务器实例的 ServletContext 是独立的&#xff0c;不能用来共享数据外 Se…

vue3组件通信与props

title: vue3组件通信与props date: 2024/5/31 下午9:00:57 updated: 2024/5/31 下午9:00:57 categories: 前端开发 tags: Vue3组件Props详解生命周期数据通信模板语法Composition API单向数据流 Vue 3 组件基础 在 Vue 3 中&#xff0c;组件是构建用户界面的基本单位&#…

最佳 Mac 数据恢复:恢复 Mac 上已删除的文件

尝试过许多 Mac 数据恢复工具&#xff0c;但发现没有一款能达到宣传的效果&#xff1f;我们重点介绍最好的 Mac 数据恢复软件 没有 Mac 用户愿意担心数据丢失&#xff0c;但您永远不知道什么时候会发生这种情况。无论是意外删除 Mac 上的重要文件、不小心弄湿了 Mac、感染病毒…

模型 STORY评估框架

说明&#xff1a;系列文章 分享 模型&#xff0c;了解更多&#x1f449; 模型_思维模型目录。故事五要素&#xff1a;结构、时间、观点、现实、收益 。 1 STORY评估框架的应用 1.1 STORY模型展示其个性化在线学习解决方案的优势 一家在线教育平台想要通过一个故事来展示其个性…

【高校科研前沿】南大王栋、吴吉春教授团队在深度学习助力水库生态调度和优化管理方面取得新进展,成果以博士生邱如健为一作发表于水环境领域国际权威期刊

1.文章简介 论文名称&#xff1a;Integration of deep learning and improved multi-objective algorithm to optimize reservoir operation for balancing human and downstream ecological needs 第一作者及单位&#xff1a;邱如健&#xff08;博士生 南京大学&#xff09;…

在Android Studio中使用谷歌Gemini代码助手

今天在做android开发的时候&#xff0c;一个项目使用到了gradle8.0&#xff0c;但是我的Android Studuio根本不支持&#xff0c;无可奈何只能从小蜜蜂版本升级了水母 | 2023.3.1版本&#xff0c;但突然发现AS已经集成了Gemini助手。 首先我们需要下载这个版本的&#xff1a; h…

【Unity脚本】使用脚本操作游戏对象的组件

【知识链】Unity -> Unity脚本 -> 游戏对象 -> 组件 【知识链】Unity -> Unity界面 -> Inspector【摘要】本文介绍如何使用脚本添加、删除组件&#xff0c;以及如何访问组件 文章目录 引言第一章 游戏对象与组件1.1什么是组件&#xff1f;1.2 场景、游戏对象与组…