Pytorch深度学习实践笔记11(b站刘二大人)

news2024/9/19 16:54:54

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili​

目录

1 1x1 卷积

2 卷积实现

3 代码


1 1x1 卷积


(1)用来对通道数进行降维或升维,保持Feature Map长宽不变,减少计算量
(2)实现跨通道信息的融合
(3)可以保持输入和输出网络结构不变的同时,融合特征
 

1x1卷积(Conv 1*1)的作用​

2 卷积实现


多层网络组成




concat算子连接:



 




3 代码
 

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
 
# prepare dataset
 
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化,均值和方差
 
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
 
# design model using class
class InceptionA(nn.Module):
    def __init__(self, in_channels):
        super(InceptionA, self).__init__()
        self.branch1x1 = nn.Conv2d(in_channels, 16, kernel_size=1)
 
        self.branch5x5_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        self.branch5x5_2 = nn.Conv2d(16, 24, kernel_size=5, padding=2)
 
        self.branch3x3_1 = nn.Conv2d(in_channels, 16, kernel_size=1)
        self.branch3x3_2 = nn.Conv2d(16, 24, kernel_size=3, padding=1)
        self.branch3x3_3 = nn.Conv2d(24, 24, kernel_size=3, padding=1)
 
        self.branch_pool = nn.Conv2d(in_channels, 24, kernel_size=1)
 
    def forward(self, x):
        branch1x1 = self.branch1x1(x)
 
        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)
 
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)
        branch3x3 = self.branch3x3_3(branch3x3)
 
        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)
 
        outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
        return torch.cat(outputs, dim=1) # b,c,w,h  c对应的是dim=1
 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(88, 20, kernel_size=5) # 88 = 24x3 + 16
 
        self.incep1 = InceptionA(in_channels=10) # 与conv1 中的10对应
        self.incep2 = InceptionA(in_channels=20) # 与conv2 中的20对应
 
        self.mp = nn.MaxPool2d(2)
        self.fc = nn.Linear(1408, 10) 
 
 
    def forward(self, x):
        in_size = x.size(0)
        x = F.relu(self.mp(self.conv1(x)))
        x = self.incep1(x)
        x = F.relu(self.mp(self.conv2(x)))
        x = self.incep2(x)
        x = x.view(in_size, -1)
        x = self.fc(x)
 
        return x
 
model = Net()
 
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
# training cycle forward, backward, update
 
 
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()
 
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0.0
 
 
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('accuracy on test set: %d %% ' % (100*correct/total))
 
 
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()



引入残差解决梯度消失,上一节已经讲过,构建更深的网络




代码:

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
 
# prepare dataset
 
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化,均值和方差
 
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
 
# design model using class
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.channels = channels
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
 
    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return F.relu(x + y)
 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5) # 88 = 24x3 + 16
 
        self.rblock1 = ResidualBlock(16)
        self.rblock2 = ResidualBlock(32)
 
        self.mp = nn.MaxPool2d(2)
        self.fc = nn.Linear(512, 10) # 暂时不知道1408咋能自动出来的
 
 
    def forward(self, x):
        in_size = x.size(0)
 
        x = self.mp(F.relu(self.conv1(x)))
        x = self.rblock1(x)
        x = self.mp(F.relu(self.conv2(x)))
        x = self.rblock2(x)
 
        x = x.view(in_size, -1)
        x = self.fc(x)
        return x
 
model = Net()
 
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
# training cycle forward, backward, update
 
 
def train(epoch):
    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()
 
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if batch_idx % 300 == 299:
            print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
            running_loss = 0.0
 
 
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('accuracy on test set: %d %% ' % (100*correct/total))
 
 
if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1703755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Bug:Linux用户拥有r权限但无法打开文件【Linux权限体系】

Bug:Linux用户拥有r权限但无法打开文件【Linux权限体系】 0 问题描述&解决 问题描述: 通过go编写了一个程序,产生的/var/log/xx日志文件发现普通用户无权限打开 - 查看文件权限发现该文件所有者、所有者组、其他用户均有r权限 - 查看该日…

AI数学知识

AI数学知识 1、线性代数相关(矩阵)1、什么是秩2、奇异值分解3、特征值分解和奇异值分解4、低秩分解 回归分类知识点2、概率论相关1、先验概率和后验概率2、条件概率、全概率公式、贝叶斯公式、联合概率3、最大似然估计4、贝叶斯公式和最大似然估计5、伯努…

深入解读力扣154题:寻找旋转排序数组中的最小值 II(多种方法及详细ASCII图解)

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容,和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣! 推荐:数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航: LeetCode解锁100…

【VTKExamples::Utilities】第十三期 SaveSceneToFile

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例SaveSceneToFile,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞关注,小易会继续努力分享,一起进步…

高项案例分析知识点总结

文章目录 纠错题计算题进度估算成本管理立项管理版本管理组合管理知识产权信息技术计算题运筹学 纠错题 人:人员经验、能力、数量、缺少培训;自己一个人完成需求和计划不正确流程:先做什么,后做什么,流程是否正确。是…

【UE Slate】 虚幻引擎Slate开发快速入门

