深度学习笔记_6经典预训练网络LeNet-18解决FashionMNIST数据集

news2024/11/25 6:38:39

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
from torchvision import models

# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 50

# 定义数据转换操作
transform = transforms.Compose([
    transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转
    transforms.RandomHorizontalFlip(),   # 随机水平翻转
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,
                        download=True, transform=transform)

# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,
                       download=True, transform=transform)

# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=test_batch_size, shuffle=False)

3、使用预训练的ResNet-18模型

# 使用预训练的ResNet-18模型
model = models.resnet18(pretrained=True)
# 修改最后一层,使其适应FashionMNIST的输出类别数
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

# 冻结预训练模型的参数
for param in model.parameters():
    param.requires_grad = False

# 只训练模型的最后一层
for param in model.fc.parameters():
    param.requires_grad = True
# 初始化优化器和损失函数
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

4、 训练循环

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []
conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []

# 训练循环
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # 计算训练准确率
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

    # 计算平均训练损失和训练准确率
    train_loss /= len(train_loader)
    train_accuracy = 100. * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)

    # 测试模型
    model.eval()
    test_loss = 0.0
    correct = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上
            all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上

    # 计算平均测试损失和测试准确率
    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)

    # 计算额外的指标
    conf_matrix = confusion_matrix(all_labels, all_preds)
    conf_matrix_list.append(conf_matrix)

    accuracy = accuracy_score(all_labels, all_preds)
    accuracy_list.append(accuracy)

    error_rate = 1 - accuracy
    error_rate_list.append(error_rate)

    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    precision_list.append(precision)
    recall_list.append(recall)
    f1_score_list.append(f1)

    fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)
    roc_auc = auc(fpr, tpr)
    roc_auc_list.append(roc_auc)

    # 打印每个 epoch 的指标
    print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.show()

# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.show()

# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
class_labels = [str(i) for i in range(10)]
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1319781.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

眼镜正确清洗方式有哪些?超声波眼镜清洗机推荐

随着人们对健康的重视,眼镜已经成为了日常生活中的必需品。然而,眼镜的清洗却常常被忽视。正确的清洗方式不仅可以保护眼睛健康,还可以延长眼镜的使用寿命。那么,眼镜的正确清洗方式有哪些呢?经常去眼镜店清洗眼镜的朋…

【PHP入门】1.2-常量与变量

-常量与变量- PHP是一种动态网站开发的脚本语言,动态语言特点是交互性,会有数据的传递,而PHP作为“中间人”,需要进行数据的传递,传递的前提就是PHP能自己存储数据(临时存储) 1.2.1变量基本概…

未来LED全彩显示屏的发展趋势研究

随着LED产品性能的不断提升,全彩 LED 显示屏在亮度、颜色改善和白平衡方面已经达到了比较理想的效果,完全可以满足户外全天候的环境条件。由于全彩 LED 显示屏在价格性能比上的优势,未来数年内有望逐渐取代传统的灯箱、霓虹灯、磁翻板等产品&…

如何在Eclipse中安装WindowBuilder插件,详解过程

第一步:找到自己安装eclipse的版本,在Help-关于eclipse里面,即Version 第二步:去下面这个网站找到对应的 link(Update Site),这一步很重要,不然版本下载错了之后还得删除WindowBuil…

2023大湾区汽车创新大会在深圳坪山开幕

12月15日,2023大湾区汽车创新大会在深圳坪山开幕。 本次大会是由广东省科学技术厅、深圳市发展和改革委员会、深圳市工业和信息化局、中共深圳市新能源和智能网联汽车产业链委员会、坪山区人民政府指导,北京理工大学深圳汽车研究院、广东省大湾区新能源汽…

关于“Python”的核心知识点整理大全27

目录 10.5 小结 第11 章 测试代码 11.1 测试函数 name_function.py 函数get_formatted_name()将名和姓合并成姓名,在名和姓之间加上一个空格,并将它们的 首字母都大写,再返回结果。为核实get_formatted_name()像期望的那样工…

SwitchHosts - 管理、切换多个 hosts 方案的工具

一、hosts文件 简单的说,hosts文件是用于本地dns服务的,采用ip 域名的格式写在一个文本文件当中,Hosts是一个没有扩展名的系统文件,可以用记事本等工具打开,其作用就是将一些常用的网址域名与其对应的IP地址建立一个关…

机器学习——数据划分

【说明】文章内容来自《机器学习入门——基于sklearn》,用于学习记录。若有争议联系删除。 1、数据划分 在机器学习中,通常将数据集划分为训练集和测试集。训练集用于训练数据,生成机器学习模型;测试集用于评估学习模型的泛化性能…

