卷积神经网络——LeNet——FashionMNIST

news2024/9/21 0:51:58

目录

  • 一、整体结构
  • 二、model.py
  • 三、model_train.py
  • 四、model_test.py

GitHub地址

一、整体结构

在这里插入图片描述

二、model.py

import torch
from torch import nn
from torchsummary import summary

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.c1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)
        self.sig = nn.Sigmoid()
        self.s2 = nn.AvgPool2d(kernel_size=2,stride=2)
        self.c3 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
        self.s4 = nn.AvgPool2d(kernel_size=2,stride=2)

        self.flatten = nn.Flatten()
        self.f5 = nn.Linear(in_features=5*5*16,out_features=120)
        self.f6 = nn.Linear(in_features=120,out_features=84)
        self.f7 = nn.Linear(in_features=84,out_features=10)

    def forward(self,x):
        x = self.sig(self.c1(x))
        x = self.s2(x)
        x = self.sig(self.c3(x))
        x = self.s4(x)
        x = self.flatten(x)
        x = self.f5(x)
        x = self.f6(x)
        x = self.f7(x)
        return x

# if __name__ =="__main__":
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
#     model = LeNet().to(device)
#
#     print(summary(model,input_size=(1,28,28)))

三、model_train.py

# 导入所需的Python库
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as Data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LeNet  # model.py中定义了LeNet模型
from tqdm import tqdm  # 导入tqdm库,用于显示进度条

# 定义数据加载和处理函数
def train_val_data_process():
    # 加载FashionMNIST数据集,Resize到28x28尺寸,并转换为Tensor
    train_data = FashionMNIST(root="./data",
                              train=True,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    # 将加载的数据集分为80%的训练数据和20%的验证数据
    train_data, val_data = Data.random_split(train_data, lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])

    # 为训练数据和验证数据创建DataLoader,设置批量大小为32,洗牌,2个进程加载数据
    train_dataloader = Data.DataLoader(dataset=train_data,
                                       batch_size=32,
                                       shuffle=True,
                                       num_workers=2)

    val_dataloader = Data.DataLoader(dataset=val_data,
                                     batch_size=32,
                                     shuffle=True,
                                     num_workers=2)

    # 返回训练和验证的DataLoader
    return train_dataloader, val_dataloader

