pytorch Neural Networks学习笔记

news2025/4/19 10:04:48

在这里插入图片描述

(1)输入图像,1×32×32,通道数1,高32,宽32
(2)卷积层1,滤波器的shape为6×1×5×5,滤波器个数6,通道数1,高5,宽5。卷积层1的输出为6×28×28。
(3)relu层,输出和输入的shape相同
(4)池化层,滤波器的大小为2×2,stride为2×2,输出为6×14×14。
(5)卷积层2,滤波器的shape为16×6×5×5,滤波器个数16,通道数6,高5,宽5。卷积层2的输出为16×10×10。
(6)relu层,输出和输入的shape相同
(7)池化层,滤波器的大小为2×2,stride为2×2,输出为16×5×5。reshape为1×400后输入到全连接层1。
(8)全连接层1,权重的shape为400×120,输出为1×120
(9)relu层,输出和输入的shape相同
(10)全连接层2,权重的shape为120×84,输出为1×84
(11)relu层,输出和输入的shape相同
(12)全连接层3,权重的shape为84×10,输出为1×10

  这个网络的结构如下:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
      -> flatten -> linear -> relu -> linear -> relu -> linear
      -> MSELoss
      -> loss

  Let’s define this network:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
print(net)
Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

  这里写代码测试一下flatten函数的用法:

>>> data=[[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]]
>>> x_data = torch.tensor(data)
>>> print(f"x_data: \n {x_data} \n")
x_data:
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])

>>> print(f"Shape of x_data: {x_data.shape}")
Shape of x_data: torch.Size([2, 2, 3])
>>> x_data = torch.flatten(x_data, 1)
>>> print(f"x_data: \n {x_data} \n")
x_data:
 tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12]])

>>> print(f"Shape of x_data: {x_data.shape}")
Shape of x_data: torch.Size([2, 6])

  该网络有2个卷积层和3个全连接层,有5个权重参数和5个偏置参数,共10个参数。

params = list(net.parameters())
print(len(params))
for i in range(len(params)):
    print(params[i].size())

  以上代码打印了每个参数的shape,如下:

10
torch.Size([6, 1, 5, 5])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([120, 400])
torch.Size([120])
torch.Size([84, 120])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1601262.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络靶场实战-加密固件分析

背景 在漏洞挖掘过程中,想要对二进制程序改进分析,我们就需要获取目标设备的文件系统,这样才能更好的逆向分析设备程序的运行逻辑,从而发现其中的漏洞点。然而并非所有设备固件中的文件系统都可以让我们轻易获取,对于…

什么是上位机?入门指南

什么是上位机? 上位机(SCADA,Supervisory Control and Data Acquisition)是一种软件系统,用于监控和控制工业过程中的设备。它通常与传感器、执行器和其他自动化设备一起工作,以实时地监视过程状态、收集数…

7 pytorch DataLoader, TensorDataset批数据训练方法

前言 本文主要介绍pytorch里面批数据的处理方法,以及这个算法的效果是什么样的。具体就是要弄明白这个批数据选取的算法是在干什么,不会涉及到网络的训练。 from torch.utils.data import DataLoader, TensorDataset主要实现就是上面的数据集和数据载入…

详解IP安全:【IPSec协议簇 - AH协议 - ESP协议 - IKE协议】

