多重感知机MLP:Mnist

news2024/11/24 17:56:14

文章目录

    • 网络结构
    • 代码
      • common_utils.py
      • network.py
      • provider.py
      • train.py
      • test.py
      • visual.py
    • 实验
      • 训练结果
      • 测试结果
      • 可视化

网络结构

输入过程输出
28*28Flatten784
784Linear300
300Linear100
100Linear10

代码

文件结构:
在这里插入图片描述

common_utils.py

用来输出日志文件

# common_utils.py
import logging


def create_logger(log_file=None, rank=0, log_level=logging.INFO):
    logger = logging.getLogger(__name__)
    logger.setLevel(log_level if rank == 0 else 'ERROR')
    formatter = logging.Formatter('[%(asctime)s  %(filename)s %(lineno)d '
                                  '%(levelname)5s]  %(message)s')
    console = logging.StreamHandler()
    console.setLevel(log_level if rank == 0 else 'ERROR')
    console.setFormatter(formatter)
    logger.addHandler(console)
    if log_file is not None:
        file_handler = logging.FileHandler(filename=log_file)
        file_handler.setLevel(log_level if rank == 0 else 'ERROR')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger

network.py

设计MLP结构,包含训练函数train_model和评估函数eval_model

# network.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import provider

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc3(x)
        x = self.softmax(x)
        return x

    def train_model(self, args):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters(), lr=args.lr)
        scheduler = StepLR(optimizer, step_size=3, gamma=0.1)  # 学习率调度器
        train_loader = provider.GetLoader(batch_size=args.batch_size, loadType='train')
        test_loader = provider.GetLoader(batch_size=args.batch_size, loadType='test')

        best_accuracy = 0.0
        for epoch in range(args.epochs):
            self.train()
            running_loss = 0.0
            correct = 0
            total = 0

            for images, labels in train_loader:
                images = images.to(device)
                labels = labels.to(device)

                # 前向传播
                outputs = self(images)
                loss = criterion(outputs, labels)

                # 反向传播和优化
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # 统计准确率
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                running_loss += loss.item()

            train_loss = running_loss / len(train_loader)
            train_accuracy = correct / total

            # 在测试集上评估模型
            self.eval()
            test_loss = 0.0
            correct = 0
            total = 0

            with torch.no_grad():
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)

                    outputs = self(images)
                    loss = criterion(outputs, labels)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    test_loss += loss.item()

            test_loss = test_loss / len(test_loader)
            test_accuracy = correct / total

            # 更新学习率
            scheduler.step()

            # 保存在验证集上表现最好的模型
            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                torch.save({
                    'model_state_dict': self.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch': epoch,
                    'best_accuracy': best_accuracy,
                }, 'best_model.pth')

            # 打印训练过程中的损失和准确率
            args.logger.info(f"Epoch [{epoch+1}/{args.epochs}] - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Accuracy: {test_accuracy:.4f}")

        # 保存最后一个epoch的模型
        torch.save({
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'best_accuracy': best_accuracy,
        }, 'final_model.pth')

    def eval_model(self, dataloader):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        self.eval()
        total = 0
        correct = 0

        with torch.no_grad():
            for images, labels in dataloader:
                images = images.to(device)
                labels = labels.to(device)

                outputs = self(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        return accuracy

provider.py

包含数据读取函数GetLoader和数据可视化函数visualize_loader

#  provider.py
from sklearn.preprocessing import MinMaxScaler
import torch
import torchvision
import matplotlib.pyplot as plt




def visualize_loader(loader,model=None): 
    # batch=[32*1*28*28,32]
    for batch in loader:
        break
    fig, axes = plt.subplots(4, 8, figsize=(20, 10))
    imgs=batch[0]
    labels=batch[1].numpy()
    if model==None:
        imgName='train.png'
        predicted=labels
    else:
        imgName = 'test.png'
        outputs = model(imgs)
        _, predicted = torch.max(outputs.data, 1)
        predicted = predicted.numpy()
    imgs=imgs.squeeze().numpy()
    for i, ax in enumerate(axes.flat):
        ax.imshow(imgs[i])
        ax.set_title(predicted[i],color='black' if predicted[i]==labels[i] else 'red')
        ax.axis('off')
    plt.tight_layout()
    plt.show()
    plt.savefig(imgName)


# loader.shape=1875*[32*1*28*28,32]
def GetLoader(path='data',batch_size=32,loadType='train'):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])
    transfer=MinMaxScaler(feature_range=(0, 255)) 
    dataset = torchvision.datasets.MNIST(root=path, train=loadType=='train',transform=transform,download =False)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return loader

train.py

训练模型