# 定义模型训练和验证过程的函数
def train_model_process(model, train_dataloader, val_dataloader, num_epochs):
    # 设置使用CUDA如果可用
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 打印使用的设备
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    print(f'当前模型训练设备为: {dev}')

    # 初始化Adam优化器和交叉熵损失函数
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # 将模型移动到选定的设备上
    model = model.to(device)

    # 复制模型权重用于后续更新最佳模型
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0  # 初始化最佳准确度

    # 初始化用于记录训练和验证过程中损失和准确度的列表
    train_loss_all = []
    val_loss_all = []
    train_acc_all = []
    val_acc_all = []

    # 记录训练开始时间
    start_time = time.time()

    # 迭代指定的训练轮数
    for epoch in range(1, num_epochs + 1):
        # 记录每个epoch开始的时间
        since = time.time()

        # 打印分隔符和当前epoch信息
        print("-" * 10)
        print(f"Epoch: {epoch}/{num_epochs}")

        # 初始化训练和验证过程中的损失和正确预测数量
        train_loss = 0.0
        train_corrects = 0
        val_loss = 0.0
        val_corrects = 0

        # 初始化批次计数器
        train_num = 0
        val_num = 0

        # 创建训练进度条
        progress_train_bar = tqdm(total=len(train_dataloader), desc=f'Training {epoch}', unit='batch')

        # 训练数据集的遍历
        for step, (b_x, b_y) in enumerate(train_dataloader):
            # 将数据移动到相应的设备上
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            # 训练模型
            model.train()

            # 前向传播
            output = model(b_x)

            # 计算预测标签
            pre_label = torch.argmax(output, dim=1)

            # 计算损失
            loss = criterion(output, b_y)

            # 清空梯度
            optimizer.zero_grad()

            # 反向传播
            loss.backward()

            # 更新权重
            optimizer.step()

            # 累加损失和正确预测数量
            train_loss += loss.item() * b_x.size(0)
            train_corrects += torch.sum(pre_label == b_y.data)

            # 更新批次计数器
            train_num += b_x.size(0)

            # 更新训练进度条
            progress_train_bar.update(1)

        # 关闭训练进度条
        progress_train_bar.close()

        # 创建验证进度条
        progress_val_bar = tqdm(total=len(val_dataloader), desc=f'Validation {epoch}', unit='batch')

        # 验证数据集的遍历
        for step, (b_x, b_y) in enumerate(val_dataloader):
            # 将数据移动到相应的设备上
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            # 评估模型
            model.eval()

            # 前向传播
            output = model(b_x)

            # 计算预测标签
            pre_label = torch.argmax(output, dim=1)

            # 计算损失
            loss = criterion(output, b_y)

            # 累加损失和正确预测数量
            val_loss += loss.item() * b_x.size(0)
            val_corrects += torch.sum(pre_label == b_y.data)

            # 更新批次计数器
            val_num += b_x.size(0)

            # 更新验证进度条
            progress_val_bar.update(1)

        # 关闭验证进度条
        progress_val_bar.close()

        # 计算并记录epoch的平均损失和准确度
        train_loss_all.append(train_loss / train_num)
        train_acc_all.append(train_corrects.double().item() / train_num)

        val_loss_all.append(val_loss / val_num)
        val_acc_all.append(val_corrects.double().item() / val_num)

        # 打印训练和验证的损失与准确度
        print(f'{epoch} Train Loss: {train_loss_all[-1]:.4f} Train Acc: {train_acc_all[-1]:.4f}')
        print(f'{epoch} Val Loss: {val_loss_all[-1]:.4f} Val Acc: {val_acc_all[-1]:.4f}')

        # 计算并打印epoch训练耗费的时间
        time_use = time.time() - since
        print(f'第 {epoch} 个 epoch 训练耗费时间: {time_use // 60:.0f}m {time_use % 60:.0f}s')

        # 若当前epoch的验证准确度为最佳,则更新最佳模型权重
        if val_acc_all[-1] > best_acc:
            best_acc = val_acc_all[-1]
            best_model_wts = copy.deepcopy(model.state_dict())

    # 训练结束,保存最佳模型权重
    torch.save(best_model_wts, 'D:/Pycharm/deepl/LeNet/weight/best_model.pth')

    # 如果当前epoch为总epoch数,则保存最终模型权重
    if epoch == num_epochs:
        torch.save(model.state_dict(), f'D:/Pycharm/deepl/LeNet/weight/{num_epochs}_model.pth')

    # 将训练过程中的统计数据整理成DataFrame
    train_process = pd.DataFrame(data={
        "epoch": range(1, num_epochs + 1),
        "train_loss_all": train_loss_all,
        "val_loss_all": val_loss_all,
        "train_acc_all": train_acc_all,
        "val_acc_all": val_acc_all
    })

    # 打印总训练时间
    consume_time = time.time() - start_time
    print(f'总耗时:{consume_time // 60:.0f}m {consume_time % 60:.0f}s')

    # 返回包含训练过程统计数据的DataFrame
    return train_process

# 定义绘制训练和验证过程中损失与准确度的函数
def matplot_acc_loss(train_process):
    # 创建图形和子图
    plt.figure(figsize=(12, 4))

    # 绘制训练和验证损失
    plt.subplot(1, 2, 1)
    plt.plot(train_process["epoch"], train_process["train_loss_all"], 'ro-', label="train_loss")
    plt.plot(train_process["epoch"], train_process["val_loss_all"], 'bs-', label="val_loss")
    plt.legend()
    plt.xlabel("epoch")
    plt.ylabel("loss")
    # 保存损失图像
    plt.savefig('./result_picture/training_loss_accuracy.png', bbox_inches='tight')

    # 绘制训练和验证准确度
    plt.subplot(1, 2, 2)
    plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_acc")
    plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bs-', label="val_acc")
    plt.legend()
    plt.xlabel("epoch")
    plt.ylabel("accuracy")
    # 保存准确率曲线图
    plt.savefig('./result_picture/training_accuracy.png', bbox_inches='tight')
    plt.show()

if __name__ == "__main__":
    model = LeNet()

    train_dataloader, val_dataloader = train_val_data_process()
    train_process = train_model_process(model, train_dataloader, val_dataloader, num_epochs=20)

    matplot_acc_loss(train_process)

四、model_test.py

import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# t代表test


def t_data_process():
    test_data = FashionMNIST(root="./data",
                             train=False,
                              transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),
                              download=True)

    test_dataloader = Data.DataLoader(dataset=test_data,
                                       batch_size=1,
                                       shuffle=True,
                                       num_workers=0)

    return test_dataloader


