Pytorch:搭建卷积神经网络完成MNIST分类任务:

news2025/1/11 4:51:00

2023.7.18

MNIST百科:

MNIST数据集简介与使用_bwqiang的博客-CSDN博客

数据集官网:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集获取并转换成图片格式:

数据集将按以图片和文件夹名为标签的形式保存:

 代码:下载mnist数据集并转还为图片


import os
from PIL import Image
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化
])

# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(root=os.getcwd(), train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=os.getcwd(), train=False, transform=transform, download=True)

# 路径
train_path = './images/train'
test_path = './images/test'

# 将训练集中的图像保存为图片
for i in range(10):
    file_name = train_path + os.sep + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

for i in range(10):
    file_name = test_path + os.sep + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

for i, (image, label) in enumerate(train_dataset):
    train_label = label
    image_path = f'images/train/{train_label}/{i}.png'
    image = image.squeeze().numpy()  # 去除通道维度,并转换为 numpy 数组
    image = (image * 0.5) + 0.5  # 反标准化,将范围调整为 [0, 1]
    image = (image * 255).astype('uint8')  # 将范围调整为 [0, 255],并转换为整数类型
    Image.fromarray(image).save(image_path)

# 将测试集中的图像保存为图片
for i, (image, label) in enumerate(test_dataset):
    text_label = label
    image_path = f'images/test/{text_label}/{i}.png'
    image = image.squeeze().numpy()  # 去除通道维度,并转换为 numpy 数组
    image = (image * 0.5) + 0.5  # 反标准化,将范围调整为 [0, 1]
    image = (image * 255).astype('uint8')  # 将范围调整为 [0, 255],并转换为整数类型
    Image.fromarray(image).save(image_path)

 训练代码:


import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

# 调动显卡进行计算
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.names_list = []

        for dirs in os.listdir(self.root_dir):
            dir_path = self.root_dir + '/' + dirs
            for imgs in os.listdir(dir_path):
                img_path = dir_path + '/' + imgs
                self.names_list.append((img_path, dirs))

    def __len__(self):
        return len(self.names_list)

    def __getitem__(self, index):
        image_path, label = self.names_list[index]
        if not os.path.isfile(image_path):
            print(image_path + '不存在该路径')
            return None
        image = Image.open(image_path)

        label = np.array(label).astype(int)
        label = torch.from_numpy(label)

        if self.transform:
            image = self.transform(image)

        return image, label


# 定义卷积神经网络模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)  # 卷积
        x = self.relu(x)  # 激活函数
        x = self.maxpool(x)  # 最大值池化
        x = x.view(x.size(0), -1)
        x = self.fc(x)  # 全连接层
        return x


# 加载手写数字数据集
train_dataset = MyDataset('./dataset/images/train', transform=transforms.ToTensor())
val_dataset = MyDataset('./dataset/images/val', transform=transforms.ToTensor())

# 定义超参数
batch_size = 8192  # 批处理大小
learning_rate = 0.001  # 学习率
num_epochs = 30  # 迭代次数

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# 实例化模型、损失函数和优化器
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 优化器

# 记录验证的次数
total_train_step = 0
total_val_step = 0

# 模型训练和验证
print("-------------TRAINING-------------")
total_step = len(train_loader)
for epoch in range(num_epochs):
    print("Epoch=", epoch)
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        output = model(images)
        loss = criterion(output, labels.long())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        print("train_times:{},Loss:{}".format(total_train_step, loss.item()))

    # 测试验证
    total_val_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(val_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels.long())

            total_val_loss = total_val_loss + loss.item()  # 计算损失值的和
            accuracy = 0

            for j in labels:  # 计算精确度的和
                if outputs.argmax(1)[j] == labels[j]:
                    accuracy = accuracy + 1

            total_accuracy = total_accuracy + accuracy

    print('Accuracy =', float(total_accuracy / len(val_dataset)))  # 输出正确率
    torch.save(model, "cnn_{}.pth".format(epoch))  # 模型保存

# # 模型评估
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

测试代码:

