Pytorch代码入门学习之分类任务(一):搭建网络框架

news2024/12/26 2:27:37

目录

一、网络框架介绍

二、导包

三、定义卷积神经网络

3.1 代码展示

3.2  定义网络的目的

3.3 Pytorch搭建网络

四、测试网络效果


一、网络框架介绍

        网络理解:

        将32*32大小的灰度图片(下述的代码中输入为32*32大小的RGB彩色图片),输入到网络中;经过第一次卷积C1,变成了6通道、28*28大小的一个特征向量;通过一次下采样S2,变成了6通道、14*14大小的一个特征向量,其宽高相当于折损了一般;经过第二次卷积C3,变成了16通道、10*10大小的一个特征向量;通过第二次下采样S4,变成了16通道、5*5大小的一个特征向量;最后三层全连接输出。

        ①Convolutious(卷积):涉及到输入、输出与很多参数的设置,需要初始化。

        ②Subsampling(下采样):该网络中使用的是最大池化下采样的方法,最大池化下采样的和维2*2大小。

        最大池化:Max Pooling,取窗口内的最大值作为输出。

        ③Full Connection(全连接):需要初始化。

二、导包

import torch  # torch基础库
import torch.nn as nn  # torch神经网络库
import torch.nn.functional as F

三、定义卷积神经网络

3.1 代码展示

class Net(nn.Module):
    # 初始化
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    # 前向传播
    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x,(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1,x.size()[1:].numel())
        x = F.relu(self.fc1(x))  # 进入全连接层需要进行激活函数
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # 最后一层为输出层,要输出结果,不需要进行激活
        return x

3.2  定义网络的目的

        希望网络有科学系参数,通过输入数据的训练让相关参数不断更新、梯度下降到一个合适的值,之后输入新的图片,可以进行分类或者预测。

3.3 Pytorch搭建网络

        Pytorch搭建网络通常会采用类进行管理,可取名为Net(该名字可以更换),通常需要继承nn.Model类(相当于在Net中将Model定义好的方法直接进行使用)。搭建网络通常包括两个函数:

        ①初始化函数(含有默认参数):实例化这个类的时候会自动执行的一部分,这里面放网络需要初始化的内容。

 def __init__(self)

         A. super(Net,self).__init__():在该函数中通常需要进行多继承操作,相当于把Model类里面继承的类以及全部的类的方法都继承下来,供Net去使用;

        B. nn.Conv2d(3,6,5):2d卷积核的函数,只涉及三个参数,其余参数使用默认值;第一个参数为输入的通道数,第二个参数为输出特征向量的通道数,第三个参数为卷积核大小(使用output公式进行计算 W-F+1=28,W=32,F=5 );

        Output = \frac{W-F+2P}{S} + 1:其中W是指宽高,F是指所求的ColorSize的大小,P是指Padding—像图片外面补边,让它去遍历,默认为0;S是指步长,卷积核遍历图片的步长,默认为1;

        C. nn.Linear(16*5*5,120):全连接层的初始化,涉及两个参数(输入特征的维数大小和输出特征的维数大小),全连接层需要对特征做一个拉平,将每一个特征拉平,将上一个特征向量拉为一条直线,送给全连接层;

        ②前向传播函数:需实现前向回归逻辑,相当于完成整个网络运行的逻辑,x是指输入,相当于上图中的input。

def forward(self,x)

        A. F.relu(x):relu激活函数,激活之后网络具有非线性的分离能力;

        B. tensor[batch,channel,H,W]: channel是指通道数,例如RBG三通道这些概念、H是指高,W是指宽,batch是指有几批这样的数据;

        C. F.max_pool2d(x,(2,2)):最大池化下采样对x进行处理;

        D. x.view(-1,x.size()[1:].numel()):进行拉平、展平之后给全连接层,对当前的输入数据x进行一个形式转换,输入行和列,这里所对应的列等于self.fc1 = nn.Linear(16*5*5,120)这里所对应的行,为x.size切片之后数据的乘积;行信息根据批次信息自动生成,-1让程序自动生成这个行;为什么要切1,对于tensor信息来说,将batch切掉,channel、H、W相乘等于16*5*5;

       注意: Pytorch处理的都是张量(张量是神经网络所使用的主要数据结构)数据。

四、测试网络效果

        相当于打印网络初始化部分,也可以与网络结构相对应检查一下。

net = Net()
print(net)

        参考:Pytorch逐行代码入门学习_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1133254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论多段图的最短路径问题(我认为本质上还是暴力枚举法)

比如说这道题:我向前推进 从0到11的最短路径 按照图可以分5段,v1 是第一阶段 0,v2是第二段 有1,2,3,4 从0开始,路径为0,所以m(1,0)0&#xff1b…

单片机核心/RTOS必备 (ARM汇编)

ARM汇编概述 一开始,ARM公司发布两类指令集: ARM指令集,这是32位的,每条指令占据32位,高效,但是太占空间。Thumb指令集,这是16位的,每条指令占据16位,节省空间。 要节…

reqable(小黄鸟)+雷电抓包安卓APP

x 下载证书保存到雷电模拟器根目录(安装位置) 在根目录打开cmd执行命令 F:\Program\leidian\LDPlayer9>adb push reqable-ca.crt /system/etc/security/cacerts/364618e0.0 reqable-ca.crt: 1 file pushed, 0 skipped. 0.8 MB/s (1773 bytes in 0.002s)如果是powershell…

Hadoop3.0大数据处理学习4(案例:数据清洗、数据指标统计、任务脚本封装、Sqoop导出Mysql)