目录 0 引言1 Slate框架1.0 控件布局1.1 SWidget1.1.1 SWidget的主要作用1.1.2 SWidget的关键方法1.1.3 使用SWidget创建自定义控件1.1.4 结论 1.2 SCompoundWidget1.2.1 SCompoundWidget的主要作用1.2.2 SCompoundWidget的使用示例1.2.3 SCompoundWidget的关系1.2.4 总结 1.3 …

闲话 .NET(7):.NET Core 能淘汰 .NET FrameWork 吗?

前言 虽然说,目前 .NET FrameWork 上的大部分类都已经移植到 .NET Core 上,而且 .NET FrameWork 也已经停止了更新,未来必然是 .NET Core 的天下,但要说现在 .NET Core 就能淘汰 .NET FrameWork,我觉得为时尚早&#…

C++学习笔记(21)——继承

目录 1. 继承的概念及定义1.1 继承的概念1.2 继承定义1.2.1 定义格式1.2.2 继承关系和访问限定符1.2.3 继承基类成员访问方式的变化 继承的概念总结: 2. 基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数知识点:派生类中6个默认成员函数…

利用java8 的 CompletableFuture 优化 Flink 程序,性能提升 50%

你好,我是 shengjk1,多年大厂经验,努力构建 通俗易懂的、好玩的编程语言教程。 欢迎关注!你会有如下收益: 了解大厂经验拥有和大厂相匹配的技术等 希望看什么,评论或者私信告诉我! 文章目录 一…

【吊打面试官系列】Java高并发篇 - ConcurrentHashMap 的并发度是什么?

大家好,我是锋哥。今天分享关于 【ConcurrentHashMap 的并发度是什么?】面试题,希望对大家有帮助; ConcurrentHashMap 的并发度是什么? ConcurrentHashMap 的并发度就是 segment 的大小,默认为 16, 这意味着最多同时…

.DFS.

DFS 全称为Depth First Search,中文称为深度优先搜索。 这是一种用于遍历或搜索树或图的算法,其思想是: 沿着每一条可能的路径一个节点一个节点地往下搜索, 直到路径的终点,然后再回溯,直到所有路径搜索完为止。 DFS俗…

ComfyUI简单介绍

🍓什么是ComfyUI ComfyUI是一个为Stable Diffusion专门设计的基于节点的图形用户界面,可以通过各种不同的节点快速搭建自己的绘图工作流程。 软件打开之后是长这个样子: 同时软件本身是github上的一个开源项目,开源地址为&#…

I.MX6ULL Linux C语言开发环境搭建(点灯实验)

系列文章目录 I.MX6ULL Linux C语言开发 I.MX6ULL Linux C语言开发 系列文章目录一、前言二、硬件原理分析三、构建步骤一、 C语言运行环境构建二、软件编写三、链接脚本 四、实验程序编写五、编译下载验证 一、前言 汇编语言编写 LED 灯实验,但是实际开发过程中汇…

Python实现国密GmSSL

Python实现国密GmSSL 前言开始首先安装生成公钥与私钥从用户证书中读取公钥读取公钥生成签名验证签名加密解密 遇到的大坑参考文献 前言 首先我是找得到的gmssl库,经过实操,发现公钥与密钥不能通过pem文件得到,就是缺少导入pem文件的api。这…

maven的下载以及配置的详细教程(附网盘下载地址)

文章目录 下载配置IDEA内部使用配置 下载 1.百度网盘下载 链接: https://pan.baidu.com/s/1LD9wOMFalLL49XUscU4qnQ?pwd1234 提取码: 1234 2.解压即可 配置 1.打开安装文件下conf下的settings.xml文件,我的如下 2.修改配置信息(目的是为了修改本地…

【技术分享】Maven常用配置

一、Maven简介 (一)为什么使用 Maven 由于 Java 的生态非常丰富,无论你想实现什么功能,都能找到对应的工具类,这些工具类都是以 jar 包的形式出现的,例如 Spring,SpringMVC、MyBatis、数据库驱…

MQ本地消息事务表

纯技术方案水文特此记录 MQ本地消息事务表解决了什么问题? MQ本地事务表方案解决了本地事务与消息发送的原子性问题,即:事务发起方在本地事务执行成功后消息必须发出去,否则就丢弃消息。实现本地事务和消息发送的原子性&#xf…

系统安全扫描扫出了:可能存在 CSRF 攻击怎么办

公司的H5在软件安全测试中被检查出可能存在 CSRF 攻击,网上找了一堆解决方法,最后用这种方式解决了。 1、问题描述 CSRF 是 Cross Site Request Forgery的缩写(也缩写为也就是在用户会话下对某个 CGI 做一些 GET/POST 的事,RIVTSTCNNARGO一这…

香橙派AIpro初体验,详解如何安装Home Assistant Supervised

香橙派AIpro(OrangePi AIpro)开发版,定位是一块AI开发板,搭载的是华为昇腾310(Ascend310)处理器。 没想到,这几年的发展,AI开发板也逐渐铺开,记得之前看到华为发布昇腾3…

挑战你的数据结构技能:复习题来袭【3】

chap3 练习1 一. 单选题 1. (单选题)栈和队列具有相同的() A. 抽象数据类型B. 逻辑结构C. 存储结构D. 运算 答案: B:逻辑结构 答案分析:逻辑结构都属于线性结构,只是它们对数据的运算不同。 2. (单选题)栈是() A. 顺序存储的线性结构B…