# train.py
import argparse
import datetime
import common_utils
import os
import network
import provider
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--batch_size', type=int, default=32, required=False, help='batch size for training')
    parser.add_argument('--epochs', type=int, default=7, required=False, help='number of epochs to train for')
    parser.add_argument('--lr', type=float, default=0.01, required=False, help='learning rate')
    
    log_file = 'output/'+ ('log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
    logger = common_utils.create_logger(log_file)
    parser.add_argument('--logger', type=type(logger), default=logger, help='logger')
    
    args = parser.parse_args()

    return args

def main():
    args = parse_config()
    # log to file
    args.logger.info('**********************Start logging**********************')
    for key, val in vars(args).items():
        args.logger.info('{:16} {}'.format(key, val))

    args.logger.info('**********************Start training ********************')
    model = network.MLP()
    model.train_model(args)
    args.logger.info('**********************End training **********************')

    # Evaluate the trained model
    args.logger.info('**********************Start eval ************************')
    test_loader = provider.GetLoader(batch_size=args.batch_size, loadType='test')
    test_accuracy = model.eval_model(test_loader)
    args.logger.info(f'Test Accuracy: {test_accuracy:.4f}')
    args.logger.info('**********************End eval **************************')
    args.logger.info('**********************End *******************************\n')


if __name__ == '__main__':
    main()

test.py

测试模型

# test.py
import argparse
import datetime
import common_utils
import os
import network
import provider
import torch
import torch.nn as nn
import torch.optim as optim


def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--batch_size', type=int, default=32, required=False, help='batch size for training')
    parser.add_argument('--checkpoint', type=str, default='best_model.pth', help='checkpoint to start from')
    log_file = 'output/'+ ('log_test_%s.txt' % datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
    logger = common_utils.create_logger(log_file)
    parser.add_argument('--logger', type=type(logger), default=logger, help='checkpoint to start from')
    args = parser.parse_args()
    return args

def main():
    args= parse_config()
    args.logger.info('**********************Start logging**********************')
    for key, val in vars(args).items():
        args.logger.info('{:16} {}'.format(key, val))
    args.logger.info('**********************Start testing **********************')
    test(args)
    args.logger.info('**********************End testing ************************\n\n')

    
def test(args): 
    checkpoint = torch.load(args.checkpoint)
    model=network.MLP()
    model.load_state_dict(checkpoint['model_state_dict'])
    args.logger.info(model)
    test_loader=provider.GetLoader(batch_size=args.batch_size,loadType='test')  
    test_accuracy = model.eval_model(test_loader)
    args.logger.info(f'Test Accuracy: {test_accuracy:.4f}')


if __name__ == '__main__':
    main()


visual.py

可视化代码

# visual.py
import provider
import network
import torch


train_loader=provider.GetLoader(loadType='train')
provider.visualize_loader(train_loader)

test_loader=provider.GetLoader(loadType='test')
checkpoint = torch.load('best_model.pth')
model=network.MLP()
model.load_state_dict(checkpoint['model_state_dict'])
provider.visualize_loader(test_loader,model)


实验

训练结果

[2023-07-22 10:45:31,237  train.py 30  INFO]  **********************Start logging**********************
[2023-07-22 10:45:31,237  train.py 32  INFO]  batch_size       32
[2023-07-22 10:45:31,237  train.py 32  INFO]  epochs           7
[2023-07-22 10:45:31,237  train.py 32  INFO]  lr               0.01
[2023-07-22 10:45:31,237  train.py 32  INFO]  logger           <Logger common_utils (INFO)>
[2023-07-22 10:45:31,237  train.py 34  INFO]  **********************Start training ********************
[2023-07-22 10:45:46,963  network.py 106  INFO]  Epoch [1/7] - Train Loss: 0.5768, Train Accuracy: 0.8446, Test Accuracy: 0.9037
[2023-07-22 10:45:59,299  network.py 106  INFO]  Epoch [2/7] - Train Loss: 0.5059, Train Accuracy: 0.8759, Test Accuracy: 0.9299
[2023-07-22 10:46:11,687  network.py 106  INFO]  Epoch [3/7] - Train Loss: 0.4536, Train Accuracy: 0.8884, Test Accuracy: 0.9198
[2023-07-22 10:46:24,010  network.py 106  INFO]  Epoch [4/7] - Train Loss: 0.3161, Train Accuracy: 0.9196, Test Accuracy: 0.9502
[2023-07-22 10:46:36,307  network.py 106  INFO]  Epoch [5/7] - Train Loss: 0.2497, Train Accuracy: 0.9350, Test Accuracy: 0.9528
[2023-07-22 10:46:48,712  network.py 106  INFO]  Epoch [6/7] - Train Loss: 0.2280, Train Accuracy: 0.9395, Test Accuracy: 0.9549
[2023-07-22 10:47:01,138  network.py 106  INFO]  Epoch [7/7] - Train Loss: 0.2078, Train Accuracy: 0.9443, Test Accuracy: 0.9573
[2023-07-22 10:47:01,155  train.py 37  INFO]  **********************End training **********************
[2023-07-22 10:47:01,155  train.py 40  INFO]  **********************Start eval ************************
[2023-07-22 10:47:02,492  train.py 43  INFO]  Test Accuracy: 0.9573
[2023-07-22 10:47:02,493  train.py 44  INFO]  **********************End eval **************************
[2023-07-22 10:47:02,493  train.py 45  INFO]  **********************End *******************************

测试结果

[2023-07-22 10:50:46,173  test.py 24  INFO]  **********************Start logging**********************
[2023-07-22 10:50:46,173  test.py 26  INFO]  batch_size       32
[2023-07-22 10:50:46,173  test.py 26  INFO]  checkpoint       best_model.pth
[2023-07-22 10:50:46,173  test.py 26  INFO]  logger           <Logger common_utils (INFO)>
[2023-07-22 10:50:46,173  test.py 27  INFO]  **********************Start testing **********************
[2023-07-22 10:50:49,084  test.py 36  INFO]  MLP(
(flatten): Flatten(start_dim=1, end_dim=-1)
(fc1): Linear(in_features=784, out_features=300, bias=True)
(fc2): Linear(in_features=300, out_features=100, bias=True)
(fc3): Linear(in_features=100, out_features=10, bias=True)
(relu): ReLU()
(softmax): LogSoftmax(dim=1)
(dropout): Dropout(p=0.2, inplace=False)
)
[2023-07-22 10:50:50,970  test.py 39  INFO]  Test Accuracy: 0.9573
[2023-07-22 10:50:50,970  test.py 29  INFO]  **********************End testing ************************

可视化

测试结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/781209.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于扩展(EKF)和无迹卡尔曼滤波(UKF)的电力系统动态状态估计

1 主要内容 该程序对应文章《Power System Dynamic State Estimation Using Extended and Unscented Kalman Filters》&#xff0c;电力系统状态的准确估计对于提高电力系统的可靠性、弹性、安全性和稳定性具有重要意义&#xff0c;虽然近年来测量设备和传输技术的发展大大降低…

Linux常用嗅探工具(1):fping命令

fping的优点&#xff1a; 可以一次ping多个主机可以从主机列表文件ping结果清晰 便于脚本处理速度快 fping的安装&#xff1a; 前置安装cgg编译器 &#xff1a; yum -y install gcc 下载fping&#xff1a; wget http://fping.org/dist/fping-4.0.tar.gz 解压&#xff1a; …

力扣 -- 918. 环形子数组的最大和

一、题目&#xff1a; 题目链接&#xff1a;918. 环形子数组的最大和 - 力扣&#xff08;LeetCode&#xff09; 二、解题步骤&#xff1a; 下面是用动态规划的思想解决这道题的过程&#xff0c;相信各位小伙伴都能看懂并且掌握这道经典的动规题目滴。 三、参考代码&#xff1…

Redis 基础知识和核心概念解析:探索 Redis 的数据结构与存储方式

&#x1f337;&#x1f341; 博主 libin9iOak带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——libin9iOak的博客&#x1f390; &#x1f433; 《面试题大全》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33…

F---图像识别---河南省第十三届ICPC大学生程序设计竞赛

链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 来源&#xff1a;牛客网 输入 10 10 ........*. ........*. ........*. ........*. ....#...*. ........*. ........*. ********** ........*. ........*. 输出 -4 3 解析&#xff1a; 遍历整个二维数组&#xff0…

dubbo从基于注解方式转为基于xml配置方式的使用心得

过程中遇到的问题百分之九十九的问题都是因为版本不兼容问题&#xff0c;所以在引入依赖的时候要注意这点&#xff0c;可以从maven central repository官方仓库这里看所要引用版本与其可兼容的版本

畅想未来感汽车HMI设计的奇妙之旅!

当下智能电动汽车的发展势头越来越高涨&#xff0c;与智能电动汽车相关的汽车HMI设计也成为各个品牌重点发力的地方&#xff0c;汽车HMI设计正在前所未有的新高度&#xff0c;本篇文章就来聊聊HMI设计的那些事 ⬇⬇⬇点击获取更多设计资源 https://js.design/community?categ…

prometheus监控mysql8.x以及主从监控告警

mysql8.x主从部署请看下面文档 docker和yum安装的都有 Docker部署mysql8.x版本互为主从_争取不加班&#xff01;的博客-CSDN博客 Mysql8.x版本主从加读写分离&#xff08;一&#xff09; mysql8.x主从_myswl8双主一从读写分离_争取不加班&#xff01;的博客-CSDN博客 安装部署…

C++OpenCV(4):图像截取与掩膜操作

&#x1f506; 文章首发于我的个人博客&#xff1a;欢迎大佬们来逛逛 &#x1f506; OpenCV项目地址及源代码&#xff1a;点击这里 文章目录 图像截取图像掩膜操作 图像截取 ROI操作&#xff0c;指的是&#xff1a;region of interest&#xff0c;感兴趣区域。 我们可以对一张…

Vue 项目增加版本号输出, 可用于验证是否更新成功

webpack 1. vue.config.js 中增加以下配置, 此处以增加一个日期时间字符串为例, 具体内容可以根据自己需求自定义 // vue.config.js module.exports {chainWebpack(config) {config.plugin(define).tap(args > {args[0][process.env].APP_VERSION ${JSON.stringify(new …

行为型模式--模版方法模式(图文详解)

模版方法模式--图文详解 采摘机器人-场景体验模版方法模式-解决问题模版方法模式-定义优缺点优点缺点 采摘机器人-场景体验 今天看抖音上外国开始使用采摘苹果的机器人&#xff0c;我们模仿一下的他的大体流程&#xff1a; 主体采摘车进入苹果园进入苹果指定采摘地点&#xf…

通过自动化单元测试的形式守护系统架构

目录 0前言 1 背景 2 为什么选择 Archunit 3 Archunit 是什么 4 引入 Archunit 4.1 开始就是如此简单 4.2 如何组织架构规则 4.3 团队如何规范化 0前言 通过自动化单元测试的形式守护系统架构是一种有效的方式&#xff0c;可以确保系统在不断演进和修改的过程中保持稳…

Python实战之数据挖掘详解

一、Python数据挖掘 1.1 数据挖掘是什么&#xff1f; 数据挖掘是从大量的、不完全的、有噪声的、模糊的、随机的实际应用数据中&#xff0c;通过算法&#xff0c;找出其中的规律、知识、信息的过程。Python作为一门广泛应用的编程语言&#xff0c;拥有丰富的数据挖掘库&#…

数据分享|R语言逻辑回归、Naive Bayes贝叶斯、决策树、随机森林算法预测心脏病...

全文链接&#xff1a;http://tecdat.cn/?p23061 这个数据集&#xff08;查看文末了解数据免费获取方式&#xff09;可以追溯到1988年&#xff0c;由四个数据库组成。克利夫兰、匈牙利、瑞士和长滩。"目标 "字段是指病人是否有心脏病。它的数值为整数&#xff0c;0无…

盖子的c++小课堂——第二十讲:动态规划

前言 中间呢其实还有两讲&#xff0c;但是那两讲太easy了&#xff0c;根本难不倒你们&#xff0c;所以&#xff0c;我索性不放了~~那我们今天讲一个比较容易的知识点——动态规划&#xff08;终于没人给我催更了&#xff01;哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈…

IOC控制反转--.net framework

IOC控制反转–.net framework 分层架构&#xff1a; 一、传统依赖倒置实现 传统工艺&#xff1a;会有依赖&#xff0c;上端全部展示细节 BaseBll baseBll new BaseBll(); baseBll.DoSomething();依赖于抽象&#xff1a;左边依赖倒置&#xff0c;面向抽象 实现类继承接口&am…

React18和React16合成事件原理(附图)

&#x1f4a1; React18合成事件的处理原理 “绝对不是”给当前元素基于addEventListener做的事件绑定&#xff0c;React中的合成事件&#xff0c;都是基于“事件委托”处理的&#xff01; 在React17及以后版本&#xff0c;都是委托给#root这个容器&#xff08;捕获和冒泡都做了…

动态规划入门第1课

1、从计数到选择 ---- 递推与DP&#xff08;动态规划&#xff09; 2、从递归到记忆 ---- 子问题与去重复运算 3、动态规划的要点 第1题 网格路1(grid1) 小x住在左下角(0,0)处&#xff0c;小y在右上角(n,n)处。小x需要通过一段网格路才能到小y家。每次&#xff0c;小x可以选…

视频基础知识

1.视频比特率 视频的比特率是指传输过程中单位时间传输的数据量。可以理解为视频的编码采样率。单位是kbps&#xff0c;即每秒千比特。视频比特率是决定视频清晰度的一个重要指标。比特率越高&#xff0c;视频越清晰&#xff0c;但数据量也会越大。比如一部100分钟的电影&#…

5.4 Bootstrap 下拉菜单(Dropdown)插件

文章目录 Bootstrap 下拉菜单&#xff08;Dropdown&#xff09;插件用法在导航栏内在标签页内 选项方法 Bootstrap 下拉菜单&#xff08;Dropdown&#xff09;插件 Bootstrap 下拉菜单 这一章讲解了下拉菜单&#xff0c;但是没有涉及到交互部分&#xff0c;本章将具体讲解下拉菜…