【深度学习】pytorch,MNIST手写数字分类

news2024/9/25 7:25:43

efficientnet_b0的迁移学习


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt

# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10

# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0
    transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])

# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 用于绘图的数据
train_losses = []
test_accuracies = []

# 训练模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # 计算平均损失
    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)

    # 测试准确率
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move test data to the correct device
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')

# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1542408.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构】五分钟自测主干知识(十)

上一节,我们讲述了二叉树的概念,二叉树又有什么基本操作呢?今天我们来讲述二叉树的应用~ 话不多说,书继上回 5.3二叉树的遍历及应用 二叉树由三个基本部分组成:根结点(D),左子树&a…

ZooKeeper 的常见应用场景

数据发布与订阅 发布与订阅即所谓的配置管理,顾名思义就是将数据发布到ZooKeeper节点上,供订阅者动态获取数据,实现配置信息的集中式管理和动态更新。例如全局的配置信息,地址列表等就非常适合使用。 数据发布/订阅的一个常见的…

Spring Boot:基础配置

Spring Boot 全局配置文件application.propertiesapplication.yml全局配置文件的优先级 从全局配置文件中获取数据的注解从外部属性文件中获取数据的注解全局配置文件的配置项通用配置项数据源配置项JPA 配置项日志配置项配置文件特定配置项Profile 特定配置项 配置类配置文件中…

【Emgu CV教程】10.4、轮廓之多边形近似拟合

文章目录 一、什么叫轮廓的多边形近似拟合二、轮廓的多边形近似拟合函数三、简单应用1.原始素材2.代码3.运行结果 一、什么叫轮廓的多边形近似拟合 轮廓一般都是光滑的曲线,多边形近似拟合的意思就是,利用少量的点组成的折线,近似逼近原始多…

AIGC实战——Transformer模型

AIGC实战——Transformer模型 0. 前言1. T52. GPT-3 和 GPT-43. ChatGPT小结系列链接 0. 前言 我们在 GPT (Generative Pre-trained Transformer) 一节所构建的 GPT 模型是一个解码器 Transformer,它逐字符地生成文本字符串,并使用因果掩码只关注输入字…

力扣98---验证二叉搜索树

题目描述: 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的左 子树 只包含 小于 当前节点的数。节点的右子树只包含 大于 当前节点的数。所有左子树和右子树自身必须也是二叉搜索树。 …

计算联合体union的大小

一:联合类型的定义 联合也是一种特殊的自定义类型,这种类型定义的变量也包含一系列的成员,特征是这些成员公用同一块空间(所以联合也叫共用体) 比如:共用了 i 这个较大的空间 二: 联合的特点 …

YoloV8改进策略:Block改进|PKINet

摘要 PKINet是面向遥感旋转框的主干,网络包含了CAA、PKI等模块,给我们改进卷积结构的模型带来了很多启发。本文使用PKINet的Block替换YoloV8的Block,实现涨点。改进方法是我独创首发,给写论文没有思路的同学提供改进思路,欢迎大家订阅! 论文:《Poly Kernel Inception …

需求:实现一个类似打印的效果(文字一个字一个字的输出)

实现效果: 需求:最近接到这么一个需求,ai机器人回复的问题,后端是通过websocket每隔一段事件返回数据,前端拿到数据后直接渲染,现在需要做到一个效果,后端返回的结果前端需要一个一个文字的输出…

Unity Canvas的三种模式

一、简介: Canvas的Render Mode一共有三种模式:Screen Space -OverLay、Screen Space-Camera、World Space Screen Space - Overlay(屏幕空间 - 覆盖): 这是最简单的 Canvas 渲染模式。UI 元素在这个模式下将渲染在屏…

使用amd架构的计算机部署其他架构的虚拟机(如:arm)

1 下载quem模拟器 https://qemu.weilnetz.de/w64/2 QEMU UEFI固件文件下载(引导文件) 推荐使用:https://releases.linaro.org/components/kernel/uefi-linaro/latest/release/qemu64/QEMU_EFI.fd3 QEMU 安装 安装完成之后,需要将安装目录添加到环境变…

flutter3_douyin:基于flutter3+dart3短视频直播实例|Flutter3.x仿抖音

flutter3-dylive 跨平台仿抖音短视频直播app实战项目。 全新原创基于flutter3.19.2dart3.3.0getx等技术开发仿抖音app实战项目。实现了类似抖音整屏丝滑式上下滑动视频、左右滑动切换页面模块,直播间进场/礼物动效,聊天等模块。 运用技术 编辑器&#x…

C语言字节对齐关键字#pragma pack(n)的使用

0 前言 在进行嵌入式开发的过程中,我们经常会见到对齐操作。这些对齐操作有些是为了便于实现指针操作,有些是为了加速对内存的访问。因此,学习如何使用对齐关键字是对于嵌入式开发是很有必要的。 1 对齐规则 1.0 什么叫做对齐 众所周知&a…

微服务(基础篇-003-Nacos集群搭建)

目录 Nacos集群搭建 1.集群结构图 2.搭建集群 2.1.初始化数据库 2.2.下载nacos 2.3.配置Nacos 2.4.启动 2.5.nginx反向代理 2.6.优化 视频地址: 06-Nacos配置管理-nacos集群搭建_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1LQ4y127n4?p29&…

操作系统究竟是什么?在计算机体系中扮演什么角色?

操作系统究竟是什么?在计算机体系中扮演什么角色? 一、操作系统概念二、操作系统如何管理软硬件资源2.1 何为管理者2.2 操作系统如何管理硬件 三、系统调用接口作用四、用户操作接口五、广义操作系统和狭义操作系统 一、操作系统概念 下面是来自百度百科…

Springboot做分组校验

目录 分组校验 Insert分组 Upload分组 测试接口 测试结果 添加测试 更新测试 顺序校验GroupSequence 自定义分组校验 自定义分组表单 CustomSequenceProvider 测试接口 测试结果 Type类型为A Type类型为B 总结: 前文提到了做自定义的校验注解&#xff…

React高阶组件(HOC)

高阶组件的基本概念 高阶组件(HOC,Higher-Order Components)不是组件,而是一个函数,它会接收一个组件作为参数并返回一个经过改造的新组件: const EnhancedComponent higherOrderComponent(WrappedCompo…

小游戏-扫雷

扫雷大多人都不陌生,是一个益智类的小游戏,那么我们能否用c语言来编写呢, 我们先来分析一下扫雷的运行逻辑, 首先,用户在进来时需要我们给与一个菜单,以供用户选择, 然后我们来完善一下&#…

解决方案Please use Oracle(R) Java(TM) 11, OpenJDK(TM) 11 to run Neo4j.

文章目录 一、现象二、解决方案 一、现象 当安装好JDK跟neo4j,用neo4j.bat console来启动neo4却报错: 部分报错信息: Starting Neo4j. WARNING! You are using an unsupported Java runtime. Please use Oracle Java™ 11, OpenJDK™ 11 t…

Rust下载安装、卸载、版本切换、创建项目(包含指定版本的)

先声名一下,下面所说的版本号为xxxxx-x86_64-unknown-linux-gnu中xxxxx的部分。 下载安装 下载最新版本的Rust: curl --proto https --tlsv1.2 -sSf https://sh.rustup.rs | sh info: downloading installer重启shell 或者 按照提示 执行命令让环境变…