Pytorch:搭建卷积神经网络完成MNIST分类任务

news2024/10/4 2:27:47

2023.7.18

MNIST百科:

MNIST数据集简介与使用_bwqiang的博客-CSDN博客

数据集官网:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集获取并转换成图片格式:

数据集将按以图片和文件夹名为标签的形式保存:

 代码:下载mnist数据集并转还为图片


import os
from PIL import Image
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化
])

# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(root=os.getcwd(), train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=os.getcwd(), train=False, transform=transform, download=True)

# 路径
train_path = './images/train'
test_path = './images/test'

# 将训练集中的图像保存为图片
for i in range(10):
    file_name = train_path + os.sep + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

for i in range(10):
    file_name = test_path + os.sep + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

for i, (image, label) in enumerate(train_dataset):
    train_label = label
    image_path = f'images/train/{train_label}/{i}.png'
    image = image.squeeze().numpy()  # 去除通道维度,并转换为 numpy 数组
    image = (image * 0.5) + 0.5  # 反标准化,将范围调整为 [0, 1]
    image = (image * 255).astype('uint8')  # 将范围调整为 [0, 255],并转换为整数类型
    Image.fromarray(image).save(image_path)

# 将测试集中的图像保存为图片
for i, (image, label) in enumerate(test_dataset):
    text_label = label
    image_path = f'images/test/{text_label}/{i}.png'
    image = image.squeeze().numpy()  # 去除通道维度,并转换为 numpy 数组
    image = (image * 0.5) + 0.5  # 反标准化,将范围调整为 [0, 1]
    image = (image * 255).astype('uint8')  # 将范围调整为 [0, 255],并转换为整数类型
    Image.fromarray(image).save(image_path)

 训练代码:


import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

# 调动显卡进行计算
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.names_list = []

        for dirs in os.listdir(self.root_dir):
            dir_path = self.root_dir + '/' + dirs
            for imgs in os.listdir(dir_path):
                img_path = dir_path + '/' + imgs
                self.names_list.append((img_path, dirs))

    def __len__(self):
        return len(self.names_list)

    def __getitem__(self, index):
        image_path, label = self.names_list[index]
        if not os.path.isfile(image_path):
            print(image_path + '不存在该路径')
            return None
        image = Image.open(image_path)

        label = np.array(label).astype(int)
        label = torch.from_numpy(label)

        if self.transform:
            image = self.transform(image)

        return image, label


# 定义卷积神经网络模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)  # 卷积
        x = self.relu(x)  # 激活函数
        x = self.maxpool(x)  # 最大值池化
        x = x.view(x.size(0), -1)
        x = self.fc(x)  # 全连接层
        return x


# 加载手写数字数据集
train_dataset = MyDataset('./dataset/images/train', transform=transforms.ToTensor())
val_dataset = MyDataset('./dataset/images/val', transform=transforms.ToTensor())

# 定义超参数
batch_size = 8192  # 批处理大小
learning_rate = 0.001  # 学习率
num_epochs = 30  # 迭代次数

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# 实例化模型、损失函数和优化器
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 优化器

# 记录验证的次数
total_train_step = 0
total_val_step = 0

# 模型训练和验证
print("-------------TRAINING-------------")
total_step = len(train_loader)
for epoch in range(num_epochs):
    print("Epoch=", epoch)
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        loss = criterion(output, labels.long())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        print("train_times:{},Loss:{}".format(total_train_step, loss.item()))

    # 测试验证
    total_val_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(val_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels.long())

            total_val_loss = total_val_loss + loss.item()  # 计算损失值的和
            accuracy = 0

            for j in labels:  # 计算精确度的和
                if outputs.argmax(1)[j] == labels[j]:
                    accuracy = accuracy + 1

            total_accuracy = total_accuracy + accuracy

    print('Accuracy =', float(total_accuracy / len(val_dataset)))  # 输出正确率
    torch.save(model, "cnn_{}.pth".format(epoch))  # 模型保存

# # 模型评估
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

测试代码:

