【深度学习入门篇 ⑥】PyTorch搭建卷积神经网络

news2025/1/23 12:04:03

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


卷积神经网络是深度学习在计算机视觉领域的突破性成果,在计算机视觉领域,往往我们输入的图像都很大,使用全连接网络的话,计算的代价较高;另外图像也很难保留原有的特征,导致图像处理的准确率不高。

卷积神经网络(CNN)是含有卷积层的神经网络,卷积层的作用就是用来自动学习、提取图像的特征。

CNN网络主要有三部分构成:卷积层、池化层和全连接层构成,其中卷积层负责提取图像中的局部特征;池化层用来大幅降低参数量级(降维);全连接层类似人工神经网络的部分,用来输出想要的结果。

像素和通道的理解

我们使用 matplotlib 库来实际理解图像知识:

import numpy as np
import matplotlib.pyplot as plt


def func1():

    img = np.zeros([200, 200])
    print(img)
    plt.imshow(img, cmap='gray', vmin=0, vmax=255)   # imshow显示图像
    plt.show()

    img = np.full([255, 255], 255)
    print(img)
    plt.imshow(img, cmap='gray', vmin=0, vmax=255)
    plt.show()


#  图像的通道
def func2():

    img = plt.imread('QQ.png')
    # 修改数据的维度
    img = np.transpose(img, [2, 0, 1])

    # 打印所有通道
    for channel in img:
        print(channel)
        plt.imshow(channel)
        plt.show()


    # 修改透明度
    img[2] = 0.05
    img = np.transpose(img, [1, 2, 0])
    plt.imshow(img)
    plt.show()


if __name__ == '__main__':
    func1()
    func2()

💯输出:

  • 图像是由像素点组成的,像素值的范围 [0, 255] 值越小表示亮度越小,值越大,表名亮度值越大。一个全0的图像就是一副全黑图像。 一个复杂的图像则是由多个通道组合在一起形成的。 

卷积层

卷积包含一维卷积,二维卷积,三维卷积,在这里以二维卷积为主,如果明白了二维卷积,就知道其他维卷积是怎么回事了

二维卷积

我们看一下卷积核的计算过程,也就是卷积核是如何提取特征的:

  1. input 表示输入的图像
  2. filter 表示卷积核, 也叫做滤波器
  3. input 经过 filter 的得到输出为最右侧的图像,该图叫做特征图

卷积运算本质上就是在滤波器和输入数据的局部区域间做点积。

按照上面的计算方法可以得到最终的特征图为:

Padding 

通过上面的卷积计算过程,我们发现最终的特征图比原始图像小很多,如果想要保持经过卷积后的图像大小不变, 可以在原图周围添加 padding 来实现。

Stride

Stride指定了卷积核在遍历输入特征图时,每次移动的距离。

格式:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)

其中

  • in_channels: 输入通道数
  • out_channels: 输出通道数(卷积核数量)
  • kernel_size: 卷积核大小
  • stride: 卷积步长
  • padding: 边缘补零
  • dilation: 扩散卷积
  • group: 分组卷积
  • bias: 是否带有偏置
import torch
import torch.nn as nn
#使用方形卷积核,以及相同的步长
m = nn.conv2d(16,33,3, stride=2)
#使用非方形的卷积核,以及非对称的步长和补零
m = nn. Conv2d(16,33,(3,5), stride=(2,1), padding=(4,2))
#使用非方形的卷积核,以及非对称的步长,补零和膨胀系数
m = nn.Conv2d(16,33,(3,5), stride=(2,1),padding=(4,2), dilation=(3,1))input = torch.randn(20,16,50,100)
output = m(input)
print(output.shape)

 输出:

torch.Size([20,33,26,100])

卷积层提取案例

我们接下来对下面的图片进行特征提取:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt


# 显示图像
def show(img):

    # 输入: (Height, Width, Channel)
    plt.imshow(img)
    plt.axis('off')
    plt.show()


# 单个多通道卷积核
def func1():

    img = plt.imread('QQ.png')
    show(img)

    conv = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=3, stride=1, padding=1)
    img = torch.tensor(img).permute(2, 0, 1).float()  # 转换为float类型以匹配默认的tensor类型
    img = img.unsqueeze(0)
    new_img = conv(img)
    new_img = new_img.squeeze(0).permute(1, 2, 0)

    show(new_img.detach().numpy())


