chap5 CNN

news2025/1/20 17:10:29

卷积神经网络(CNN)

问题描述:

利用卷积神经网络,实现对MNIST数据集的分类问题

数据集:

MNIST数据集包括60000张训练图片和10000张测试图片。图片样本的数量已经足够训练一个很复杂的模型(例如 CNN的深层神经网络)。它经常被用来作为一个新 的模式识别模型的测试用例。而且它也是一个方便学生和研究者们执行用例的数据集。除此之外,MNIST数据集是一个相对较小的数据集,可以在你的笔记本CPUs上面直接执行

题目要求

Pytorch版本的卷积神经网络需要补齐self.conv1中的nn.Conv2d()self.conv2()的参数,还需要填写x=x.view()中的内容。
训练精度应该在96%以上。

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import torch.nn.functional as F
import numpy as np
learning_rate = 1e-4
keep_prob_rate = 0.7 #
max_epoch = 3
BATCH_SIZE = 50

DOWNLOAD_MNIST = False
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
    # not mnist dir or mnist is empyt dir
    DOWNLOAD_MNIST = True


train_data = torchvision.datasets.MNIST(root='./mnist/',train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset = train_data ,batch_size= BATCH_SIZE ,shuffle= True)

test_data = torchvision.datasets.MNIST(root = './mnist/',train = False)
test_x = Variable(torch.unsqueeze(test_data.test_data,dim  = 1),volatile = True).type(torch.FloatTensor)[:500]/255.
test_y = test_data.test_labels[:500].numpy()

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d( # ???
                # patch 7 * 7 ; 1  in channels ; 32 out channels ; ; stride is 1
                # padding style is same(that means the convolution opration's input and output have the same size)
                in_channels=1,
                out_channels=32,
                kernel_size=7,
                stride=1,
                padding=3,
            ),
            nn.ReLU(),        # activation function
            nn.MaxPool2d(2),  # pooling operation
        )
        self.conv2 = nn.Sequential( # ???
            # line 1 : convolution function, patch 5*5 , 32 in channels ;64 out channels; padding style is same; stride is 1
            # line 2 : choosing your activation funciont
            # line 3 : pooling operation function.
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2, stride=1),
            nn.ReLU(),
            nn.AvgPool2d(2),
        )
        self.out1 = nn.Linear( 7*7*64 , 1024 , bias= True)   # full connection layer one

        self.dropout = nn.Dropout(keep_prob_rate)
        self.out2 = nn.Linear(1024,10,bias=True)



    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 7*7*64)  # flatten the output of coonv2 to (batch_size ,32 * 7 * 7)    # ???
        out1 = self.out1(x)
        out1 = F.relu(out1)
        out1 = self.dropout(out1)
        out2 = self.out2(out1)
        output = F.softmax(out2)
        return output


def test(cnn):
    global prediction
    y_pre = cnn(test_x)
    _,pre_index= torch.max(y_pre,1)
    pre_index= pre_index.view(-1)
    prediction = pre_index.data.numpy()
    correct  = np.sum(prediction == test_y)
    return correct / 500.0


