ccc-pytorch-卷积神经网络实战(6)

news2025/1/21 18:06:16

文章目录

      • 一、CIFAR10 与 lenet5
      • 二、CIFAR10 与 ResNet

一、CIFAR10 与 lenet5

image-20230305193723321
第一步:准备数据集
lenet5.py

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms

def main():
    batchsz = 128

    CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)

    CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)

    x,label = iter(cifar_train).next()
    print('x',x.shape,'label:',label.shape)

if __name__ =='__main__':
    main()

image-20230306214402277

第二步:确认Lenet5网络流程结构
main.py

import torch
from torch import nn
from torch.nn import functional as F

class Lenet5(nn.Module):
    def __init__(self):
        super(Lenet5, self).__init__()

        self.conv_unit = nn.Sequential(
            # x: [b, 3, 32, 32] => [b, 6, ]
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            #
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        )
        self.fc_unit = nn.Sequential(
            nn.Linear(2,120), # 由输出结果反推(拉直打平)
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )
        #[b,3,32,32]
        tmp = torch.randn(2, 3, 32, 32)
        out = self.conv_unit(tmp)
        #[2,16,5,5]   由输出结果得到
        print('conv out:', out.shape)


def main():
    net = Lenet5()

if __name__ == '__main__':
    main()


第三步:完善lenet5 结构并使用GPU加速
lenet5.py

import torch
from torch import nn
from torch.nn import functional as F

class Lenet5(nn.Module):
    def __init__(self):
        super(Lenet5, self).__init__()

        self.conv_unit = nn.Sequential(
            # x: [b, 3, 32, 32] => [b, 6, ]
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            #
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        )
        self.fc_unit = nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )
        #[b,3,32,32]
        tmp = torch.randn(2, 3, 32, 32)
        out = self.conv_unit(tmp)
        #[b,16,5,5]
        print('conv out:', out.shape)

    def forward(self,x):
        batchsz = x.size(0)
        # [b, 3, 32, 32] => [b, 16, 5, 5]
        x = self.conv_unit(x)
        #[b, 16, 5, 5] => [b,16*5*5]
        x = x.view(batchsz,16*5*5)
        # [b, 16*5*5] => [b, 10]
        logits = self.fc_unit(x)
        pred = F.softmax(logits,dim=1)
        return logits

def main():
    net = Lenet5()
    tmp = torch.randn(2, 3, 32, 32)
    out = net(tmp)
    print('lenet out:', out.shape)

if __name__ == '__main__':
    main()

main.py

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from    torch import nn, optim

def main():
    batchsz = 128

    CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)

    CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)

    x,label = iter(cifar_train).next()
    print('x',x.shape,'label:',label.shape)

    device = torch.device('cuda')
    model = Lenet5().to(device)

    print(model)

if __name__ =='__main__':
    main()

image-20230307210738450
第四步:计算交叉熵和准确率,完成迭代

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from lenet5 import Lenet5
from    torch import nn, optim

def main():
    batchsz = 128

    CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)

    CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)

    x,label = iter(cifar_train).next()
    print('x',x.shape,'label:',label.shape)

    device = torch.device('cuda')
    model = Lenet5().to(device)

    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    print(model)

    for epoch in range(1000):

        for batchidx, (x,label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            # [b]
            x,label = x.to(device),label.to(device)
            logits = model(x)
            # logits: [b, 10]
            # label:  [b]
            loss = criteon(logits,label)
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch,'loss:',loss.item())

        model.eval()
        with torch.no_grad(): #之后代码不需backprop

            total_correct = 0
            total_num = 0
            for x ,label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x,label = x.to(device),label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                total_correct += torch.eq(pred,label).float().sum()
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch,acc)

if __name__ =='__main__':
    main()

image-20230307212056887
注意事项:

  • 之所以在 测试时 添加 model.eval()是因为eval()时,BN会使用之前计算好的值,并且停止使用DropOut。保证用全部训练的均值和方差

二、CIFAR10 与 ResNet

img
第一步:构建ResNet18的网络结构
ResNet.py

import torch
from torch import  nn
from torch.nn import functional as F

class ResBlk(nn.Module):

    def __init__(self,ch_in,ch_out,stride=1):

        super(ResBlk,self).__init__()
        self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        #[b, ch_in, h, w] = > [b, ch_out, h, w]
        out = self.extra(x) + out
        out = F.relu((out))
        return out

class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),
            nn.BatchNorm2d(64)
        )
        # followed 4 blocks
        # [b, 64, h, w] => [b, 128, h ,w]
        self.blk1 = ResBlk(64,128)
        # [b, 128, h, w] => [b, 256, h ,w]
        self.blk2 = ResBlk(128,256)
        # [b, 256, h, w] => [b, 512, h ,w]
        self.blk3 = ResBlk(256,512)
        # [b, 512, h, w] => [b, 1024, h ,w]
        self.blk4 = ResBlk(512,512)

        self.outlayer = nn.Linear(512*1*1,10)

    def forward(self,x):
        x = F.relu(self.conv1(x))

        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        print('after conv:', x.shape)
        # [b, 512, h, w] => [b, 512, 1, 1]
        x = F.adaptive_avg_pool2d(x, [1, 1])
        print('after pool:', x.shape)
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)

        return x

