深度学习笔记_7经典网络模型LSTM解决FashionMNIST分类问题

news2025/1/10 20:29:05

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt

# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 设置超参数
sequence_length = 28
input_size = 28  
hidden_size = 128
num_layers = 2 
num_classes = 10
batch_size = 64
learning_rate = 0.001
num_epochs = 50

# 定义数据转换操作
transform = transforms.Compose([
    transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转
    transforms.RandomHorizontalFlip(),   # 随机水平翻转
    transforms.RandomCrop(size=28, padding=4),   # 随机裁剪
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,
                        download=True, transform=transform)

# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,
                       download=True, transform=transform)

# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

3、定义LSTM模型

# 定义LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size  # LSTM隐含层神经元数
        self.num_layers = num_layers  # LSTM层数
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)  # LSTM层
        self.fc = nn.Linear(hidden_size, num_classes)  # 全连接层

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)  # 初始化状态
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, _ = self.lstm(x, (h0, c0))  # LSTM前向传播
        out = self.fc(out[:, -1, :])  # 只取序列最后一个时间步的输出
        return F.log_softmax(out, dim=1)  # 使用log_softmax作为输出

# 初始化模型、优化器和损失函数
model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []

4、 训练循环

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
        data = data.view(-1, sequence_length, input_size)

        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # 计算训练准确率
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

    # 计算平均训练损失和训练准确率
    train_loss /= len(train_loader)
    train_accuracy = 100. * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)

    # 测试模型
    model.eval()
    test_loss = 0.0
    correct = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
            data = data.view(-1, sequence_length, input_size)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上
            all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上

    # 计算平均测试损失和测试准确率
    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)

    # 计算额外的指标
    conf_matrix = confusion_matrix(all_labels, all_preds)
    conf_matrix_list.append(conf_matrix)

    accuracy = accuracy_score(all_labels, all_preds)
    accuracy_list.append(accuracy)

    error_rate = 1 - accuracy
    error_rate_list.append(error_rate)

    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    precision_list.append(precision)
    recall_list.append(recall)
    f1_score_list.append(f1)

    fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)
    roc_auc = auc(fpr, tpr)
    roc_auc_list.append(roc_auc)

    # 打印每个 epoch 的指标
    print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# 打印或绘制训练后的最终指标
print(f'Final Confusion Matrix:\n{conf_matrix_list[-1]}')
print(f'Final Accuracy: {accuracy_list[-1]:.2%}')
print(f'Final Error Rate: {error_rate_list[-1]:.2%}')
print(f'Final Precision: {precision_list[-1]:.2%}')
print(f'Final Recall: {recall_list[-1]:.2%}')
print(f'Final F1 Score: {f1_score_list[-1]:.2%}')
print(f'Final ROC AUC: {roc_auc_list[-1]:.2%}')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()


# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()


# 计算混淆矩阵
class_labels = [str(i) for i in range(10)]
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1321244.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2000年AMC8数学竞赛中英文真题典型考题、考点分析和答案解析

今天是2023年12月19日,距离2024年的AMC8正式考试倒计时一个月。 从战争中学习战争最有效。前几天,六分成长分析了2023年、2022年、2020、2019、2018、2017的AMC8真题的典型考题、考点和详细答案解析。 今天我们不再从2016年分析,来看看更早…

pytorch文本分类(三)模型框架(DNNtextCNN)

pytorch文本分类(三)模型框架(DNN&textCNN) 原任务链接 目录 pytorch文本分类(三)模型框架(DNN&textCNN)1. 背景知识深度学习 2. DNN2.1 从感知器到神经网络2.2 DNN的基本…

避坑指南:uni-forms表单在uni-app中的实践经验

​🌈个人主页:前端青山 🔥系列专栏:uni-app篇 🔖人终将被年少不可得之物困其一生 依旧青山,本期给大家带来JavaScript篇专栏内容:uni-app中forms表单的避坑指南篇 该篇章已被前端圈子收录,点此处进入即可查看更多优质内…

Pytorch nn.Linear()的基本用法与原理详解及全连接层简介

主要引用参考: https://blog.csdn.net/zhaohongfei_358/article/details/122797190 https://blog.csdn.net/weixin_43135178/article/details/118735850 nn.Linear的基本定义 nn.Linear定义一个神经网络的线性层,方法签名如下: torch.nn.Li…

AT32F403如何扩大SRAM

配置方法 使用雅特力的ICP 进行配置(可在官网下载) (1)当连接上芯片后,点击设备操作->选择字节 (2)选择224KB SRAM (3)然后点击应用到设备,(可以点击从设备加载,来看当前的配置) (4)打开keil5魔术棒图标 ,将Target中的IRAM1第二个选项从0x10000改为0x3800。…

虚拟电厂 能源物联新方向

今年有多热?据上海市气象局官微消息,5月29日13时09分,徐家汇站气温达36.1℃,打破了百年来的当地5月份气温*高纪录。不仅如此,北京、四川、江西、湖南、广东、广西等地也频频发布高温预警。 伴随着居民用电急剧攀升&am…

4.1 媒资管理模块 - Nacos与Gateway搭建

