3.PyTorch——常用神经网络层

news2025/1/20 5:44:19
import numpy as np
import pandas as pd
import torch as t
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage

t.__version__
'2.1.1'

3.1 图像相关层

图像相关层主要包括卷积层(Conv)、池化层(Pool)等,这些层在实际使用中可分为一维(1D)、二维(2D)、三维(3D),池化方式又分为平均池化(AvgPool)、最大值池化(MaxPool)、自适应池化(AdaptiveAvgPool)等。而卷积层除了常用的前向卷积之外,还有逆卷积(TransposeConv)。

除了这里的使用,图像的卷积操作还有各种变体,具体可以参照此处动图[^2]介绍。 [^2]: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

to_tensor = ToTensor()
to_pil = ToPILImage()
lena = Image.open('imgs/lena.png')
lena

在这里插入图片描述

# layer对输入形状都有假设:输入的不是单个数据,而是一个batch。
# 这里输入一个数据,就必须调用tensor.unsqueeze(0)增加一个维度,伪装成batch_size=1的batch
input = to_tensor(lena).unsqueeze(0)

# 锐化卷积核
kernel = t.ones(3, 3) / -9
kernel[1][1] = 1
conv = t.nn.Conv2d(1, 1, (3, 3), 1, bias=False)
conv.weight.data = kernel.view(1, 1, 3, 3)

out = conv(input)
to_pil(out.data.squeeze(0))

在这里插入图片描述

池化层:可视为一种特殊的卷积层,用来下采样。注意池化层是没有可学习参数的,其weight是固定的。

pool = t.nn.AvgPool2d(2, 2)
out = pool(input)
to_pil(out.data.squeeze(0))

在这里插入图片描述

除了卷积层和池化层,深度学习中还将常用到以下几个层:

  • Linear:全连接层。
  • BatchNorm:批规范化层,分为1D、2D和3D。除了标准的BatchNorm之外,还有在风格迁移中常用到的InstanceNorm层。
  • Dropout:dropout层,用来防止过拟合,同样分为1D、2D和3D。
    下面通过例子来说明它们的使用。
# 输入batch_size=2, 维度3
input = t.rand(2,3)
linear = t.nn.Linear(3, 4)
h = linear(input)
h
tensor([[-0.2314, -0.2245,  0.0966,  0.7610],
        [-0.2679, -0.2403,  0.0086,  0.5799]], grad_fn=<AddmmBackward0>)
# 4 channel,初始化标准差4,均值0
bn = t.nn.BatchNorm1d(4)
bn.weight.data = t.ones(4) * 4
bn.bias.data = t.zeros(4)

bn_out = bn(h)
print(bn_out)
bn_out.mean(), bn_out.var(0, unbiased=False)       # 由于计算无偏方差分母会减1, 使用unbiased=1分母不减一
tensor([[ 3.9415,  3.7136,  3.9897,  3.9976],
        [-3.9415, -3.7136, -3.9897, -3.9976]],
       grad_fn=<NativeBatchNormBackward0>)





(tensor(-1.8775e-06, grad_fn=<MeanBackward0>),
 tensor([15.5355, 13.7908, 15.9179, 15.9805], grad_fn=<VarBackward0>))
# 每个元素以0.5的概率舍弃
dropout = t.nn.Dropout(0.5)
o = dropout(bn_out)
o           
tensor([[ 7.8830,  7.4272,  0.0000,  7.9951],
        [-7.8830, -7.4272, -7.9794, -7.9951]], grad_fn=<MulBackward0>)

3.2 激活函数

PyTorch实现了常见的激活函数,其具体的接口信息可参见官方文档1,这些激活函数可作为独立的layer使用。这里将介绍最常用的激活函数ReLU,其数学表达式为:
R e L U ( x ) = m a x ( 0 , x ) ReLU(x)=max(0,x) ReLU(x)=max(0,x)

ReLU函数有个inplace参数,如果设为True,它会把输出直接覆盖到输入中,这样可以节省内存/显存。之所以可以覆盖是因为在计算ReLU的反向传播时,只需根据输出就能够推算出反向传播的梯度。但是只有少数的autograd操作支持inplace操作(如tensor.sigmoid_()),除非你明确地知道自己在做什么,否则一般不要使用inplace操作。