def train(cnn):
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate )
    loss_func = nn.CrossEntropyLoss()
    for epoch in range(max_epoch):
        for step, (x_, y_) in enumerate(train_loader):
            x ,y= Variable(x_),Variable(y_)
            output = cnn(x)  
            loss = loss_func(output,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if step != 0 and step % 20 ==0:
                print("=" * 10,step,"="*5,"="*5, "test accuracy is ",test(cnn) ,"=" * 10 )

if __name__ == '__main__':
    cnn = CNN()
    train(cnn)

训练结果为:

========== 20 ===== ===== test accuracy is  0.224 ==========
========== 40 ===== ===== test accuracy is  0.362 ==========
========== 60 ===== ===== test accuracy is  0.402 ==========
========== 80 ===== ===== test accuracy is  0.51 ==========
========== 100 ===== ===== test accuracy is  0.608 ==========
========== 120 ===== ===== test accuracy is  0.624 ==========
========== 140 ===== ===== test accuracy is  0.708 ==========
========== 160 ===== ===== test accuracy is  0.684 ==========
========== 180 ===== ===== test accuracy is  0.738 ==========
========== 200 ===== ===== test accuracy is  0.766 ==========
========== 220 ===== ===== test accuracy is  0.778 ==========
========== 240 ===== ===== test accuracy is  0.796 ==========
========== 260 ===== ===== test accuracy is  0.802 ==========
========== 280 ===== ===== test accuracy is  0.81 ==========
========== 300 ===== ===== test accuracy is  0.812 ==========
========== 320 ===== ===== test accuracy is  0.82 ==========
========== 340 ===== ===== test accuracy is  0.848 ==========
========== 360 ===== ===== test accuracy is  0.83 ==========
========== 380 ===== ===== test accuracy is  0.852 ==========
========== 400 ===== ===== test accuracy is  0.852 ==========
========== 420 ===== ===== test accuracy is  0.856 ==========
========== 440 ===== ===== test accuracy is  0.874 ==========
========== 460 ===== ===== test accuracy is  0.85 ==========
========== 480 ===== ===== test accuracy is  0.874 ==========
========== 500 ===== ===== test accuracy is  0.864 ==========
========== 520 ===== ===== test accuracy is  0.858 ==========
========== 540 ===== ===== test accuracy is  0.884 ==========
========== 560 ===== ===== test accuracy is  0.872 ==========
========== 580 ===== ===== test accuracy is  0.9 ==========
========== 600 ===== ===== test accuracy is  0.88 ==========
========== 620 ===== ===== test accuracy is  0.886 ==========
========== 640 ===== ===== test accuracy is  0.882 ==========
========== 660 ===== ===== test accuracy is  0.886 ==========
========== 680 ===== ===== test accuracy is  0.876 ==========
========== 700 ===== ===== test accuracy is  0.882 ==========
========== 720 ===== ===== test accuracy is  0.886 ==========
========== 740 ===== ===== test accuracy is  0.894 ==========
========== 760 ===== ===== test accuracy is  0.894 ==========
========== 780 ===== ===== test accuracy is  0.9 ==========
========== 800 ===== ===== test accuracy is  0.898 ==========
========== 820 ===== ===== test accuracy is  0.912 ==========
========== 840 ===== ===== test accuracy is  0.894 ==========
========== 860 ===== ===== test accuracy is  0.898 ==========
========== 880 ===== ===== test accuracy is  0.888 ==========
========== 900 ===== ===== test accuracy is  0.896 ==========
========== 920 ===== ===== test accuracy is  0.888 ==========
========== 940 ===== ===== test accuracy is  0.91 ==========
========== 960 ===== ===== test accuracy is  0.908 ==========
========== 980 ===== ===== test accuracy is  0.918 ==========
========== 1000 ===== ===== test accuracy is  0.906 ==========
========== 1020 ===== ===== test accuracy is  0.908 ==========
========== 1040 ===== ===== test accuracy is  0.906 ==========
========== 1060 ===== ===== test accuracy is  0.914 ==========
========== 1080 ===== ===== test accuracy is  0.908 ==========
========== 1100 ===== ===== test accuracy is  0.906 ==========
========== 1120 ===== ===== test accuracy is  0.906 ==========
========== 1140 ===== ===== test accuracy is  0.924 ==========
========== 1160 ===== ===== test accuracy is  0.918 ==========
========== 1180 ===== ===== test accuracy is  0.904 ==========
========== 20 ===== ===== test accuracy is  0.924 ==========
========== 40 ===== ===== test accuracy is  0.908 ==========
========== 60 ===== ===== test accuracy is  0.92 ==========
========== 80 ===== ===== test accuracy is  0.91 ==========
========== 100 ===== ===== test accuracy is  0.926 ==========
========== 120 ===== ===== test accuracy is  0.91 ==========
========== 140 ===== ===== test accuracy is  0.922 ==========
========== 160 ===== ===== test accuracy is  0.932 ==========
========== 180 ===== ===== test accuracy is  0.932 ==========
========== 200 ===== ===== test accuracy is  0.93 ==========
========== 220 ===== ===== test accuracy is  0.94 ==========
========== 240 ===== ===== test accuracy is  0.918 ==========
========== 260 ===== ===== test accuracy is  0.934 ==========
========== 280 ===== ===== test accuracy is  0.93 ==========
========== 300 ===== ===== test accuracy is  0.934 ==========
========== 320 ===== ===== test accuracy is  0.934 ==========
========== 340 ===== ===== test accuracy is  0.93 ==========
========== 360 ===== ===== test accuracy is  0.944 ==========
========== 380 ===== ===== test accuracy is  0.938 ==========
========== 400 ===== ===== test accuracy is  0.92 ==========
========== 420 ===== ===== test accuracy is  0.936 ==========
========== 440 ===== ===== test accuracy is  0.948 ==========
========== 460 ===== ===== test accuracy is  0.934 ==========
========== 480 ===== ===== test accuracy is  0.938 ==========
========== 500 ===== ===== test accuracy is  0.916 ==========
========== 520 ===== ===== test accuracy is  0.916 ==========
========== 540 ===== ===== test accuracy is  0.928 ==========
========== 560 ===== ===== test accuracy is  0.936 ==========
========== 580 ===== ===== test accuracy is  0.942 ==========
========== 600 ===== ===== test accuracy is  0.922 ==========
========== 620 ===== ===== test accuracy is  0.94 ==========
========== 640 ===== ===== test accuracy is  0.94 ==========
========== 660 ===== ===== test accuracy is  0.96 ==========
========== 680 ===== ===== test accuracy is  0.938 ==========
========== 700 ===== ===== test accuracy is  0.936 ==========
========== 720 ===== ===== test accuracy is  0.94 ==========
========== 740 ===== ===== test accuracy is  0.946 ==========
========== 760 ===== ===== test accuracy is  0.946 ==========
========== 780 ===== ===== test accuracy is  0.948 ==========
========== 800 ===== ===== test accuracy is  0.95 ==========
========== 820 ===== ===== test accuracy is  0.948 ==========
========== 840 ===== ===== test accuracy is  0.95 ==========
========== 860 ===== ===== test accuracy is  0.94 ==========
========== 880 ===== ===== test accuracy is  0.956 ==========
========== 900 ===== ===== test accuracy is  0.944 ==========
========== 920 ===== ===== test accuracy is  0.948 ==========
========== 940 ===== ===== test accuracy is  0.95 ==========
========== 960 ===== ===== test accuracy is  0.944 ==========
========== 980 ===== ===== test accuracy is  0.94 ==========
========== 1000 ===== ===== test accuracy is  0.946 ==========
========== 1020 ===== ===== test accuracy is  0.952 ==========
========== 1040 ===== ===== test accuracy is  0.952 ==========
========== 1060 ===== ===== test accuracy is  0.944 ==========
========== 1080 ===== ===== test accuracy is  0.956 ==========
========== 1100 ===== ===== test accuracy is  0.96 ==========
========== 1120 ===== ===== test accuracy is  0.948 ==========
========== 1140 ===== ===== test accuracy is  0.942 ==========
========== 1160 ===== ===== test accuracy is  0.948 ==========
========== 1180 ===== ===== test accuracy is  0.944 ==========
========== 20 ===== ===== test accuracy is  0.952 ==========
========== 40 ===== ===== test accuracy is  0.96 ==========
========== 60 ===== ===== test accuracy is  0.948 ==========
========== 80 ===== ===== test accuracy is  0.954 ==========
========== 100 ===== ===== test accuracy is  0.948 ==========
========== 120 ===== ===== test accuracy is  0.948 ==========
========== 140 ===== ===== test accuracy is  0.958 ==========
========== 160 ===== ===== test accuracy is  0.942 ==========
========== 180 ===== ===== test accuracy is  0.948 ==========
========== 200 ===== ===== test accuracy is  0.952 ==========
========== 220 ===== ===== test accuracy is  0.952 ==========
========== 240 ===== ===== test accuracy is  0.95 ==========
========== 260 ===== ===== test accuracy is  0.966 ==========
========== 280 ===== ===== test accuracy is  0.96 ==========
========== 300 ===== ===== test accuracy is  0.956 ==========
========== 320 ===== ===== test accuracy is  0.96 ==========
========== 340 ===== ===== test accuracy is  0.956 ==========
========== 360 ===== ===== test accuracy is  0.956 ==========
========== 380 ===== ===== test accuracy is  0.954 ==========
========== 400 ===== ===== test accuracy is  0.96 ==========
========== 420 ===== ===== test accuracy is  0.966 ==========
========== 440 ===== ===== test accuracy is  0.96 ==========
========== 460 ===== ===== test accuracy is  0.954 ==========
========== 480 ===== ===== test accuracy is  0.968 ==========
========== 500 ===== ===== test accuracy is  0.958 ==========
========== 520 ===== ===== test accuracy is  0.958 ==========
========== 540 ===== ===== test accuracy is  0.962 ==========
========== 560 ===== ===== test accuracy is  0.968 ==========
========== 580 ===== ===== test accuracy is  0.958 ==========
========== 600 ===== ===== test accuracy is  0.952 ==========
========== 620 ===== ===== test accuracy is  0.95 ==========
========== 640 ===== ===== test accuracy is  0.964 ==========
========== 660 ===== ===== test accuracy is  0.962 ==========
========== 680 ===== ===== test accuracy is  0.96 ==========
========== 700 ===== ===== test accuracy is  0.962 ==========
========== 720 ===== ===== test accuracy is  0.964 ==========
========== 740 ===== ===== test accuracy is  0.958 ==========
========== 760 ===== ===== test accuracy is  0.96 ==========
========== 780 ===== ===== test accuracy is  0.972 ==========
========== 800 ===== ===== test accuracy is  0.962 ==========
========== 820 ===== ===== test accuracy is  0.968 ==========
========== 840 ===== ===== test accuracy is  0.964 ==========
========== 860 ===== ===== test accuracy is  0.96 ==========
========== 880 ===== ===== test accuracy is  0.964 ==========
========== 900 ===== ===== test accuracy is  0.96 ==========
========== 920 ===== ===== test accuracy is  0.96 ==========
========== 940 ===== ===== test accuracy is  0.97 ==========
========== 960 ===== ===== test accuracy is  0.956 ==========
========== 980 ===== ===== test accuracy is  0.966 ==========
========== 1000 ===== ===== test accuracy is  0.964 ==========
========== 1020 ===== ===== test accuracy is  0.964 ==========
========== 1040 ===== ===== test accuracy is  0.97 ==========
========== 1060 ===== ===== test accuracy is  0.974 ==========
========== 1080 ===== ===== test accuracy is  0.962 ==========
========== 1100 ===== ===== test accuracy is  0.97 ==========
========== 1120 ===== ===== test accuracy is  0.974 ==========
========== 1140 ===== ===== test accuracy is  0.978 ==========
========== 1160 ===== ===== test accuracy is  0.976 ==========
========== 1180 ===== ===== test accuracy is  0.974 ==========

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1719496.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Visual Studio Code使用(C++项目新建,运行)

VS Code 直接在官网下载安装。 接下来安装插件,下图是C所需的对应插件 1.新建项目 VS Code下载安装完成后,直接进入欢迎页: 在访达/文件夹中新建一个文件夹,欢迎页点击【打开】,选择刚刚新建的文件夹。点击第一个图…

MT8781安卓核心板_MTK联发科Helio G99核心板规格参数

MT8781安卓核心板采用先进的台积电6纳米级芯片生产工艺,配备高性能Arm Cortex-A76处理器和Arm Mali G57 GPU,加上LPDDR4X内存和UFS 2.2存储,在处理速度和数据访问速度上都有着出色的表现。 MT8781还支持120Hz显示器,无需额外的DSC…

vue3学习(六)

前言 接上一篇学习笔记,今天主要是抽空学习了vue的状态管理,这里学习的是vuex,版本4.1。学习还没有学习完,里面有大坑,难怪现在官网出的状态管理用Pinia。 一、vuex状态管理知识点 上面的方式没有写全,还有…

QT软件界面的设计与启动方法

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、Q T界面设计的重要性 二、QT软件的启动与配置 三、QT软件的启动路径设置 四、QT软件启…

Linux--EXT2文件系统

参考资料: linux之EXT2文件系统--理解block/block group/索引结点inode/索引位图_一个块组中索引节点表和数据块区最多占用字节-CSDN博客 linux环境: Linux version 5.15.146.1-microsoft-standard-WSL2 (root65c757a075e2) (gcc (GCC) 11.2.0, GNU ld…

Llama改进之——分组查询注意力

引言 今天介绍LLAMA2模型引入的关于注意力的改进——分组查询注意力(Grouped-query attention,GQA)1。 Transformer中的多头注意力在解码阶段来说是一个性能瓶颈。多查询注意力2通过共享单个key和value头,同时不减少query头来提升性能。多查询注意力可能导致质量下…

C++双层Vector容器详解

双层Vector容器 关于C中二维vector使用 双层vector的运用细节 插入元素 //正确的插入方式 vector<vector<int> > A; //A.push_back里必须是vector vector<int> B; B.push_back(0); B.push_back(1); B.push_back(2); A.push_back(B); B.clear(); B.push_back…

AI边缘计算盒子在智慧交通的应用

方案背景 随着经济增长&#xff0c;交通出行需求大幅增长&#xff0c;但道路建设增长缓慢&#xff0c;交通供需矛盾日益显著&#xff0c;中心城区主要道路高峰时段交通拥堵严重&#xff0c;道路交通拥堵逐渐常态化&#xff0c;成为制约城市可持续发展的重要因素之一。 痛点问题…

python移位操作符(左移位操作符<<、右移位操作符>>)(允许开发者对整数进行位操作,乘2或除2)(左移操作、右移操作)(位掩码操作|=)

文章目录 Python 中的移位操作符详解移位操作符简介左移位操作符 (<<)语法和使用示例代码输出 右移位操作符 (>>)语法和使用示例代码输出 移位操作符的应用场景快速乘除运算&#xff1a;使用移位操作符代替传统的乘法和除法运算&#xff0c;可以提高计算速度。位掩…

3位新加坡华人交易员分享:交易策略、风险管理与心态

交易与投资似乎是一对“双胞胎”,它们都是金融市场中获得收益的重要途径。 区别在于投资者购买自以为长期将有出色业绩的资产组合&#xff0c;并且长期持有这些资产组合&#xff0c;交易者依靠交易技巧借助资产工具价格瞬息波动在短期内产生利润。交易资产的手段有&#xff0c…

MySQL统计字符长度:CHAR_LENGTH(str)

对于SQL表&#xff0c;用于计算字符串中字符数的最佳函数是 CHAR_LENGTH(str)&#xff0c;它返回字符串 str 的长度。 另一个常用的函数 LENGTH(str) 在这个问题中也适用&#xff0c;因为列 content 只包含英文字符&#xff0c;没有特殊字符。否则&#xff0c;LENGTH() 可能会返…

unicloud 云对象

背景和优势 20年前&#xff0c;restful接口开发开始流行&#xff0c;服务器编写接口&#xff0c;客户端调用接口&#xff0c;传输json。 现在&#xff0c;替代restful的新模式来了。 云对象&#xff0c;服务器编写API&#xff0c;客户端调用API&#xff0c;不再开发传输json…

AI图书推荐:使用GitHub Copilot和ChatGPT辅助的Python编程

使用Python编写计算机程序变得更加简单了&#xff01;使用像GitHub Copilot和ChatGPT这样的AI辅助编码工具&#xff0c;将你的想法快速转化为应用程序。人工智能已经改变了我们编写计算机程序的方式。有了像Copilot和ChatGPT这样的工具&#xff0c;你可以用简单的英语描述你想要…

QT5:调用qt键盘组件实现文本框输入

目录 一、环境与目标 二、Qt VirtualKeyboard 1.勾选Qt VirtualKeyboard 2.ui设计流程 3.注意事项及问题点 三、参考代码 参考博客 一、环境与目标 qt版本&#xff1a;5.12.7 windows 11 下的 Qt Designer &#xff08;已搭建&#xff09; 目标&#xff1a;创建一个窗…

【Nacos源码分析01-服务注册与集群间数据是同步】

文章目录 了解CAPBASE理论Nacos支持CP还是AP集群数据同步实现集群数据一致性源码 了解CAP CAP理论的核心观点是&#xff0c;一个分布式系统无法同时完全满足一致性、可用性和分区容错性这三个特性。具体而言&#xff0c;当发生网络分区时&#xff0c;系统必须在一致性和可用性之…

【Vue】v-for中的key

文章目录 一、引入问题二、分析问题 一、引入问题 语法&#xff1a; key属性 "唯一值" 作用&#xff1a;给列表项添加的唯一标识。便于Vue进行列表项的正确排序复用。 为什么加key&#xff1a;Vue 的默认行为会尝试原地修改元素&#xff08;就地复用&#xff09;…

华媒舍:10种欧洲地区媒体发稿推广技巧

1.了解欧洲地区媒体自然环境必须掌握欧洲地区媒体的发稿推广方法&#xff0c;首先要对欧洲地区媒体自然环境有一定的了解。包含不一样国家的主力媒体&#xff0c;他的阅读者人群、销售市场遮盖及其报导风格等。仅有熟悉媒体自然环境&#xff0c;才能更好的制订营销推广策略。 …

【Unity Shader入门精要 第11章】让画面动起来(一)

1. Unity Shader中的时间变量 Shader控制这物体的显示&#xff0c;当向Shader中引入时间变量后&#xff0c;就可以让物体的显示效果随时间发生变化&#xff0c;以实现动画效果。 Unity中常见的时间变量如下表&#xff1a; 变量类型描述_Timefloat4(t/20, t, 2t, 3t)&#xf…

Visual Studio 2022创建dll并调用

需求&#xff1a; 创建A项目&#xff0c;有函数和类&#xff0c;将A项目生成DLL动态链接库 创建B项目&#xff0c;使用A项目生成的dll和lib相关文件 正常项目开发.h用于函数声明&#xff0c;.cpp用于函数实现&#xff0c;但是项目开发往往不喜欢将.cpp函数实现的代码发给别人&…

git使用流程与规范

原文网址&#xff1a;git代码提交流程与规范-CSDN博客 简介 本文git提交流程与规范是宝贵靠谱的经验&#xff0c;它能解决如下问题&#xff1a; 分支差距过大&#xff0c;导致合代码无数的冲突合完代码后发现代码丢失分支不清晰&#xff0c;无法追溯问题合代码耗时很长&…