文章目录 媒资管理模块 - 媒资项目搭建一、需求分析1.1 介绍1.2 数据模型1.3 分析网关 二、 搭建Nacos2.1 服务发现中心2.2.1 Maven2.2.2 配置Nacos 2.2 配置中心2.2.1 介绍2.2.2 Maven 坐标2.2.3 配置 content-api 工程2.2.4 配置 content-service 工程2.2.5 配置 system-api …

基础算法(5):滑动窗口

1.何为滑动窗口? 滑动窗口其实也是一种算法,主要有两类:一类是固定窗口,一类是可变窗口。固定的窗口只需要一个变量记录,而可变窗口需要两个变量。 2.固定窗口 就像上面这个图一样。两个相邻的长度为4的红色窗口&…

HTML---CSS美化网页元素

文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 一.div 标签&#xff1a; <div>是HTML中的一个常用标签&#xff0c;用于定义HTML文档中的一个区块&#xff08;或一个容器&#xff09;。它可以包含其他HTML元素&#xff0c;如文本、图像…

探秘 AJAX:让网页变得更智能的异步技术(上)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

如何编写好的测试用例?

对于软件测试工程师来说&#xff0c;设计测试用例和提交缺陷报告是最基本的职业技能。是非常重要的部分。一个好的测试用例能够指示测试人员如何对软件进行测试。在这篇文章中&#xff0c;我们将介绍测试用例设计常用的几种方法&#xff0c;以及如何编写高效的测试用例。 一、…

iPhone 17Pro/Max或升级4800万像素长焦镜头,配备自研Wi-Fi 7芯片。

iPhone 16未至&#xff0c;关于iPhone 17系列的相关消息就已经放出&#xff0c;到底是谁走漏了风声。 海通国际证券技术分析师Jeff Pu近日发布报告称&#xff0c;苹果将为2025年推出的iPhone 17ProMax配备4800万像素的长焦镜头。经调查&#xff0c;该分析师认为提升iPhone拍摄方…

如何在华为云上购买ECS及以镜像的方式部署华为云欧拉操作系统 (HCE OS)

写在前面 工作中遇到&#xff0c;简单整理博文内容为 华为云开发者认证 实验笔记https://edu.huaweicloud.com/certificationindex/developer/9bf91efb086a448ab4331a2f53a4d3a1理解不足小伙伴帮忙指正 对每个人而言&#xff0c;真正的职责只有一个&#xff1a;找到自我。然后在…

Nginx快速入门:Nginx应用场景、安装与部署(一)

1. Nginx简介 Nginx 是一个高性能的 HTTP 和反向代理服务器&#xff0c;也是一个非常流行的开源 Web 服务器软件。它是由俄罗斯程序员 Igor Sysoev 开发的&#xff0c;最初是为了解决在高并发场景下的C10k 问题&#xff08;即一个服务器进程只能处理 10,000 个并发连接&#x…

早期的OCR是怎么识别图片上的文字的?

现在的OCR技术融合了人工智能技术&#xff0c;通过深度学习&#xff0c;无论是识别的准确率还是效果都非常不错&#xff0c;那您知道在早期的OCR是通过什么技术来实现的吗&#xff1f;如果您不知道&#xff0c;那么&#xff0c;就让我来告诉您&#xff1a;它主要是基于字符的几…

DiffUtil + RecyclerView 在 Kotlin中的使用

很惭愧, 做了多年的Android开发还没有使用过DiffUtil这样解放双手的工具。 文章目录 1 DiffUtil 用来解决什么问题?2 DiffUtil 是什么?3 DiffUtil的使用4 参考文章 1 DiffUtil 用来解决什么问题? 先举几个实际开发中的例子帮助我们感受下: 加载内容流时,第一次加载了ABC,…

数据分析思维导图

参考&#xff1a; https://zhuanlan.zhihu.com/p/567761684?utm_id0 1、数据分析步骤地图 2、数据分析基础知识地图 3、数据分析技术知识地图 4、数据分析业务流程 5、数据分析师能力体系 6、数据分析思路体系 7、电商数据分析核心主题 8、数据科学技能书知识地图 9、数据挖掘…

文章解读与仿真程序复现思路——电力系统自动化EI\CSCD\北大核心《基于碳捕集-电转气的矿区综合能源系统协同优化调度》

这个标题涉及到碳捕集、电力转化为气体&#xff08;可能是指电力转化为氢气等&#xff09;、矿区综合能源系统以及协同优化调度等概念。让我们逐步解读&#xff1a; 碳捕集&#xff08;Carbon Capture&#xff09;&#xff1a; 这指的是通过不同技术手段捕获和隔离工业过程中产…

输电线路定位:精确导航,确保电力传输安全

在现代社会中&#xff0c;电力作为生活的基石&#xff0c;其安全稳定运行至关重要。而输电线路作为电力传输的重要通道&#xff0c;其故障定位和修复显得尤为重要。恒峰智慧科技将为您介绍一种采用分布式行波测量技术的输电线路定位方法&#xff0c;以提高故障定位精度&#xf…

新版Android Studio Logcat 筛选日志

下载了新版的Android Studio&#xff0c;android-studio-2022.3.1.21-mac_arm&#xff0c;记录一下新版本AS的logcat过滤日志条件 1. 按照包名过滤 1.1 过滤当前包名的日志 package:mine 1.2 过滤其他包名日志 package:com.example.firstemptyapplication 2. 按照日志等级过滤…