卷积神经网络(六)---实现 cifar10 分类

news2024/11/28 4:26:31

        cifar10 数据集有60000张图片,每张图片的大小都是 32x32 的三通道的彩色图,一共是10种类别、每种类别有6000张图片,如图4.27所示。

图 4.27  cifar数据集

        使用前面讲过的残差结构来处理 cifar10 数据集,可以实现比较高的准确率。

        首先进行图像增强,使用前面介绍的增强方式。

train_transform = transforms.Compose([
    transforms.Scale(40),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

        注意只对训练图片进行图像增强,提高其泛化能力,对于测试集,仅对其中心化,不做其他的图像增强。

        下面先定义好 resnet 的基本模块。

def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False
    )


# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

        和前面介绍的内容一样,先定义残差模块,再将残差模块拼接起来,注意其中的维度变化。

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[0], 2)
        self.layer3 = self.make_layer(block, 64, layers[0], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
        layers = [block(self.in_channels, out_channels, stride, downsample)]
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def foward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

        最后在 cifar10 的数据集上跑100个 epoch,实现66.61%的训练集准确率,68%的验证集准确率,因为这里只跑了100次,所以还有一定的提升空间。同时使用更深的残差和更多的训练技巧能实现更好的实验结果,如图4.28所示。

        因为这里我是按照自己的想法写的普通版本的 cifar10 分类识别,所以准确率最后并不是很高,如果有人读懂了上面的方法,可以进行试一试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961673.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

配置本地开发服务器代理请求以及登录模块开发(二)

项目初始化完成之后,准备开始进行项目的开发,首先配置好开发环境作为整个项目的基础 一、配置代理 1、config/proxy.ts配置代理 export default {// 如果需要自定义本地开发服务器 请取消注释按需调整dev: {// localhost:8000/api/** -> https://p…

Seata 入门与实战

一、什么是 Seata Seata 是一款开源的分布式事务解决方式,致力于提供高性能和简单易用的分布式事务服务。Seata 为用户提供了 AT、TCC、SAGA 和 XA 事务模式,为用户打造一站式的分布式事务解决方案。 二、Seata 组成 事务协调者(Transacti…

什么是Shell?怎么编写和执行Shell脚本?

大家好呀!今天来简单介绍一下Shell基础,Shell介于内核与用户之间,是一个命令解释器,负责命令的解释。简单理解,Shell既是一个程序也是一种脚本语言。 1、shell介绍 1.1 概述 shell介于内核与用户之间,是一个…

索引结构—B+Tree索引、Hash索引、Full-Text(全文)索引、R-Tree(空间)索引

一、概述 在数据库系统中,索引是一种用于加快数据检索的数据结构。不同的索引结构适用于不同的查询场景和数据特性。索引按照不同角度可以划分不同类型的索引。按照数据结构可以划分BTree索引、Hash索引、FULL TEXT(全文)索引、R-Tree&#…

python inf是什么意思

INF / inf:这个值表示“无穷大 (infinity 的缩写)”,即超出了计算机可以表示的浮点数的范围(或者说超过了 double 类型的值)。例如,当用 0 除一个整数时便会得到一个1.#INF / inf值;相应的,如果…

卡码网KamaCoder 103. 水流问题

题目来源&#xff1a;103. 水流问题 C题解&#xff1a;从边界往高处走&#xff0c;走过的地方做标记。第一组边界跟第二组边界能走到的地方取交集。 代码来源代码随想录。&#xff08;虽然思路一样&#xff0c;但人家代码写得比我好哇&#xff09; #include <iostream>…

pyinstaller带浏览器一起打包playwright 独立运行exe

前置条件 没有安装自带环境&#xff0c;则 playwright install 安装了自带的浏览器 查看playwright的浏览器的位置 playwright install --dry-run 打开此文件夹可以看到 新建一个多层级目录playwright\driver\package.local-browsers 然后复制chromium-1124到playwright\dr…

听说它可以让代码更优雅

一提到静态代码检查工具这个词应该比较好理解&#xff0c;所谓静态代码检查工具就是检查静态代码的工具&#xff0c;完美~ 言归正传&#xff0c;相信很多程序员朋友都听说过静态代码检查工具这个概念&#xff0c;它可能是我们IDE里的某一个插件&#xff0c;可能是计算机中的一…

比 faster-whisper 至少快10倍的音视频转换文字

背景介绍 前两天我自己玩玩搞搞一个音频转文字服务&#xff0c;基于 faster-whisper&#xff0c;本想着这个已经是很快的了&#xff0c;没想到还有比它更快的&#xff0c;今天就来介绍使用一下。 FunClip&#xff0c;是阿里巴巴推出的一个智能视频剪辑工具&#xff0c;它结合…

计算机毕业设计选题推荐-某炼油厂盲板管理系统-Java/Python项目实战

✨作者主页&#xff1a;IT研究室✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…

[Bugku] web-CTF靶场详解!!!

平台为“山东安信安全技术有限公司”自研CTF/AWD一体化平台&#xff0c;部分赛题采用动态FLAG形式&#xff0c;避免直接抄袭答案。 平台有题库、赛事预告、工具库、Writeup库等模块。 ------------------------------- Simple_SSTI_1 启动环境&#xff1a; 页面提示传入参数f…

【Qt】QLCDNumberQProgressBarQCalendarWidget

目录 QLCDNumber 倒计时小程序 相关属性 QProgressBar 进度条小程序 相关设置 QLCDNumber QLCDNumber是Qt框架中用于显示数字或计数值的小部件。通常用于显示整数值&#xff0c;例如时钟、计时器、计数器等 常用属性 属性说明intValueQLCDNumber显示的初始值(int类型)va…

【全面介绍下Gitea,什么是Gitea?】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

这几个高级爬虫软件和插件真的强!

亮数据&#xff08;Bright Data&#xff09; 亮数据是一款强大的数据采集工具&#xff0c;以其全球代理IP网络和强大数据采集技术而闻名。它能够轻松采集各种网页数据&#xff0c;包括产品信息、价格、评论和社交媒体数据等。 网站&#xff1a;https://get.brightdata.com/we…

ubuntu安装并配置flameshot截图软件

参考&#xff1a;flameshot key-bindins 安装 sudo apt install flameshot自定义快捷键 Settings->Keyboard->View and Customize Shortcuts->Custom Shortcuts&#xff0c;输入该快捷键名称&#xff08;自定义&#xff09;&#xff0c;然后输入command&#xff08;…

RFID物流智能锁在物流锁控领域的意义与应用

在当今全球化和电子商务迅速发展的时代&#xff0c;物流行业作为经济的重要支撑&#xff0c;面临着日益增长的安全、效率和管理需求。物流锁控作为保障货物在运输过程中安全与完整的关键环节&#xff0c;传统的机械锁和简单电子锁已经难以满足现代物流复杂多变的业务场景。 一、…

前缀表达式(波兰式)和后缀表达式(逆波兰式)的计算方式

缀是指操作符。 1. 前缀表达式&#xff08;波兰式&#xff09; &#xff08;1&#xff09;不需用括号&#xff1b; &#xff08;2&#xff09;不用考虑运算符的优先级&#xff1b; &#xff08;3&#xff09;操作符置于操作数的前面。&#xff08;如 3 2 &#xff09; 1.1 中…

3.5.3、查找和排序算法-插入类排序和选择类排序

术语说明 稳定&#xff1a;如果a原本在b前面&#xff0c;而ab,排序之后a仍然在b的前面&#xff1b; 不稳定&#xff1a;如果a原本在b的前面&#xff0c;而ab,排序之后a可能会出现在b的后面&#xff1b; 例如&#xff1a;数组{1,2,3,3,4,7,6}。如果排序后&#xff0c;两个3的位…

【嵌入式之RTOS】死锁问题详解

目录 一、什么是死锁 二、产生死锁的四个必要条件 三、避免死锁的方法 四、实际应用中的考虑 一、什么是死锁 死锁&#xff08;Deadlock&#xff09;是多任务或多线程环境中一个常见的问题&#xff0c;尤其是在实时操作系统&#xff08;RTOS&#xff09;中&#xff0c;如果…

kvm虚拟化平台部署

kvm虚拟化平台部署 kvm概念简介 kvm自linux2.6版本以后就整合到内核中&#xff0c;因此可以看做是一个原生架构. kvm虚拟化架构 硬件底层提供物理层面的硬件支持 linux&#xff08;host&#xff09;&#xff0c;就相当于这个架构中的宿主机&#xff0c;上面运行了多个虚拟机。…