def t_model_process(model, test_dataloader):
    if model is not None:
        print('Successfully loaded the model.')

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = model.to(device)

    # 初始化参数
    test_corrects = 0.0
    test_num = 0
    all_preds = []  # 存储所有预测标签
    all_labels = []  # 存储所有实际标签

    # 只进行前向传播,不计算梯度
    with torch.no_grad():
        for test_x, test_y in test_dataloader:
            test_x = test_x.to(device)
            test_y = test_y.to(device)

            # 设置模型为验证模式
            model.eval()
            # 前向传播得到一个batch的结果
            output = model(test_x)
            # 查找最大值对应的行标
            pre_lab = torch.argmax(output, dim=1)

            # 收集预测和实际标签
            all_preds.extend(pre_lab.tolist())
            all_labels.extend(test_y.tolist())

            # 计算准确率
            test_corrects += torch.sum(pre_lab == test_y.data)

            # 将所有的测试样本进行累加
            test_num += test_x.size(0)

    # 计算准确率
    test_acc = test_corrects.double().item() / test_num
    print(f'测试的准确率:{test_acc}')

    # 绘制混淆矩阵
    conf_matrix = confusion_matrix(all_labels, all_preds)
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.title('Confusion Matrix')
    plt.show()
    plt.savefig('./result_picture/Confusion_Matrix.png', bbox_inches='tight')