如何让.NET应用使用更大的内存

我一直在思考为何Redis这种应用就能独占那么大的内存空间而我开发的应用为何只有4GB大小左右,在此基础上也问了一些大佬,最终还是验证下自己的猜测。 操作系统限制 主要为32位操作系统和64位操作系统。 每个进程自身还分为了用户进程空间和内核进程空…

安全算法(二):共享密钥加密、公开密钥加密、混合加密和迪菲-赫尔曼密钥交换

安全算法(二):共享密钥加密、公开密钥加密、混合加密和迪菲-赫尔曼密钥交换 本章介绍了共享密钥加密、公开密钥加密,和两种加密方法混合使用的混合加密方法;最后介绍了迪菲-赫尔曼密钥交换。 加密数据的方法可以分为…

人工智能工程师

据悉:为进一步贯彻落实中共中央印发《关于深化人才发展体制机制改革的意见》和国务院印发《关于“十四五”数字经济发展规划》等有关工作的部署求,深入实施人才强国战略和创新驱动发展战略,加强全国数字化人才队伍建设,持续推进人…

Axure之交互与情节与一些实例

目录 一.交互与情节简介 二.ERP登录页到主页的跳转 三.ERP的菜单跳转到各个页面的跳转 四.省市联动 五.手机下拉加载 今天就到这里了,希望帮到你哦!!! 一.交互与情节简介 "交互"通常指的是人与人、人与计算机或物体…

lseek()函数的原型及使用方法,超详细

对于所有打开的文件都有一个当前文件偏移量(current file offset),文件偏移量通常是一个非负整数,用于表明文件开始处到文件当前位置的字节数。 读写操作通常开始于当前文件偏移量的位置,并且使其增大,增量为读写的字节数。文件被…

苹果M系列芯片安装Notepad-- 详细教程(亲测14以上系统也可用)

目录 1. 介绍2. 前言说明3. 安装使用教程3.1 下载3.2 安装3.3 打开3.4 最终效果 4. 主体功能一览5. 其他信息 1. 介绍 鉴于某些Notepad竞品作者的不当言论,Notepad–的意义在于:减少一点错误言论,减少一点自以为是。 Notepad–的目标&#xf…

Xpath注入

这里学习一下xpath注入 xpath其实是前端匹配树的内容 爬虫用的挺多的 XPATH注入学习 - 先知社区 查询简单xpath注入 index.php <?php if(file_exists(t3stt3st.xml)) { $xml simplexml_load_file(t3stt3st.xml); $user$_GET[user]; $query"user/username[name&q…

【网络安全技术】传输层安全——SSL/TLS

一、TLS位置及架构 TLS建立在传输层TCP/UDP之上&#xff0c;应用层之下。 所以这可以解决一个问题&#xff0c;那就是为什么抓不到HTTP和SMTP包&#xff0c;因为这两个在TLS之上&#xff0c;消息封上应用层的头&#xff0c;下到TLS层&#xff0c;TLS层对上层消息整个做了加密&…

SpringBoot 3.2.0 版本 mysql 依赖下载错误

最近想尝试一下最新的 SpringBoot 项目&#xff0c;于是将自己的开源项目进行了一些升级。 JDK 版本从 JDK8 升级至 JDK17。SpringBoot 版本从 SpringBoot 2.7.3 升级到 SpringBoot 3.2.0 其中 JDK 的升级比较顺利&#xff0c;毕竟 JDK 的旧版本兼容性一直非常好。 但是在升级…

STM32_通过Ymodem协议进行蓝牙OTA升级固件教程

目录标题 前言1、OTA升级的重要性和应用场景2、理论基础2.1、单片机的启动流程2.2、什么是IAP&#xff1f;2.3、什么是OTA&#xff1f;2.4、什么是BootLoader&#xff1f;2.5、Ymodem协议是什么&#xff1f;2.6、IAP是如何实现的&#xff1f; 3、具体操作3.1、软硬件工具准备3.…

链表基础知识(二、双向链表头插、尾插、头删、尾删、查找、删除、插入)

目录 一、双向链表的概念 二、 双向链表的优缺点分析​与对比 2.1双向链表特点&#xff1a; 2.2双链表的优劣&#xff1a; 2.3循环链表的优劣 2.4 顺序表和双向链表的优缺点分析​ 三、带头双向循环链表增删改查实现 3.1SList.c 3.2创建一个新节点、头节点 3.3头插 3.…

互联网加竞赛 python+opencv+深度学习实现二维码识别

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; pythonopencv深度学习实现二维码识别 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;3分 该项目较为新颖&…