def main():
    blk = ResBlk(64,128,stride=2)
    tmp = torch.randn(2,64,32,32)
    out = blk(tmp)
    print('block:',out.shape)
    x = torch.randn(2,3,32,32)
    model  = ResNet18()
    out = model(x)
    print('resnet:',out.shape)

if __name__ == '__main__':
    main()

第二步:代入第一个项目的main函数中即可
main.py

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from resnet import ResNet18
from    torch import nn, optim

def main():
    batchsz = 128

    CIFAR_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_train = DataLoader(CIFAR_train, batch_size=batchsz, shuffle=True)

    CIFAR_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ]), download=True)
    cifar_test = DataLoader(CIFAR_test, batch_size=batchsz, shuffle=True)

    x,label = iter(cifar_train).next()
    print('x',x.shape,'label:',label.shape)

    device = torch.device('cuda')
    model = ResNet18().to(device)


    criteon = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    print(model)

    for epoch in range(1000):

        for batchidx, (x,label) in enumerate(cifar_train):
            # [b, 3, 32, 32]
            # [b]
            x,label = x.to(device),label.to(device)
            logits = model(x)
            # logits: [b, 10]
            # label:  [b]
            loss = criteon(logits,label)
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch,'loss:',loss.item())

        model.eval()
        with torch.no_grad(): #之后代码不需backprop

            total_correct = 0
            total_num = 0
            for x ,label in cifar_test:
                # [b, 3, 32, 32]
                # [b]
                x,label = x.to(device),label.to(device)
                logits = model(x)
                pred = logits.argmax(dim=1)
                total_correct += torch.eq(pred,label).float().sum()
                total_num += x.size(0)
            acc = total_correct / total_num
            print(epoch,acc)

if __name__ =='__main__':
    main()

网络结构如下:
image-20230308192349803
迭代准确率和交叉熵计算如下:
image-20230308193023625
其他需要注意的地方:

  • 并不是ResNet的paper中流程完全相同,但是十分类似
  • 可以对数据进行数据增强和归一化等操作进一步提升效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/397445.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于嵌入式libxml2的ARM64平台的移植(aarch64)

由于libxml在移植过程中依赖于zlib的库文件,因此本节内容包含zlib(V1.2.13)的移植libxml2(V2.10.3)的移植两部分组成。 (一)zlib的移植(基于arm64) 1、在github上下载zlib的最新源码压缩包&am…

【C++的OpenCV】第十课-OpenCV图像常用操作(七):直方图和直方图同等化(直方图均衡化)

🎉🎉🎉欢迎各位来到小白piao的学习空间!\color{red}{欢迎各位来到小白piao的学习空间!}欢迎各位来到小白piao的学习空间!🎉🎉🎉 💖💖&#x1f496…

看完书上的链表还不会实现?不进来看看?

1.1链表的概念定义:链表是一种物理存储上非连续,数据元素的逻辑顺序通过链表中的指针链接次序,实现的一种线性存储结构。特点:链表由一系列节点组成,节点在运行时动态生成 (malloc),…

【react】类组件

React类创建组件&#xff0c;通过继承React内置的Component来实现的 class MyComponent extends React.Component{render() {console.log(this)// render是放在哪里的 —— 类(即&#xff1a;MyComponent)的原型对象上&#xff0c;供实例使用return <h2>我是用函数定义的…

python实现波士顿房价预测

波士顿房价预测 目标 这是一个经典的机器学习回归场景&#xff0c;我们利用Python和numpy来实现神经网络。该数据集统计了房价受到13个特征因素的影响&#xff0c;如图1所示。 对于预测问题&#xff0c;可以根据预测输出的类型是连续的实数值&#xff0c;还是离散值&#xff…

QGraphicsItem的简单自定义图形项

QGraphicsItem的继承重写序言重点函数QRectF boundingRect() constQPainterPath shape() constvoid paint(QPainter *painter, const QStyleOptionGraphicsItem *option, QWidget *widget 0)序言 学习途中记录一下&#xff0c;可谓是精华点 重点函数 QRectF boundingRect()…

农产品销售系统/商城,可运行

文章目录项目介绍一、项目功能介绍1、用户模块主要功能包括&#xff1a;2、商家模块主要功能包括&#xff1a;3、管理员模块主要功能包括&#xff1a;二、部分页面展示1、用户模块部分功能页面展示2、商家模块部分功能页面展示3、管理员模块部分功能页面展示三、部分源码四、底…

蓝牙 - 设备类型设置: Class of Device