if __name__=="__main__":
    # 加载模型
    model = LeNet()

    print('loading model')
    # 加载权重
    model.load_state_dict(torch.load('D:/Pycharm/deepl/LeNet/weight/best_model.pth'))

    # 加载测试数据
    test_dataloader = t_data_process()

    # 加载模型测试的函数
    t_model_process(model,test_dataloader)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = model.to(device)

    classes = ['T-shirt/top','Trouser','Pullover','Dress','coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
    with torch.no_grad():
        for b_x,b_y in test_dataloader:
            b_x = b_x.to(device)
            b_y = b_y.to(device)

            model.eval()

            output = model(b_x)
            pre_lab = torch.argmax(output,dim=1)
            result = pre_lab.item()
            label = b_y.item()

            print(f'预测值:{classes[result]}',"-----------",f'真实值:{classes[label]}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1911661.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java--instanceof和类型转换

1.如图,Object,Person,Teacher,Student四类的关系已经写出来了,由于实例化的是Student类,因此,与Student类存在关系的类在使用instanceof时都会输出True,而无关的都会输出False&…

Vatee万腾平台:创新科技,驱动未来

在科技日新月异的今天,每一个创新的火花都可能成为推动社会进步的重要力量。Vatee万腾平台,作为科技创新领域的佼佼者,正以其卓越的技术实力、前瞻性的战略眼光和不懈的探索精神,驱动着未来的车轮滚滚向前。 Vatee万腾平台深知&am…

STM32实战篇:按键控制LED

按键控制LED 功能要求 有两个按键,分别控制两个LED灯。当按键按下后,灯的亮暗状态改变。实物如下图所示: 由图可知,按键一端直接接地,故另一端所对应IO引脚的输入模式应该为上拉输入模式。 实现代码 #include "…

WEB安全基础:网络安全常用术语

一、攻击类别 漏洞:硬件、软件、协议,代码层次的缺陷。 后⻔:方便后续进行系统留下的隐蔽后⻔程序。 病毒:一种可以自我复制并传播,感染计算机和网络系统的恶意软件(Malware),它能损害数据、系统功能或拦…

接口测试(3)

接口自动化 # 获取图片验证码import requestsresponse requests.get(url"http://kdtx-test.itheima.net/api/captchaImage")print(response.status_code) print(response.text) import requestsurl "http://kdtx-test.itheima.net/api/login" header_da…

【自动驾驶/机器人面试C++八股精选】专栏介绍

目录 一、自动驾驶和机器人技术发展前景二、C在自动驾驶和机器人领域的地位三、专栏介绍四、订阅需知 一、自动驾驶和机器人技术发展前景 随着人工智能、机器学习、传感器技术和计算能力的进步,自动驾驶和机器人的技术水平不断提升,使得它们更加智能、可…

国际网课平台Udemy上的亚马逊云科技AWS免费高分课程和创建、维护EC2动手实践

亚马逊云科技(AWS)是全球云行业最🔥火的云平台,在全球经济形势不好的大背景下,通过网课学习亚马逊云科技AWS基础备考亚马逊云科技AWS证书,对于找工作或者无背景转行做AWS帮助巨大。欢迎大家关注小李哥,及时了解世界最前…

Clickhouse的联合索引

Clickhouse 有了单独的键索引,为什么还需要有联合索引呢?了解过mysql的兄弟们应该都知道这个事。 对sql比较熟悉的兄弟们估计看见这个联合索引心里大概有点数了,不过clickhouse的联合索引相比mysql的又有些不一样了,mysql 很遵循最…

信息技术课上的纪律秘诀:营造有序学习环境

信息技术课是学生们探索数字世界的乐园,但同时也是课堂纪律管理的挑战场。电脑、网络、游戏等元素可能分散学生的注意力,影响学习效果。本文将分享一些有效的策略,帮助教师在信息技术课上维持课堂纪律,确保教学活动顺利进行。 制…

C++笔试强训3

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一、选择题1-5题6-10题 二、编程题题目一题目二 一、选择题 1-5题 如图所示,如图所示p-3指向的元素是6,printf里面的是%s,从6开…

BurpSuite抓IOS设备HTTPS流量

一、简述: Burp 这个工具做过 web 安全的人都应该用过,是个非常强大的抓包工具。在 PC 的浏览器上直接配置代理就行了,本篇文章就来介绍一下如何用 Burp 抓 IOS 设备上的流量,很多文章都介绍过怎么抓包,但是很多坑都没…

计算机网络之WPAN 和 WLAN

上一篇文章内容:无线局域网 1.WPAN(无线个人区域网) WPAN 是以个人为中心来使用的无线个人区域网,它实际上就是一个低功率、小范围、低速率和低价格的电缆替代技术。 (1) 蓝牙系统(Bluetooth) &#…

汇川CodeSysPLC教程03-2-14 与HMI通信

硬件连接 PLC与HMI连接采用何种连接方式,通常是参考双方支持哪些接口。PLC(可编程逻辑控制器)与HMI(人机界面)之间的通讯方式主要有以下几种: 串行通讯(Serial Communication)&…

redis学习(007 实战:黑马点评:登录)

黑马程序员Redis入门到实战教程,深度透析redis底层原理redis分布式锁企业解决方案黑马点评实战项目 总时长 42:48:00 共175P 此文章包含第25p-第p34的内容 文章目录 短信登录功能session 共享问题 短信登录功能 接口编写 这里是Result的封装 过滤器在拦截器的外层…

ISO/OSI七层模型

ISO:国际标准化/ OSI:开放系统互联 七层协议必背图 1.注意事项: 1.上三层是为用户服务的,下四层负责实际数据传输。 2.下四层的传输单位: 传输层; 数据段(报文) 网络层: 数据包(报…

【MATLAB源码-第232期】基于matlab的 (204,188) RS编码解码仿真,采用QPSK调制输出误码率曲线。

操作环境: MATLAB 2022a 1、算法描述 Reed-Solomon码(RS码)是一类广泛应用于数字通信和存储系统中的纠错码,尤其在光盘、卫星通信和QR码等领域有着重要作用。RS码是一种非二进制的纠删码,由Irving S. Reed和Gustave…

vue缓存页面,当tab切换时保留原有的查询条件

需求: 切换tab时,查询条件不变 路由页面: 单个页面上加这句话:

一文带你彻底搞懂什么是责任链模式!!

文章目录 什么是责任链模式?详细示例SpingMVC 中的责任链模式使用总结 什么是责任链模式? 在我们日常生活中,经常会出现一种场景:一个请求需要经过多个对象的处理才能得到最终的结果。比如,一个请假申请,需…

vue打包terser压缩去除控制台打印和断点

情况一: 1、vue-cli搭建 代码压缩工具terser在vue-cli里面是自动支持的,所以直接在vue.config.js里面加入下面配置: const {defineConfig} require(vue/cli-service) module.exportsdefineConfig({transpileDependencies:true,terser:{te…

火灾疏散逃生3d消防模拟演练教学系统助您轻松打造专业的消防培训基地

消防VR仿真教学系统将真实世界的消防挑战带入虚拟的训练环境,为您打造了一个前所未有的消防培训体验。在这里,您可以模拟现实中难以搭建的复杂场景,如工厂火灾、地下室逃生、人员密集场所的紧急疏散等。 深圳VR公司华锐视点研发的消防VR仿真教…