pytorch深度学习基础(十)——常用线性CNN模型的结构与训练

news2025/1/12 9:39:13

线性CNN模型的结构与训练

  • 引入包
  • LeNet
    • 模型结构
    • 模型构建
  • AlexNet
    • 模型结构
    • 模型构建
  • VGG
    • 模型结构
  • 模型构建
  • 加载数据集
  • 累加器
  • 精度
  • 训练

引入包

import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

LeNet

LeNet是一个较为简单的线性卷积模型,最早是应用于手写数字识别,提到手写数字识别就不得不提到mnist数据集,而为了与后续其他较强的模型进行对比,我们采用fashion_mnist数据集为例。

模型结构

模型结构如下
在这里插入图片描述

模型构建

因为LeNet是线性模型,所以我们直接使用torch.nn.Sequential构建模型

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),  # (1, 28, 28)->(6, 28, 28)
                    nn.AvgPool2d(kernel_size=2, stride=2),  # (6, 28, 28)->(6, 14, 14)
                    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),  # (6, 14, 14)->(16, 10, 10)
                    nn.AvgPool2d(kernel_size=2, stride=2),  # (16, 10, 10)->(16, 5, 5)
                    nn.Flatten(),
                    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),  # (16 * 5 * 5, 120)
                    nn.Linear(120, 84), nn.Sigmoid(),  # (120, 84)
                    nn.Linear(84, 10)
                   )

我们先构建一个伪数据输入模型,来观察它的每一层的结构

X= torch.randn((1, 1, 28, 28))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)

在这里插入图片描述

AlexNet

模型结构

AlexNet也是一个线性模型,相当于算是LeNet的加强版,模型结构如下
在这里插入图片描述

模型构建

net = nn.Sequential(nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=1),nn.ReLU(),  # (3, 224, 224)->(96, 54, 54)
                    nn.MaxPool2d(kernel_size=3, stride=2),  # (96, 26, 26)
                    nn.Conv2d(96, 256, kernel_size=5, padding=2), nn.ReLU(),  # (256, 26, 26)
                    nn.MaxPool2d(kernel_size=3, stride=2),  # (256, 12, 12)
                    nn.Conv2d(256, 384, kernel_size=3, padding=1), nn.ReLU(),  # (384, 12, 12)
                    nn.Conv2d(384, 384, kernel_size=3, padding=1), nn.ReLU(),  # (384, 12, 12)
                    nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(),  # (256, 12, 12)
                    nn.MaxPool2d(kernel_size=3, stride=2),  # (256, 5, 5)
                    nn.Flatten(),  # 256 * 5 * 5
                    nn.Linear(6400, 4096), nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(4096, 4096), nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(4096, 10)
                   )

观察