import torch
from torchvision import transforms
import torch.nn as nn
import os
from PIL import Image

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 判断是否有GPU


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)  # 卷积
        x = self.relu(x)  # 激活函数
        x = self.maxpool(x)  # 最大值池化
        x = x.view(x.size(0), -1)
        x = self.fc(x)  # 全连接层
        return x


model = torch.load('cnn.pth')  # 加载模型

path = "./dataset/images/test/"  # 测试集

imgs = os.listdir(path)

test_num = len(imgs)
print(f"test_dataset_quantity={test_num}")

for img_name in imgs:
    img = Image.open(path + img_name)

    test_transform = transforms.Compose([transforms.ToTensor()])

    img = test_transform(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    outputs = model(img)  # 将图片输入到模型中
    _, predicted = outputs.max(1)

    pred_type = predicted.item()
    print(img_name, 'pred_type:', pred_type)

分类正确率不错:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/769696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

青岛大学_王卓老师【数据结构与算法】Week05_13_队列的顺序表示和实现1_学习笔记

本文是个人学习笔记,素材来自青岛大学王卓老师的教学视频。 一方面用于学习记录与分享, 另一方面是想让更多的人看到这么好的《数据结构与算法》的学习视频。 如有侵权,请留言作删文处理。 课程视频链接: 数据结构与算法基础…

DeepSpeed系列篇1:零门槛上手DeepSpeed实战(服务器部署及训练过程详解SFT)

1、建立虚拟环境 conda create -n dsnew python3.10 2、安装pytorch conda install pytorch torchvision torchaudio pytorch-cuda11.7 -c pytorch -c nvidia 3、安装deepspeed pip install deepspeed 4、下载DeepSpeedExamples并安装依赖 https://github.com/microsof…

一文了解DDD分层架构演进

1.3 分层架构演进 1.3.1 传统四层架构 将领域模型和业务逻辑分离出来,并减少对基础设施、用户界面甚至应用层逻辑的依赖,因为它们不属业务逻辑。将一个夏杂的系统分为不同的层,每层都应该具有良好的内聚性,并且只依赖于比其自身更…

arduinoIDE2.1.1最新版升级开发板(esp32-2.0.3升级2.0.10)方法总结(esp8266升级通用)

一、arduinoIDE 升级最新版 2.1.1方法 1.1.通过IDE2.x直接升级(推荐,速度还可以) 1.2.官网下载安装包覆盖升级(地址https://www.arduino.cc/en/software) 1.3 ESP8266升级方法雷同可参考(原理一样,最新好像是3.1.2) https://github.com/esp8266/Arduino/releases http…

C++入门知识点

目录 命名空间 命名空间定义 命名空间使用 法一:加命名空间名称及作用域限定符:: 法二:使用using部分展开(授权)某个命名空间中的成员 法三:使用using对整个命名空间全部展开(授权…

Windows修改mysql服务的root密码

目录 步骤1、停止mysql服务2、使用命令行启动mysql服务,跳过密码验证3、密码置空4、关闭命令行启动的mysql服务并正常启动5、修改root密码 参考 步骤 1、停止mysql服务 以管理员身份打开终端,输入指令net stop mysql停止MySQL服务,停止服务…

PBOOTCMS登录请求发生错误,您可按照如下方式排查: 1、试着删除根目录下runtime目录,刷新页面重试;2、检查系统会话文件存储目录是否具有写入权限;

PBOOTCMS后台登录请求发生错误,您可按照如下方式排查: 1、试着删除根目录下runtime目录,刷新页面重试;2、检查系统会话文件存储目录是否具有写入权限; 以上提示其实就是,表单提交校验失败,请刷新后重试的提…

U盘文件修复怎么做?简单3步,快速修复u盘文件!

“很离谱!由于有些文件存在错误,我想将这些错误文件修复,但在操作过程中,不知为什么所有的数据都被删除了。U盘文件修复应该怎么做呀?是不是我的操作方法有误呢?” U盘使用时间长了之后,很可能会…

【分布鲁棒、状态估计】分布式鲁棒优化电力系统状态估计研究[几种算法进行比较](Matlab代码实现)

💥1 概述 文献来源: 摘要: 能源市场的自由化、可再生能源的渗透、先进的计量能力以及对情境感知的需求,都要求进行系统范围的电力系统状态估计(PSSE)。然而,由于互联的复杂性、实时监测中的通信…

MySQL八股学习记录6-日志from小林coding

MySQL八股学习记录6-日志from小林coding MySQL日志分类undo logBuffer Poolredo logbinlogredo log 和undo log有什么区别主从复制是如何实现update语句执行过程为什么需要两阶段提交 MySQL日志分类 undo log:InnoDB存储引擎层生成的日志,实现事务中的原子性,主要用于事务回滚…

学习记录——SpectFormer、DilateFormer、ShadowFormer、MISSFormer

SpectFormer: Frequency and Attention is what you need in a Vision Transformer, arXiv2023 频域混合注意力SpectFormer 2023 论文:https://arxiv.org/abs/2304.06446 代码:https://badripatro.github.io/SpectFormers/ 摘要视觉变压器已经成功地应用…

No.2(4)——双指针解决柱子间最大面积

已知现在有几根柱子成有序排列,求出两根柱子之间围成面积的最大值。 不难想到,只需要将每两个柱子之间的面积计算一次并找出最大值,即可找到答案,但采用双指针法可以有效降低重复计算:从数组的两侧开始移动左右两个指针…

Elasticsearch SQL 详解

Elasticsearch SQL 是一个 X-Pack 组件,允许用户使用类似 SQL 的语法在 ES 中进行查询。用户可以在 REST、JDBC、命令行中使用 SQL 在 ES 执行数据检索和数据聚合操作。ES SQL 有以下几个特点: 本地集成,SQL 模块是 ES 自己构建的&#xff0…

数据库| 中国研究数据服务平台

数据哪里查,查不到,怎么办? 今天分享一个数据库|中国研究数据服务平台(CNRDS) 中国研究数据服务平台(Chinese Research Data Services,简称CNRDS),是上海经禾信息技术有…

HCIA|详解Telnet协议

一、前言 今天翻到了之前写的Telnet协议的实验,由于该篇文章创作于开始写作的初期,文章结构简单、布局潦草,但实验内容是完整的,因此本篇文章将对Telnet技术进行详解,希望能够对大家提供帮助。在本文中,将从…

ECharts is not Loaded -- echarts里china.json与china.js有何区别

echarts官方提示他们的地图json测绘不符合中国官方标准不提供下载 如下图 china.json china.js 可以很明显的看出地图山东与辽宁部分堆到一起的情况, 接下来换成china.js vue项目,要引入china.js,直接import引入会报错:ECharts is not Loa…

【SCI征稿】IEEE旗下中科院1区(TOP),有关计算机的广泛领域研究

期刊简介: 出版社:IEEE 影响因子:IF(2022)10.5-11.0 期刊分区:JCR1区,中科院1区(TOP) 检索情况:SCIE&EI 双检 自引率:13.20% 国人占比&…

在Illustrator中创建 3D 冰淇淋模型对象

推荐: NSDT场景编辑器助你快速搭建可二次开发的3D应用场景 一旦你学会了如何在Illustrator中制作一个对象3D,你可以前往Envato Elements,在那里你可以找到大量的3D设计来激发你的灵感。这个基于订阅的市场拥有超过 2,000 个 Illus…

国外广告联盟和国内广告联盟的优劣势是什么

国外广告联盟和国内广告联盟在一些方面存在一些差异和优劣势。以下是对比它们的一些常见优劣势: 一、国外广告联盟优势: 1、国际资源:国外广告联盟拥有更广泛的国际媒体资源,能够帮助广告主拓展全球市场,进一步提高国…

Arduino安装ESP32下载失败的解决方法

Arduino安装ESP32时,经常下载失败 解决办法: 1.复制命令行中的提示信息到记事本,找到下载地址 2.打开浏览器,在地址栏中贴粘下载地址,回车开始下载 3.将下载的包复制到C:\Users\Administrator\AppData\Local\Arduino…