案例需求分析 直播公司每日都会产生海量的直播数据,为了更好地服务主播与用户,提高直播质量与用户粘性,往往会对大量的数据进行分析与统计,从中挖掘商业价值,我们将通过一个实战案例,来使用Hadoop技术来实…

Node编写更新用户头像接口

目录 定义路由和处理函数 验证表单数据 ​编辑 实现更新用户头像的功能 定义路由和处理函数 向外共享定义的更新用户头像处理函数 // 更新用户头像的处理函数 exports.updateAvatar (req, res) > {res.send(更新成功) } 定义更新用户头像路由 // 更新用户头像的路由…

Liunx-Kubernetes安装

安装Kubernetes Kubernetes有多种部署方式,目前主流的方式有kubeadm、minikube、二进制包 minikube:一个用于快速搭建单节点kubernetes的工具kubeadm:一个用于快速搭建kubernetes集群的工具二进制包:从官网下载每个组件的二进制…

浏览器多开,数据之间相互不干扰

方法很简单 在浏览器快捷方式中,快捷键点开属性,在目标中添加--user-data-dirD:\chrome\1

【蓝桥每日一题]-贪心(保姆级教程 篇1)#拼数 #合并果子 #凌乱yyy

目录 题目: 拼数 思路: 题目: 合并果子 思路: 题目:凌乱yyy 思路: 题目:拼数 思路: 思路很简单。举个例子:对于a321,b32。我们发现ab32132,ba32321,那么…

TypeScript学习 | 泛型

简介 泛型是指在定义函数、接口或类的时候,不预先指定具体的类型,而在使用的时候再指定类型的一种特性 作用 可以保证类型安全的前提下,让函数、接口或类与多种类型一起工作,从而实现复用 基本使用 举个例子: 创…

【TGRS 2023】RingMo: A Remote Sensing Foundation ModelWith Masked Image Modeling

RingMo: A Remote Sensing Foundation Model With Masked Image Modeling, TGRS 2023 论文:https://ieeexplore.ieee.org/stamp/stamp.jsp?tp&arnumber9844015 代码:https://github.com/comeony/RingMo MindSpore/RingMo-Framework (gitee.com) …

解决:vscode和jupyter远程连接无法创建、删除文件的问题(permission denied)

目录 问题:vscode和jupyter远程连接服务器无法创建、删除文件的问题原因:代码文件的权限不够解决方法:1.ls -l查看目录所在组,权限2.chown修改拥有者和所在组 问题:vscode和jupyter远程连接服务器无法创建、删除文件的…

【兔子王赠书第3期】《案例学Python(进阶篇)》

文章目录 前言推荐图书本书特色本书目录本书样章本书读者对象粉丝福利丨评论免费赠书尾声 前言 随着人工智能和大数据的蓬勃发展,Python将会得到越来越多开发者的喜爱和应用。因为Python语法简单,学习速度快,大家可以用更短的时间掌握这门语…

Spring学习笔记—JDK动态代理

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: Spring专栏 ✨特色专栏: M…

Unity中Shader的ShaderLOD

文章目录 前言一、ShaderLOD的使用步骤1、ShaderLOD使用在不同的SubShader中,用于区分SubShader所对应的配置2、在 C# 中使用 Shader.globalMaximumLOD 赋值来选择不同的 SubShader,以达到修改配置对应Shader的效果3、在设置LOD时,是需要和程序讨论统一 …

WebGL笔记:矩阵的变换之平移的实现

矩阵的变换 变换 变换有三种状态:平移、旋转、缩放。当我们变换一个图形时,实际上就是在移动这个图形的所有顶点。解释 webgl 要绘图的话,它是先定顶点的,就比如说我要画个三角形,那它会先把这三角形的三个顶点定出来…

为什么需要山洪灾害监测预警系统?

在山洪高发地区,安装山洪灾害监测预警系统能够通过实时监测,预警山洪信息,对于保障我们的生命财产安全具有重要意义。 监测山洪不仅需要对山体进行监测,还要监测降雨量以及水位上升情况。山洪灾害监测预警系统是由GNSS监测站和水…

linux安装node(含npm命令) 并配置淘宝镜像源

1. 下载压缩包 wget https://nodejs.org/dist/v16.14.0/node-v16.14.0-linux-x64.tar.xz # node14 https://nodejs.org/dist/v14.15.4/node-v14.15.4-linux-x64.tar.xz # 推荐将压缩包放置到/usr/local/node文件夹中安装 mv node-v16.14.0-linux-x64.tar.xz /usr/local/node …

LeetCode217——存在重复元素

LeetCode217——存在重复元素 1.题目描述: 给你一个整数数组 nums 。如果任一值在数组中出现 至少两次 ,返回 true ;如果数组中每个元素互不相同,返回 false 。 2.Result01(暴力解) public static boolean containsDuplicate(in…

SRAM与DRAM的区别

目录 SRAM 特点 应用场景 DRAM 特点 应用场景 SRAM和DRAM的区别 SRAM SRAM(静态随机存取存储器)是一种用于存储和检索数据的类型的计算机内存。SRAM的存储单元通过触发器(flip-flop)实现,它们可以保持数据的状态…

语雀崩溃7个小时的原因是什么??

1 语雀是什么 语雀是蚂蚁集团旗下的在线文档编辑与协同工具,使用了“结构化知识库管理”,形式上类似书籍的目录。用户量在千万级别,是非常强大的。身边有不少朋友是付费会员,有许多公司也付费在使用语雀作为知识库进行文档的存储…