# 多个多通道卷积核
def func2():


    img = plt.imread('QQ.png')
    show(img)

    conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
    img = torch.tensor(img).permute(2, 0, 1).float()  # 转换为float类型
    img = img.unsqueeze(0)

    new_img = conv(img)
    new_img = new_img.squeeze(0).permute(1, 2, 0)

    # 打印三个特征图
    show(new_img[:, :, 0].unsqueeze(2).detach().numpy())
    show(new_img[:, :, 1].unsqueeze(2).detach().numpy())
    show(new_img[:, :, 2].unsqueeze(2).detach().numpy())


if __name__ == '__main__':
    func1()
    func2()

 输出:

转置卷积 :就是卷积的逆操作,也称为逆卷积、反卷积。

torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode=‘zeros’, device=None, dtype=None)
  • 输入:(𝑁,𝐶𝑖𝑛,𝐻𝑖𝑛,𝑊𝑖𝑛)或者(𝐶𝑖𝑛,𝐻𝑖𝑛,𝑊𝑖𝑛)
  • 输出:(𝑁,𝐶𝑜𝑢𝑡,𝐻𝑜𝑢𝑡,𝑊𝑜𝑢𝑡)或者(𝐶𝑜𝑢𝑡,𝐻𝑜𝑢𝑡,𝑊𝑜𝑢𝑡)
import torch.nn as nnimport torch
#使用长宽一致的卷积核以及相同的步长
m = nn.ConvTranspose2d( 16,33,3, stride=2)#使用长宽不一致的卷积核,步长,以及补零
m = nn.ConvTranspose2d(16,33,(3,5), stride=(2,1), padding=(4,2))
input = torch.randn(20,16,50,100)
output = m( input)
#可以直接指明输出的尺寸大小
input = torch.randn(1,16,12,12)
downsample = nn.conv2d(16,16,3, stride=2, padding=1)
upsample = nn.ConvTranspose2d(16,16,3, stride=2, padding=1)
h = downsample( input)
print(h.size())
output = upsample(h,output_size=input.size( ))
print(output.size())

输出:

torch.Size([1,16,6,6])
torch.Size([1,16,12,12])

案例:搭建全卷积网络结构

import torch
import torch.nn as nn
import torch.nn.functional as F


class FCN(nn. Module) :
    def __init__(self, num_class):
        super(FCN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)

        self.unsample1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3)
        self.unsample2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3)
        self.unsample3 = nn.ConvTranspose2d(in_channels=32, out_channels=num_class, kernel_size=3)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.unsample1(x))
        x = F.relu(self.unsample2(x))
        x = F.relu(self.unsample3(x))
        return x

num_class = 10
model = FCN(num_class)

print(model)

案例:搭建卷积+全连接的网络结构

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        # 第一层卷积
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # 添加padding以避免尺寸减小
        self.pool1 = nn.MaxPool2d(2, 2)  # 第一个池化层
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 添加padding
        self.pool2 = nn.MaxPool2d(2, 2)  # 第二个池化层

        self.flatten = nn.Flatten(start_dim=1)
        # 计算fc1的输入特征数:64 * (28/2/2) * (28/2/2) = 64 * 7 * 7
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)  # 应用池化层
        x = self.conv2(x)
        x = self.pool2(x)  # 应用另一个池化层
        x = self.flatten(x)  # 展平
        x = self.fc1(x)
        x = self.fc2(x)
        return x


num_class = 10
model = ConvNet(num_class)

batch_size = 4
input_tensor = torch.randn(batch_size, 3, 28, 28)

output = model(input_tensor)
print(output)

输出:

tensor([[ 0.0986, -0.1008, -0.0225, -0.1896, -0.1659,  0.0817, -0.0684, -0.0195,
         -0.1648,  0.0578],
        [ 0.0241, -0.0391,  0.0014, -0.1261, -0.0593,  0.0679, -0.1342, -0.0396,
         -0.2054,  0.1309],
        [ 0.0549, -0.0116, -0.0471, -0.1747, -0.0148,  0.1378, -0.2085,  0.0004,
         -0.1579,  0.1637],
        [ 0.0553, -0.1103,  0.1054, -0.0782, -0.1624, -0.0047, -0.2090,  0.0089,
         -0.2294,  0.0865]], grad_fn=<AddmmBackward0>)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1933543.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