目录 IP安全概述 IPSec协议簇 IPSec的实现方式 AH(Authentication Header,认证头) ESP(Encapsulating Security Payload,封装安全载荷) IKE(Internet Key Exchange,因特网密钥…

Sourcetree安装使用(补个笔记)

Sourcetree介绍 Sourcetree是一款免费的Git图形化客户端,它由Atlassian开发,提供了跨平台的支持,可运行在Windows和Mac操作系统上。Sourcetree可以让开发者更方便地使用Git来管理代码,不需要在命令行中输入复杂的Git命令&#xf…

C++ | Leetcode C++题解之第31题下一个排列

题目&#xff1a; 题解&#xff1a; class Solution { public:void nextPermutation(vector<int>& nums) {int i nums.size() - 2;while (i > 0 && nums[i] > nums[i 1]) {i--;}if (i > 0) {int j nums.size() - 1;while (j > 0 && …

HarmonyOS NEXT星河版之实战商城App瀑布流(含加载更多)

文章目录 一、目标二、开撸2.1 声明商品对象2.2 mock数据2.3 主页面2.4 加载更多2.5 完整代码 三、小结 一、目标 二、开撸 2.1 声明商品对象 export interface GoodsItem {title: stringimageUrl: string }2.2 mock数据 export const mockGoodsList: GoodsItem[] [{title:…

1688上怎么找跨境电商低价优质货源?妙手ERP帮您搞定!

1688是跨境卖家常用的货源网站&#xff0c;但很多从1688上拿货的卖家都会遇到这样的问题&#xff1a;有时候找到的货源比同行售价还高&#xff0c;好不容易找到低价的货源&#xff0c;但质量不好、供应也不稳定。 实际上&#xff0c;1688平台上有很多超低价的货源&#xff0c;…

Ubuntu系统安装APITable多维表格平台结合内网穿透实现公网访问

文章目录 前言1. 部署APITable2. cpolar的安装和注册3. 配置APITable公网访问地址4. 固定APITable公网地址 前言 vika维格表作为新一代数据生产力平台&#xff0c;是一款面向 API 的智能多维表格。它将复杂的可视化数据库、电子表格、实时在线协同、低代码开发技术四合为一&am…

机器人码垛机的技术特点与应用

随着科技的飞速发展&#xff0c;机器人技术正逐渐渗透到各个行业领域&#xff0c;其中&#xff0c;机器人码垛机在物流行业的应用尤为引人瞩目。它不仅提高了物流效率&#xff0c;降低了成本&#xff0c;更在改变传统物流模式的同时&#xff0c;为行业发展带来了重大的变革。 一…

2024蓝桥杯每日一题(最大公约数)

备战2024年蓝桥杯 -- 每日一题 Python大学A组 试题一&#xff1a;公约数 试题二&#xff1a;最大公约数 试题三&#xff1a;等差数列 试题四&#xff1a;最大比例 试题五&#xff1a;Hankson的趣味题 试题一&#xff1a;公约数 【题目描述】 …

cesium加载高层级离线影像地图瓦片(天地图、19级Arcgis)

实际加载效果如图&#xff1a; 1、下载离线地图瓦片方式&#xff08;多种任选其一&#xff0c;个人倾向于Qgis工具&#xff09;&#xff1a; 方式1、采用第三方下载工具如&#xff1a;91卫图、水经注、全能电子地图下载器、bigemap等等。&#xff08;这些有的下载层级不够&…

服务器数据恢复—xfs文件系统节点、目录项丢失的数据恢复案例

服务器数据恢复环境&#xff1a; EMC某型号存储&#xff0c;该存储内有一组由12块磁盘组建的raid5阵列&#xff0c;划分了两个lun。 服务器故障&#xff1a; 管理员为服务器重装操作系统后&#xff0c;发现服务器的磁盘分区发生改变&#xff0c;原来的sdc3分区丢失。由于该分区…

【C语言】每日一题,快速提升(2)!

&#x1f525;博客主页&#x1f525;&#xff1a;【 坊钰_CSDN博客 】 欢迎各位点赞&#x1f44d;评论✍收藏⭐ 题目&#xff1a;杨氏矩阵 有一个数字矩阵&#xff0c;矩阵的每行从左到右是递增的&#xff0c;矩阵从上到下是递增的&#xff0c;请编写程序在这样的矩阵中查找某个…

Ubuntu 部署ChatGLM3大语言模型

Ubuntu 部署ChatGLM3大语言模型 ChatGLM3 是智谱AI和清华大学 KEG 实验室联合发布的对话预训练模型。 源码&#xff1a;https://github.com/THUDM/ChatGLM3 部署步骤 1.服务器配置 Ubuntu 20.04 8核(vCPU) 32GiB 5Mbps GPU NVIDIA T4 16GB 硬盘 100GiB CUDA 版本 12.2.2/…

从汇编代码理解数组越界访问漏洞

数组越界访问漏洞是 C/C 语言中常见的缺陷&#xff0c;它发生在程序尝试访问数组元素时未正确验证索引是否在有效范围内。通常情况下&#xff0c;数组的索引从0开始&#xff0c;到数组长度减1结束。如果程序尝试访问小于0或大于等于数组长度的索引位置&#xff0c;就会导致数组…

Python 物联网入门指南(八)

原文&#xff1a;zh.annas-archive.org/md5/4fe4273add75ed738e70f3d05e428b06 译者&#xff1a;飞龙 协议&#xff1a;CC BY-NC-SA 4.0 第三十章&#xff1a;制作机械臂 最后&#xff0c;我们终于到达了大多数人自本书开始以来就想要到达的地方。制作一个机械臂&#xff01;在…

pip如何查看Python某个包已发行所有版本号?

以matplotlib包为例子&#xff0c; pip install matplotlib6666 6666只是胡乱输入的一个数&#xff0c;反正输入任意一个不像版本号的数字都可以&#xff5e; matplotlib所有版本号如下&#xff0c; 0.86, 0.86.1, 0.86.2, 0.91.0, 0.91.1, 1.0.1, 1.1.0, 1.1.1, 1.2.0, 1.2.1…

【问题处理】银河麒麟操作系统实例分享,oom分析

1.问题现象描述 服务器数据库被oomkill掉&#xff0c;但是mem查看只占用了不到60%。 2.问题分析 2.1.oom现象分析 从下面的日志信息&#xff0c;可以看到chmod进程是在内核采用GFP_KERNEL|__GFP_COMP分配order3也就是2的3次方&#xff0c;8个连续页的时候&#xff0c;因进入…

如何使用Git-Secrets防止将敏感信息意外上传至Git库

关于Git-Secrets Git-secrets是一款功能强大的开发安全工具&#xff0c;该工具可以防止开发人员意外将密码和其他敏感信息上传到Git库中。 Git-secrets首先会扫描提交的代码和说明&#xff0c;当与用户预先配置的正则表达式模式匹配时&#xff0c;便会阻止此次提交。该工具的优…