在电脑或手机上&#xff0c;搜寻和连接蓝牙设备时&#xff0c;不同的蓝牙设备显示的图标是不同的&#xff0c;比如搜到或连接上的设备是一个蓝牙键盘&#xff0c;显示的就会是键盘图标&#xff0c;如果搜索到的设备是一个手柄&#xff0c;显示的就是一个手柄图标。 显示的图标是…

进程(操作系统408)

进程的概念和特征 概念&#xff1a; 进程的多个定义&#xff1a; 进程是程序的一次执行过程 进程是一个程序及其数据在处理机上顺序执行时所发生的活动 进程时具有独立功能的程序在一个数据集合上运行的过程&#xff0c;它是系统进行资源分配和调度的一个独立单位 上面所说…

JVM的基本知识

JVM JVM是java的虚拟机,是一个十分复杂的东西,所以掌握的要求比较高.本文主要是研究JVM的三大话题 JVM内存划分JVM类加载JVM的垃圾回收 JVM内存划分 java程序要执行的时候,JVM会先申请一块空间,这里就涉及到JVM的内存划分 堆 : 放的是new 出来的对象栈: 放的是方法之间的调…

rabbitmq集群-镜像模式

上文参考&#xff1a; rabbitmq集群-普通模式 1. 什么是镜像模式 它和普通集群最大的区别在于 Queue 数据和原数据不再是单独存储在一台机器上&#xff0c;而是同时存储在多台机器上。也就是说每个 RabbitMQ 实例都有一份镜像数据&#xff08;副本数据&#xff09;。每次写入…

3月8号作业

题目&#xff1a;题目一&#xff1a;vmlinux可执行文件如何产生题目二&#xff1a;整理内核编译流程&#xff1a;uImage&#xff0c;zImage,Image,vmlinux之间的关系答案一&#xff1a;在内核源码目录下vi Makefile&#xff0c;搜索vmlinux目标&#xff0c;vmlinux: scripts/li…

MongoDB学习(java版)

MongoDB概述 结构化数据库 ​ 结构化数据库是一种使用结构化查询语言&#xff08;SQL&#xff09;进行管理和操作的数据库&#xff0c;它们的数据存储方式是基于表格和列的。结构化数据库要求数据预先定义数据模式和结构&#xff0c;然后才能存储和查询数据。结构化数据库通常…

Android Camera SDK NDK NDK_vendor介绍

Android Camera JNI NDK NDK_vendor介绍前言主要有哪几种interface&#xff1f;Android SDKCamera API 1Camera API 2小结Android NDKNDK InterfaceNDK Vendor Interface小结Camera VTS Testcase总结Reference前言 本篇博客是想介绍Android camera从application layer到camera…

谷歌插件Fetch在不同页面之间Cookie携带情况详解

content script 和 script inject 表现情况 在碰到content script 注入和用script标签注入一样&#xff0c;即使服务端有写入Cookie到域名下在该tab标签应用下也不会被保存&#xff0c;所以在发送时也无法自动携带&#xff0c;所以通过content script和<script>这种方式…

微信小程序第二节 —— 自定义组件。

&#x1f449;微信小程序第一节 —— 自定义顶部、底部导航栏以及获取胶囊体位置信息。 一、前言 &#x1f4d6;&#x1f4d6;&#x1f4d6;书接上回 &#xff0c;dai ga hou啊&#xff01;我是 &#x1f618;&#x1f618;&#x1f618;是江迪呀。在进行微信小程序开发中&am…

多维数组的地址,通过指针引用多维数组详解

通过指针引用一维数组可以参考这篇文章&#xff1a; 通过指针引用数组的几种方法的原理和差异&#xff1b;以及利用指针引用数组元素的技巧_juechen333的博客-CSDN博客一个数组包含若干元素&#xff0c;每个数组元素都占用存储单元&#xff0c;所以他们都有相应的地址&#xf…

《Ansible模块篇:debug模块详解》

一、简介 平时我们在使用ansible执行playbook时&#xff0c;经常可能会遇到一些错误&#xff0c;有的时候不知道问题在哪里 &#xff0c;这个时候可以使用-vvv参数打印出来详细信息&#xff0c;不过很多时候-vvv参数里很多东西并不是我们想要的&#xff0c;这时候就可以使用官方…

python第四天作业~函数练习

目录 作业4、判断以下哪些不能作为标识符 A、a B、&#xffe5;a C、_12 D、$a12 E、false F、False 作业5&#xff1a; 输入数&#xff0c;判断这个数是否是质数&#xff08;要求使用函数 for循环&#xff09; 作业6&#xff1a;求50~150之间的质数是…

ReentrantLock 源码解读

一、ReentrantLock ReentrantLock 是 java JUC 中的一个可重入锁&#xff0c;在上篇文章讲解 AQS 源码的时候提到 ReentrantLock 锁是基于 AQS 实现的&#xff0c;那是如何使用的 AQS 呢&#xff0c;本篇文章一起带大家看下 ReentrantLock 的源码。 在 AQS 中&#xff0c;如果…