import torch
from torchvision import transforms
import torch.nn as nn
import os
from PIL import Image

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 判断是否有GPU


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(16 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)  # 卷积
        x = self.relu(x)  # 激活函数
        x = self.maxpool(x)  # 最大值池化
        x = x.view(x.size(0), -1)
        x = self.fc(x)  # 全连接层
        return x


model = torch.load('cnn.pth')  # 加载模型

path = "./dataset/images/test/"  # 测试集

imgs = os.listdir(path)

test_num = len(imgs)
print(f"test_dataset_quantity={test_num}")

for img_name in imgs:
    img = Image.open(path + img_name)

    test_transform = transforms.Compose([transforms.ToTensor()])

    img = test_transform(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    outputs = model(img)  # 将图片输入到模型中
    _, predicted = outputs.max(1)

    pred_type = predicted.item()
    print(img_name, 'pred_type:', pred_type)

分类正确率不错:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/766404.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二、DDL-4.表操作-修改删除

一、修改 1、往表中添加字段 e.g.为emp表增加一个新的字段“昵称”为nickname,类型为varchar(20) alter table emp add nickname varchar(20) comment 昵称; 2、修改表中字段 e.g.将emp表的nickname字段修改为username,类型为varchar(30) alter table e…

TCP/IP网络编程 第十六章:关于IO流分离的其他内容

分离I/O流 两次I/O流分离 我们之前通过2种方法分离过IO流,第一种是第十章的“TCPI/O过程(Routine)分离”。这种方法通过调用fork函数复制出1个文件描述符,以区分输入和输出中使用的文件描述符。虽然文件描述符本身不会根据输入和输…

基于主从博弈的主动配电网阻塞管理的论文复现——附Matlab代码

目录 文章摘要: 编程思路: 研究背景: 基于主从博弈的电网阻塞管理: 算例介绍: Matlab运行结果展示: Matlab代码数据分享: 文章摘要: 随着需求侧灵活性资源在配电网中的渗透率…

SQLite编程操作

一、打开/创建数据库的C接口 ①sqlite3_open ( const char * filename , sqlite3 ** ppDb ) 打开一个指向 SQLite 数据库文件的连接,返回一个用于其他 SQLite 程序的数据库连接对 象。 ②sqlite3_close(sqlite3*) 关闭之前调用 sqlite3_open() 打开的数据…

⛳ Git安装与配置

Git安装配置目录 ⛳ Git安装与配置🏭 一,git的安装🎨 1,下载git👣 2,下载完成之后,双击安装即可。💻 3,更改安装目录(没有中文且没有空格)&#x…

3本期刊被剔除,7月SCIE/SSCI目录已更新 (附2023WOS历次更新目录)~

2023年7月17日,科睿唯安更新了Web of Science核心期刊目录。 此次更新后SCIE期刊目录共包含9498本期刊,SSCI期刊目录共包含3557本期刊。此次SCIE & SSCI期刊目录更新,与上次更新(2023年6月)相比,有4本S…

Shell之循环语句 —— WhileUntil 实验

While While循环语句:满足条件才会执行循环,不满足就结束,用于不知道循环次数,需要主动结束循环或者达到条件循环的场景 While的结构 while(条件判断)——do —— 命令序列 —— done 如:用whi…

Python实现HBA混合蝙蝠智能算法优化卷积神经网络分类模型(CNN分类算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 蝙蝠算法是2010年杨教授基于群体智能提出的启发式搜索算法,是一种搜索全局最优解的有效方法…

qiankun框架vue3主应用和子应用生产环境打包部署nginx

首先下载nginx,进行最小化配置 用vscode 打开nginx.conf文件 在http模块的server模块里进行配置 listen 字段监听端口号 http的默认端口号是80(nginx的端口号可以随便写) server_name字段 是ip地址 lochhost就是127.0.0.1 lacation 字段 是在浏览器的地址栏http协议ip地址…

C++类和对象——类的基础

目录 类的引入类的定义类的访问限定符和封装对象的实例化类对象的大小this指针 类的引入 在C语言中,结构体中只能定义变量 但是在C中,结构体不仅可以定义变量,还可以定义函数 下面就是C中的一个结构体: struct Stack {void init(…

【Linux系统 学习笔记】Linux线程互斥 线程安全 可重入 不可重入 死锁

目录 Linux 线程互斥进程线程间互斥相关背景和概念互斥量互斥量的接口互斥量实现原理探究 可重入与线程安全概念常见的线程不安全的情况常见的线程安全的情况常见不可重入的情况常见可重入的情况可重入与线程安全联系可重入与线程安全区别 死锁死锁四个必要条件避免死锁 Linux …

【代码随想录13】前 K 个高频元素

题目 给定一个非空的整数数组,返回其中出现频率前 k 高的元素。 示例 1: 输入: nums [1,1,1,2,2,3], k 2输出: [1,2] 示例 2: 输入: nums [1], k 1输出: [1] 提示: 你可以假设给定的 k 总是合理的,且 1 ≤ k ≤ 数组中不相同的元素…

黑客学习笔记(自学)

一、首先,什么是黑客? 黑客泛指IT技术主攻渗透窃取攻击技术的电脑高手,现阶段黑客所需要掌握的远远不止这些。 二、为什么要学习黑客技术? 其实,网络信息空间安全已经成为海陆空之外的第四大战场,除了国…

C#(六十)之Convert类 和 Parse方法的区别

Convert数据类型转换类,从接触C#开始,就一直在用,这篇日志坐下深入的了解。 Convert类常用的类型转换方法 方法 说明 Convert.ToInt32() 转换为整型(int) Convert.ToChar() 转换为字符型(char) Convert.ToString() 转换为字符串型(st…

优化CSS重置过程:探索CSS层叠技术的应用与优势

目录 下面是正文~~ CSS重置方法 方法的结合 合并方法的问题 通用移除样式 顺序很重要 CSS 优先级 我们的CSS特异性冲突 CSS Layers 来拯救 Sass 预处理器支持 浏览器支持 总结 这篇文章介绍了一种名为CSS层叠的技术,用于优化CSS重置过程。它解释了CSS重…

网络安全(黑客技术)最全面的学习笔记

学网络安全如何成为一名黑客呢? 整合了全知识点及学习框架,本篇零基础依然适用! 本篇涵盖内容及其全面,强烈推荐收藏! 一、学习网络安全会遇到什么问题呢? 1、学习基础内容多时间长 学习基础语言太多&…

基于MATLAB的无人机遥感数据预处理与农林植被性状估算教程

详情点击链接:基于MATLAB的无人机遥感数据预处理与农林植被性状估算前言 遥感技术作为一种空间大数据手段,能够从多时、多维、多地等角度,获取大量的农情数据。数据具有面状、实时、非接触、无伤检测等显著优势,是智慧农业必须采…

初中级PHP程序员如何进阶学习?

如果你是一个以PHP为主的开发人员,只会依赖现成的框架进行增删改查,想提高自己又不知道从何下手,你可以花点时间研究一下我这个开源项目:酷瓜云课堂,这个项目以PHPJS 为主,负责主要的业务逻辑,部…

基于遗传算法的新能源电动汽车充电桩与路径选择MATLAB程序

主要内容: 根据城市间的距离,规划新能源汽车的行驶路径。要求行驶距离最短。 部分代码: %% 加载数据 %%遗传参数 load zby;%个城市坐标位置 NIND50; %种群大小 MAXGEN200; Pc0.9; %交叉概率 Pm0.2; %变异概率 GGAP0.…

初识Redis——Redis概述、安装、基本操作

目录 一、NoSQL介绍 1.1什么是NoSQL 1.2为什么会出现NoSQL技术 1.3NoSQL的类别 1.4传统的ACID是什么 1.5 CAP 1.5.1 经典CAP图 1.5.4 什么是BASE 二、Redis概述 2.1 什么是Redis 2.2 Redis能干什么 2.3 Redis的特点 2.4 Redis与memcached对比 2.5 Redis的安装 2.6 Docker安装 三…