X = torch.randn((1, 3, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, '\t', X.shape)

在这里插入图片描述

VGG

模型结构

VGG虽然仍然是线性模型,但是他引入了块的结构,模型结构依然是线性,只不过构建的方式更为简便,即不用一层层的去构建,直接通过构建块来构建

模型构建

vgg块

def vgg_block(num_conv, in_channels, out_channels):
    layers = []
    for _ in range(num_conv):
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
        
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

构建vgg11模型

def vgg11(in_channel):
    conv_archs = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
    conv_blks = []
    for (num_conv, out_channel) in conv_archs:
        conv_blks.append(vgg_block(num_conv, in_channel, out_channel))
        in_channel = out_channel
    
    return nn.Sequential(*conv_blks, nn.Flatten(),
                         nn.Linear(out_channel*7*7, 4096), nn.ReLU(),
                         nn.Dropout(0.5),
                         nn.Linear(4096, 4096), nn.ReLU(),
                         nn.Linear(4096, 10)
                        )
net = vgg11(1)

加载数据集

可以直接使用torchvision.datasets中的FashionMNIST直接加载数据集,其中需要注意的是读入是图像的数据一定要使用transforms.ToTensor将数据转换成torch模型支持的数据类型,即tensor类型

def load_fashion_mnist(batch_size, resize=None):
    tran = [transforms.ToTensor()]
    if resize:
        tran.insert(0, transforms.Resize(resize))
    tran = transforms.Compose(tran)
    mnist_train = datasets.FashionMNIST(root='./data', train=True, transform=tran, download=True)
    mnist_test = datasets.FashionMNIST(root='./data', train=False, transform=tran, download=True)
    return (DataLoader(mnist_train, batch_size, shuffle=True),
            DataLoader(mnist_test, batch_size, shuffle=False)
           )

batch_size = 256
train_iter, test_iter = load_fashion_mnist(batch_size)

累加器

定义一个累加器,用来累积,在后边训练的时候用于统计精度和损失,这个是参照的李沐的动手学深度学习,刚开始的时候我也是按照这个这样写的,但是后边根据个人的使用习惯已经被替换掉了,暂时这个还是跟李沐老师的符合一下。

class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n
        
    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
        
    def reset(self):
        self.data = [0.0] * len(self.data)
        
    def __getitem__(self, idx):
        return self.data[idx]

精度

def accuracy(y_hat, y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat =  torch.argmax(y_hat, axis=1)
        
    cmp = y_hat.type(y.dtype)==y
    return float(cmp.type(y.dtype).sum())

其中y_hat是预测标签,y是真实标签。
这里需要强调的是,y作为真实标签,每个标签是一个数字,而y_hat既可以是一个数字,也可以是神经网络输出的不同类别的概率,为了统一形式,当输入的y_hat不为数字时,我们使用torch.argmax获取类别概率最大的索引,即最大的值。

训练

def train(net, name, train_ter, test_iter, num_epochs, lr, device):
    def init_weights(m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)
    
    net.apply(init_weights)
    print("device in : ", device)
    net = net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        print("epoch {}/{}".format(epoch+1, num_epochs))
        metric = Accumulator(3)
        net.train()
        print('training...')
        for X, y in tqdm(train_iter, ncols=50, postfix="{}".format(name)):
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
            
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
        tqdm.write('train loss:{} , train acc:{}'.format(train_l, train_acc))
        metric.reset()
        net.eval()
        print('validating...')
        for X, y in (test_iter):
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            with torch.no_grad():
                l = loss(y_hat, y)
                metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
            
        test_l = metric[0] / metric[2]
        test_acc = metric[1] / metric[2]
        metric.reset()
        
        tqdm.write('test loss:{} , test acc:{}'.format(test_l, test_acc))

train(net, 'LeNet', train_iter, test_iter, 30, 0.95, 'cuda')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/179232.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于蜣螂算法的极限学习机(ELM)分类算法-附代码

基于蜣螂算法的极限学习机(ELM)分类算法 文章目录基于蜣螂算法的极限学习机(ELM)分类算法1.极限学习机原理概述2.ELM学习算法3.分类问题4.基于蜣螂算法优化的ELM5.测试结果6.参考文献7.Matlab代码摘要:本文利用蜣螂算法对极限学习机进行优化,并用于分类问…

【华为上机真题】连续字母长度

🎈 作者:Linux猿 🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊! &…

C语言--指针初阶

目录什么是指针?指针变量指针类型指针类型的意义在数组中举例野指针概念野指针成因如何规避野指针指针运算指针-整数指针关系运算指针-指针应用(求字符串长度)结语什么是指针? 在计算机科学中,指针(Pointer)是编程语言…

直接在Notepad++中运行GO

1.Windows上安装Go语言开发包参考链接:http://c.biancheng.net/view/3992.html1.1.下载Go语言开发包可以在Go语言官网 (https://golang.google.cn/dl/) 下载Windows 系统下的Go语言开发包,如下图所示。这里我们下载的是64 位的开发…

深度学习:轻量级神经网络MobileNet 从v1 到v2

深度学习:轻量级神经网络MoblieNet 从v1 到 v2MobileNet V1前言深度可分离卷积传统卷积Depth Wise ConvPoint Wise Conv性能对比MobileNet V2前言主要改进Inverted Residuals BlockResidual BlockExpansion LayerReLU6Linear Activation Function小结实验MobileNet …

大年初二、初三—— 牛客网刷题经验分享~

2023年大年初二、初三 —— 牛客网刷题经验分享~😎大年初初二、初三 —— 牛客网刷题经验分享~😎)前言🙌牛客网——基础语法【循环输出图形篇】🙌BC98 线段图案 🙌BC99 正方形图案 🙌BC100 直角三角形图案 …

计算机毕业设计选题推荐之Springboot校园篮球足球竞赛预约平台-Vue

,本系统分为用户和管理员两个角色,其中用户可以在线注册登陆,查看平台公告,查看篮球比赛介绍,在线预约参加篮球比赛。管理员可以对用户信息,比赛项目,比赛分类,平台公告信息等进行管…

Linux中如何给普通用户提权

引言: 北京时间2023/1/26/11:00 ,看到这个日期,我第一时间想到的是还有十几天就要开学啦!开学我是向往的,但是我并不怎么向往开学的考试,比如什么毛概和什么信息技术,可能是我深知自己在这些课…

实现自己的数据库一

一 前言从上篇原创文章到现在又是新的一年,今天是2023年的大年初三,先在这里祝各位亲爱的老铁们新年快乐,身体健康,在新的一年里更帅气、更漂亮,都能完成自己的小目标。一直以来,我对数据存储还是比较感兴趣…

卓有成效的用例设计方法

持续坚持原创输出,点击蓝字关注我吧用例设计作为测试工程师的立身之本,是衡量测试工程师综合素质的重要参考,时间是测试工作中重要的测试资源,通过设计高质量的测试用例可以有效地提升测试效率。本文旨在介绍测试工作中常用的五种…

恶意代码分析实战 18 64位

18.1 Lab21-01 当你不带任何参数运行程序时会发生什么? 当你运行这个程序却没带任何参数,它会立即退出。 根据你使用的IDAPro的版本,main函数可能没有被自动识别,你如何识别对main函数的调用? main函数有三个参数入…

NodeJS 中 Express 之中间件

NodeJS 中 Express 之中间件参考描述中间件next()一个简单的中间件函数使用全局中间件局部中间件共享注意事项位置next()分类错误级中间件内置中间件express.urlencoded()express.json()第三方中间件参考 项目描述哔哩哔哩黑马程序员搜索引擎Bing 描述 项目描述Edge109.0.151…

【web前端】盒子模型

border 边框 content 内容 padding内边距 margin外边距 1.边框 border 边框粗细 用px作为单位 border-style : solid 实线的 dashed虚线的 dotted 点的 边框的符合写法: 那三个没有先后顺序 边框可以分开写 表格的细线边框 border-collapse …

【编程入门】开源记事本(微信小程序版)

背景 前面已输出多个系列: 《十余种编程语言做个计算器》 《十余种编程语言写2048小游戏》 《17种编程语言10种排序算法》 《十余种编程语言写博客系统》 《十余种编程语言写云笔记》 本系列对比云笔记,将更为简化,去掉了网络调用&#xff0…

20230126英语学习

Your Dog’s Behavior Is a Product of Their Genes 狗狗做什么,基因来决定 这篇好难,字基本都认识,但它不认识我~ “Identification of the genes behind dog behavior has historically been challenging,” says first author Emily Dut…

【计算机网络(考研版)】第一站:计算机网络概述(一)

目录 一、计算机网络的概念 1.计算机网络的定义 2.计算机网络的组成 3.计算机网络的功能 4.计算机网络的分类 二、计算机网络的性能指标 1.速率 2.带宽 3.时延 4.时延带宽积 5.往返时间 6.利用率 三、计算机网络的体系结构 1.体系结构 2.协议 3.服务 4.接口&a…

活动星投票优秀支书网络评选微信的投票方式线上免费投票

“优秀支书”网络评选投票_多人投票流程顺序_视频投票图文投票_微信比赛投票小程序近些年来,第三方的微信投票制作平台如雨后春笋般络绎不绝。随着手机的互联网的发展及微信开放平台各项基于手机能力的开放,更多人选择微信投票小程序平台,因为…

最详细、最仔细、最清晰的几道python习题及答案(建议收藏哦)

名字:阿玥的小东东 学习:python。c 主页:没了 今天阿玥带大家来看看更详细的python的练习题 目录 1. 在python中, list, tuple, dict, set有什么区别, 主要应用在什么样的场景? 2. 静态函数, 类函数, 成员函数、属性函数的区别? 2.1静态…

Unix\Linux多线程复健(二)线程同步

线程同步 并非让线程并行,而是有先后的顺序执行,当有一个线程对内存操作时,其他线程不可以对这个内存地址操作 线程之间的分工合作 线程的优势之一:能够通过全局变量共享信息 临界区:访问某一共享资源的代码片段&#…

【JavaEE初阶】第六节.多线程 (基础篇 )线程安全问题(下篇)

前言 一、内存可见性 二、内存可见性的解决办法 —— volatile关键字 三、wait 和notify 关键字 3.1 wait() 方法 3.2 notify() 方法 3.3 notify All() 方法 3.4 wait 和 sleep 的对比 总结 前言 本节内容接上小节有关线程安全问题;本节内容我们将介绍有关…