‍我想我大抵是疯了,我喜欢上了写单元测试

前言 大家好我是聪。相信有不少的小伙伴喜欢写代码&#xff0c;但是对于单元测试这些反而觉得多此一举&#xff0c;想着我都在接口文档测过了&#xff01;还要写什么单元测试&#xff01;写不了一点&#xff01;&#xff01; 由于本人也是一个小小程序猿&#x1f649;&#xf…

Unity | Shader基础知识(第十八集:Stencil应用-透视立方盒子)

目录 一、前言 二、场景布置 三、 shader部分 1.图片的部分 2.图片部分纯净代码 3.遮罩部分复习 4.深度写入 ZWrite 5.颜色遮罩ColorMask 6.遮罩纯净代码 四、场景中shader使用 五、作者的碎碎念 一、前言 因为这个内容稍微有点多&#xff0c;我尽力讲清楚了&#x…

VAE论文阅读

在网上看到的VAE解释&#xff0c;发现有两种版本&#xff1a; 按照原来论文中的公式纯数学推导&#xff0c;一般都是了解生成问题的人写的&#xff0c;对小白很不友好。按照实操版本的&#xff0c;非常简单易懂&#xff0c;比如苏神的。但是却忽略了论文中的公式推导&#xff…

jquery中pdf在页面的显示和导出

jquery中pdf在页面的显示和导出 01 显示pdf01 .pdf结尾在线接口显示到页面 &#xff08;pdf.js库怎么安装及使用&#xff09;&#xff1a;只显示一页02 如何用PDF.JS显示整个PDF (而不仅仅是一页)&#xff1f;03 jQuery实现在线预览PDF文件(通过a标签链接跳转)&#xff1a; 02 …

【网络安全】PostMessage:分析JS实现XSS

未经许可&#xff0c;不得转载。 文章目录 前言示例正文 前言 PostMessage是一个用于在网页间安全地发送消息的浏览器 API。它允许不同的窗口&#xff08;例如&#xff0c;来自同一域名下的不同页面或者不同域名下的跨域页面&#xff09;进行通信&#xff0c;而无需通过服务器…

【STM32 HAL库】全双工DMA双buffer的I2S使用

1、配置I2S 我们的有效数据是32位的&#xff0c;使用飞利浦格式。 2、配置DMA **这里需要注意&#xff1a;**i2s的DR寄存器是16位的&#xff0c;如果需要发送32位的数据&#xff0c;是需要写两次DR寄存器的&#xff0c;所以DMA的外设数据宽度设置16位&#xff0c;而不是32位。…

ArrayLis练习