relu = t.nn.ReLU(inplace=True)
input = t.randn(2, 3)
print(input)
output = relu(input)
print(output)        # 负数都被截断为0
tensor([[-0.4064, -0.1886,  0.4812],
        [ 0.8996, -0.3606,  0.6127]])
tensor([[0.0000, 0.0000, 0.4812],
        [0.8996, 0.0000, 0.6127]])

对于此类网络如果每次都写复杂的forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。

# Sequential
net = t.nn.Sequential(
    t.nn.Conv2d(3, 3, 3),
    t.nn.BatchNorm2d(3),
    t.nn.ReLU()
)
print('net:', net)
net: Sequential(
  (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
# 可根据名字或序号取出module
net[2]
ReLU()
input = t.randn(1, 3, 4, 4)
output = net(input)
output
tensor([[[[1.2239, 0.0000],
          [0.0000, 0.6354]],

         [[0.1855, 0.0000],
          [0.7218, 0.7777]],

         [[1.3686, 0.0000],
          [0.4861, 0.0000]]]], grad_fn=<ReluBackward0>)
# modellist
modellist = t.nn.ModuleList([t.nn.Linear(3, 4), t.nn.ReLU(), t.nn.Linear(4, 2)])
input = t.randn(1, 3)

for model in modellist:
    input = model(input)
    print(input)
tensor([[-0.1817,  0.3852,  1.3656, -0.5643]], grad_fn=<AddmmBackward0>)
tensor([[0.0000, 0.3852, 1.3656, 0.0000]], grad_fn=<ReluBackward0>)
tensor([[-0.0151, -0.0309]], grad_fn=<AddmmBackward0>)

3.3 RNN循环神经网络

关于RNN的基础知识,推荐阅读colah的文章2入门。PyTorch中实现了如今最常用的三种RNN:RNN(vanilla RNN)、LSTM和GRU。此外还有对应的三种RNNCell。

RNN和RNNCell层的区别在于前者一次能够处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性。实际上RNN层的一种后端实现方式就是调用RNNCell来实现的。

t.manual_seed(1000)
# 输入:batch_size=3, 序列长度为2,序列中每个元素占4维
input = t.randn(2, 3, 4)
# lstm输入向量4维,隐藏元3. 1层
lstm = t.nn.LSTM(4, 3, 1)
# 初始状态:1层,batch_size=3, 3个隐藏元
h0 = t.randn(1, 3, 3)
c0 = t.randn(1, 3, 3)
out, hn = lstm(input, (h0, c0))
out
tensor([[[-0.3610, -0.1643,  0.1631],
         [-0.0613, -0.4937, -0.1642],
         [ 0.5080, -0.4175,  0.2502]],

        [[-0.0703, -0.0393, -0.0429],
         [ 0.2085, -0.3005, -0.2686],
         [ 0.1482, -0.4728,  0.1425]]], grad_fn=<MkldnnRnnLayerBackward0>)
t.manual_seed(1000)
input = t.randn(2, 3, 4)
# 一个LSTMCell对应的层数只能是一层
lstm = t.nn.LSTMCell(4, 3)
hx = t.randn(3, 3)
cx = t.randn(3, 3)
out = []
for i_ in input:
    hx, cx=lstm(i_, (hx, cx))
    out.append(hx)
t.stack(out)
tensor([[[-0.3610, -0.1643,  0.1631],
         [-0.0613, -0.4937, -0.1642],
         [ 0.5080, -0.4175,  0.2502]],

        [[-0.0703, -0.0393, -0.0429],
         [ 0.2085, -0.3005, -0.2686],
         [ 0.1482, -0.4728,  0.1425]]], grad_fn=<StackBackward0>)

3.4 损失函数

损失函数可看作是一种特殊的layer,PyTorch也将这些损失函数实现为nn.Module的子类。然而在实际使用中通常将这些loss function专门提取出来,和主模型互相独立。详细的loss使用请参照文档3,这里以分类中最常用的交叉熵损失CrossEntropyloss为例说明。

# batch_size = 3, 计算对应每个类别的分数(只有两个类别)
score = t.randn(3, 2)
# 三个样本分别属于1, 0, 1类,label必须是LongTensor
label = t.Tensor([1, 0, 1]).long()

# loss与普通的layer无差异
criterion = t.nn.CrossEntropyLoss()
loss = criterion(score, label)
loss
tensor(1.8772)

  1. http://pytorch.org/docs/nn.html#non-linear-activations ↩︎

  2. http://colah.github.io/posts/2015-08-Understanding-LSTMs/ ↩︎

  3. http://pytorch.org/docs/nn.html#loss-functions ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1293400.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

翻译: 生成式人工智能的经济潜力 第2部分行业影响 The economic potential of generative AI

麦肯锡报告 翻译: 生成式人工智能的经济潜力 第一部分商业价值 The economic potential of generative AI 1. 行业影响 在我们分析的63个使用案例中&#xff0c;生成式人工智能有潜力在各行各业创造2.6万亿至4.4万亿美元的价值。其确切影响将取决于各种因素&#xff0c;比如…

SpringBoot框架+原生HTML开发,基于云端SaaS服务方式的电子病历编辑器源码

一体化电子病历编辑器源码&#xff0c;电子病历系统 一体化电子病历系统基于云端SaaS服务的方式&#xff0c;采用B/S&#xff08;Browser/Server&#xff09;架构提供&#xff0c;覆盖了医疗机构电子病历模板制作到管理使用的整个流程。除实现在线制作内容丰富、图文并茂、功能…

MySQL主从复制(一主两从)架构搭建(阿里云服务器)

建立主机master 1.建立数据库master docker run --name master --restart always -p 3308:3306 -v /root/docker/volumes/etc/master:/etc/mysql -v /root/docker/volumes/var/lib/master:/var/lib/mysql -e MYSQL_ROOT_PASSWORDriCXT8zM -d mysql:latest 2.复制master的配置文…

讲一下maven的生命周期

Maven是一种强大的项目管理工具&#xff0c;它可以帮助开发者组织和管理项目的构建过程。Maven的生命周期指的是一系列的活动&#xff0c;包括如何创建、准备、构建和测试项目的过程。以下是对Maven生命周期的主要阶段的简要概述&#xff1a; 获取项目&#xff1a;在这个阶段&…

【开源】基于Vue+SpringBoot的智慧家政系统

项目编号&#xff1a; S 063 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S063&#xff0c;文末获取源码。} 项目编号&#xff1a;S063&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示四、核心代码4.1 查询家政服…

扩展卡尔曼滤波技术(Extended Kalman Filter,EKF)

一、概念介绍 卡尔曼滤波是一种高效率的递归滤波器(自回归滤波器), 它能够从一系列的不完全包含噪声的测量中&#xff0c;估计动态系统的状态&#xff0c;然而简单的卡尔曼滤波必须应用在符合高斯分布的系统中。 扩展卡尔曼滤波就是为了解决非线性问题&#xff0c;普通卡尔曼…

HashMap系列-resize

1.resize public class HashMap<K,V> extends AbstractMap<K,V>implements Map<K,V>, Cloneable, Serializable {final Node<K,V>[] resize() {Node<K,V>[] oldTab table;int oldCap (oldTab null) ? 0 : oldTab.length; //老的数组容量in…

【开源】基于Vue+SpringBoot的快乐贩卖馆管理系统

项目编号&#xff1a; S 064 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S064&#xff0c;文末获取源码。} 项目编号&#xff1a;S064&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 搞笑视频模块2.3 视…

用 PHP和html做一个简单的注册页面

用 PHP和html做一个简单的注册页面 index.html的设计 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title…

手眼标定 - 最终精度和误差优化心得

手眼标定 - 标定误差优化项 一、TCP标定误差优化1、注意标定针摆放范围2、TCP标定时的点次态与工作姿态尽可能保持相近 二、深度相机对齐矩阵误差1、手动计算对齐矩阵 三、拍照姿态1、TCP标定姿态优先2、水平放置棋盘格优先 为减少最终手眼标定的误差&#xff0c;可做或注意以下…

华为数通---配置Smart Link负载分担案例

定义 Smart Link&#xff0c;又叫做备份链路。一个Smart Link由两个接口组成&#xff0c;其中一个接口作为另一个的备份。Smart Link常用于双上行组网&#xff0c;提供可靠高效的备份和快速的切换机制。 目的 下游设备连接到上游设备&#xff0c;当使用单上行方式时&#x…

【dig命令查询方法】

dig&#xff08;Domain Information Groper&#xff09;是一个用于查询DNS&#xff08;域名系统&#xff09;的命令行工具&#xff0c;它可以帮助您获取关于域名的各种信息&#xff0c;如IP地址、MX记录、NS记录等。下面是dig的详细使用教程。 基本语法&#xff1a; dig [ser…

【华为数据之道学习笔记】3-4主数据治理

主数据是参与业务事件的主体或资源&#xff0c;是具有高业务价值的、跨流程和跨系统重复使用的数据。主数据与基础数据有一定的相似性&#xff0c;都是在业务事件发生之前预先定义&#xff1b;但又与基础数据不同&#xff0c;主数据的取值不受限于预先定义的数据范围&#xff0…

http和https的区别有哪些?

HTTP&#xff08;超文本传输协议&#xff09;和HTTPS&#xff08;HTTP Secure&#xff09;是互联网上用于数据传输的两种协议。它们的主要区别在于HTTPS提供了加密的传输机制&#xff0c;以提高数据在传输过程中的安全性。以下是HTTP和HTTPS的一些主要区别&#xff1a; 加密&a…

[linux运维] 利用zabbix监控linux高危命令并发送告警(基于Zabbix 6)

之前写过一篇是基于zabbix 5.4的实现文章&#xff0c;但是不太详细&#xff0c;最近已经有两个小伙伴在zabbix 6上操作&#xff0c;发现触发器没有str函数&#xff0c;所以更新一下本文&#xff0c;基于zabbix 6 0x01 来看看效果 高危指令出发问题告警&#xff1a; 发出邮件告…

如何将idea中导入的文件夹中的项目识别为maven项目

问题描述 大家经常遇到导入某个文件夹的时候&#xff0c;需要将某个子文件夹识别为maven项目 解决方案

【教程】逻辑回归怎么做多分类

目录 一、逻辑回归模型介绍 1.1 逻辑回归模型简介 1.2 逻辑回归二分类模型 1.3 逻辑回归多分类模型 二、如何实现逻辑回归二分类 2.1 逻辑回归二分类例子 2.2 逻辑回归二分类实现代码 三、如何实现一个逻辑回归多分类 3.1 逻辑回归多分类问题 3.1 逻辑回归多分类的代…

RabbitMQ-学习笔记(初识 RabbitMQ)

本篇文章学习于 bilibili黑马 的视频 (狗头保命) 同步通讯 & 异步通讯 (RabbitMQ 的前置知识) 同步通讯&#xff1a;类似打电话&#xff0c;只有对方接受了你发起的请求,双方才能进行通讯, 同一时刻你只能跟一个人打视频电话。异步通讯&#xff1a;类似发信息&#xff0c…

Hadoop3.x完全分布式环境搭建Zookeeper和Hbase

先在主节点上进行安装和配置&#xff0c;随后分发到各个从节点上。 1. 安装zookeeper 1.1 解压zookeeper并添加环境变量 1&#xff09;解压zookeeper到/usr/local文件夹下 tar -zxvf /usr/local2&#xff09;进入/usr/local文件夹将apache-zookeeper-3.8.0-bin改名为zookeep…

玩转Sass:掌握数据类型!

当我们在进行前端开发的时候&#xff0c;有时候需要使用一些不同的数据类型来处理样式&#xff0c;Sass 提供的这些数据类型可以帮助我们更高效地进行样式开发&#xff0c;本篇文章将为您详细介绍 Sass 中的数据类型。 布尔类型 在 Sass 中&#xff0c;布尔数据类型可以表示逻…