代码呈现 import java.util.ArrayList;public class ArrayListTest {public static void main(String[] args) {//创建集合ArrayList<String> list new ArrayList();//添加元素list.add("A");list.add("B");list.add("C");list.add(&quo…

222.买卖股票的最佳时机(力扣)

代码解决 class Solution { public:int maxProfit(vector<int>& prices) {// 初始化最小买入价为第一个价格int min1 prices[0];// 初始化最大利润为0int max1 0;// 从第二天开始遍历价格数组for (int i 1; i < prices.size(); i) {// 计算当前价卖出的利润&a…

C++:智能指针shared_ptr、unique_ptr、weak_ptr的概念、用法即它们之间的关系

智能指针 (1)概述 A.Why&#xff08;C为什么引入智能指针&#xff09; C引入智能指针的根本原因就是解决手动管理动态内存所带来的问题&#xff0c;手动管理动态内存常见的问题如下&#xff1a;内存泄漏、悬挂指针、释放操作未定义等 内存泄漏问题&#xff1a; 当程序用光了它…

React的usestate设置了值后马上打印获取不到最新值

我们在使用usestate有时候设置了值后&#xff0c;我们想要更新一些值&#xff0c;这时候&#xff0c;我们要想要马上获取这个值去做一些处理&#xff0c;发现获取不到&#xff0c;这是为什么呢&#xff1f; 效果如下&#xff1a; 1、原因如下 在React中,当你使用useState钩子…

线程安全(七)ReentrantLock 简介、Condition 条件变量、锁的工作原理、synchronized 与 Lock 的区别

目录 一、ReentrantLock 简介1.1 Reentrant 的特性:1.2 基本语法1.3 ReentrantLock 的主要方法:1.4 lock()、tryLock()、lockInterruptibly() 的区别:二、Condition 条件变量2.1 什么是 Condition 条件变量?2.2 Condition 的核心方法:2.3 Condition 使用示例1:等待与唤醒…

PJA1介导的焦亡抑制是鼻咽癌产生耐药性的驱动因素

引用信息 文 章&#xff1a;PJA1-mediated suppression of pyroptosis as a driver of docetaxel resistance in nasopharyngeal carcinoma. 期 刊&#xff1a;Nature Communications&#xff08;影响因子&#xff1a;14.7&#xff09; 发表时间&#xff1a;2024年6月2…

LLaMA-Factory

文章目录 一、关于 LLaMA-Factory项目特色性能指标 二、如何使用1、安装 LLaMA Factory2、数据准备3、快速开始4、LLaMA Board 可视化微调5、构建 DockerCUDA 用户&#xff1a;昇腾 NPU 用户&#xff1a;不使用 Docker Compose 构建CUDA 用户&#xff1a;昇腾 NPU 用户&#xf…

变阻器与电位器有什么区别?

变阻器和电位器都是可以改变电阻值的电子元件&#xff0c;它们在电路中的作用和调节方式有一定的相似性&#xff0c;但它们之间还是存在一些区别的。 1. 结构上的区别&#xff1a;变阻器主要由固定电阻体和可动滑片组成&#xff0c;通过滑动滑片来改变电阻体的电阻值。而电位器…

数据库(创建数据库和表)

目录 一&#xff1a;创建数据库 二&#xff1a;创建表 2.1&#xff1a;创建employees表 2.2&#xff1a;创建orders表 2.3&#xff1a;创建invoices表 一&#xff1a;创建数据库 mysql> create database mydb6_product; Query OK, 1 row affected (0.01 sec) mysql&g…

linux centos limits.conf 修改错误,无法登陆问题修复 centos7.9

一、问题描述 由于修改/etc/security/limits.conf这个文件中的值不当&#xff0c;重启后会导致其账户无法远程登录&#xff0c;本机登录。 如改成这样《错误示范》&#xff1a; 会出现&#xff1a; 二、解决 现在知道是由于修改limits.conf文件不当造成的&#xff0c;那么就…

智慧农业新纪元:解锁新质生产力,加速产业数字化转型

粮食安全乃国家之根本&#xff0c;“浙江作为农业强省、粮食生产重要省份&#xff0c;在维护国家粮食安全大局中肩负着重大使命。浙江粮食产业经济年总产值已突破4800亿元&#xff0c;稳居全国前列&#xff0c;然而&#xff0c;同样面临着规模大而不强、质量效益有待提升、数字…

JVM高频面试点

文章目录 JVM内存模型程序计数器Java虚拟机栈本地方法栈Java堆方法区运行时常量池 Java对象对象的创建如何为对象分配内存 对象的内存布局对象头实例数据对齐填充 对象的访问定位 垃圾收集器找到垃圾引用计数法可达性分析&#xff08;根搜索法&#xff09; 引用概念的扩充回收方…

字符数组的魅力:C语言字符数组与字符串编程实践

1.概念 字符数组&#xff0c;数组元素是char(字符型)的数组&#xff0c;它可以是一维数组&#xff0c;也可以是二维数组。 2.定义的时候赋值 char ch1[]{c,h,i,n,a}; char ch2[]{"china"}; //相当于 char ch2[] "china"; 元素个数为6&#xff0c;默认会…

探索Linux世界 —— shell与权限的相关知识

一、shell以及其运行原理 1、什么是shell Linux严格意义上说的是一个操作系统&#xff0c;我们称之为“核心&#xff08;kernel&#xff09;“ &#xff0c;但我们一般用户&#xff0c;不能直接使用kernel。而是通过kernel的“外壳”程序&#xff0c;也就是所谓的shell&#x…