深度学习笔记_4、CNN卷积神经网络+全连接神经网络解决MNIST数据

news2024/12/23 15:28:58

1、首先,导入所需的库和模块,包括NumPy、PyTorch、MNIST数据集、数据处理工具、模型层、优化器、损失函数、混淆矩阵、绘图工具以及数据处理工具。

import numpy as np
import torch
from torchvision.datasets import mnist
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import pandas as pd

2、设置超参数,包括训练批次大小、测试批次大小、学习率和训练周期数。

# 设置超参数
train_batch_size = 64
test_batch_size = 64
learning_rate = 0.001
num_epochs = 10

3、创建数据转换管道,将图像数据转换为张量并进行标准化。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

4、下载和预处理MNIST数据集,分为训练集和测试集。

# 下载和预处理数据集
train_dataset = mnist.MNIST('data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('data', train=False, transform=transform)

5、创建用于训练和测试的数据加载器,以便有效地加载数据。

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

 6、定义了一个简单的CNN模型,包括两个卷积层和两个全连接层。

# 定义CNN模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

7、初始化模型、优化器和损失函数。

# 初始化模型、优化器和损失函数
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

8、准备用于记录训练和测试过程中损失和准确率的列表。

# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []

9、进入训练循环,遍历每个训练周期。在每个训练周期内,进入训练模式,遍历训练数据批次,计算损失、反向传播并更新模型参数,同时记录训练损失和准确率。

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # 计算训练准确率
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

    # 计算平均训练损失和训练准确率
    train_loss /= len(train_loader)
    train_accuracy = 100. * correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)  # 记录训练准确率

    # 测试模型
    model.eval()
    test_loss = 0.0
    correct = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_labels.extend(target.numpy())
            all_preds.extend(pred.numpy())

10、在每个训练周期结束后,进入测试模式,遍历测试数据批次,计算测试损失和准确率,同时记录它们。打印每个周期的训练和测试损失以及准确率。

# 计算平均测试损失和测试准确率
    test_loss /= len(test_loader)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)

    print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')

11、losses、acces、eval_losses、eval_acces保存到TXT文件

# 保存训练结果
data = np.column_stack((train_losses,test_losses,train_accuracies, test_accuracies))
np.savetxt("results.txt", data)

12、绘制Loss、ACC图像

# 绘制Loss曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()

# 绘制Accuracy曲线图
plt.figure(figsize=(10, 2))
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()

 

 13、绘制混淆矩阵图像

# 计算混淆矩阵
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1060397.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode做题笔记160. 相交链表

给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据 保证 整个链式结构中不存在环。 注意,函数返回结果后&…

1300*C. Rumor(并查集贪心)

解析&#xff1a; 并查集&#xff0c;求每个集合的最小费用。 每次合并集合的时候&#xff0c;根节点保存当前集合最小的费用。 #include<bits/stdc.h> using namespace std; #define int long long const int N1e55; int n,m,a[N],p[N],cnt[N]; int find(int x){retur…

如何在springboot2中利用mybatis-plus进行分页查询操作。

1.创建配置mp的配置类 在mp的拦截器中加入分页拦截器 package com.example.config;import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor; import com.baomidou.mybatisplus.extension.plugins.inner.PaginationInnerInterceptor; import org.springfra…

微信小程序template界面模板导入

我们有些时候 会有一些比较大但并不复杂的界面结构 这个时候 你可以试试这种导入模板的形式 我们在根目录创建一个 template 目录 然后下面创建一个 text文件夹下面创建一个 test.wxml 参考代码如下 <template name"textIndex"><text class "testw&…

微服务技术栈-Nacos配置管理和Feign远程调用

文章目录 前言一、统一配置管理1.添加配置文件2.微服务拉取配置3.配置共享 三、Feign远程调用总结 前言 在上篇文章中介绍了微服务技术栈中Nacos这个组件的概念&#xff0c;Nacos除了可以做注册中心&#xff0c;同样可以做配置管理来使用。同时我们将学习一种新的远程调用方式…

阿里云免费服务器无法领取限制说明

阿里云提供免费服务器供用户申请&#xff0c;但是领取免费服务器是有条件的&#xff0c;并不是有所的阿里云用户均可领取免费云服务器&#xff0c;免费服务器领取条件为&#xff1a;账号从未使用过阿里云服务器的用户&#xff0c;阿里云百科来举例说明免费服务器领取说明&#…

强化学习环境 - robogym - 学习 - 1

强化学习环境 - robogym - 学习 - 1 项目地址 https://github.com/openai/robogym 为什么选择 robogym 自己的项目需要做一些机械臂 table-top 级的多任务操作 robogym 基于 mujoco 搭建&#xff0c;构建了一个仿真机械臂桌面物体操作&#xff08;pick-place、stack、rearr…

System Generator学习——将代码导入System Generator

文章目录 前言一、步骤 1&#xff1a;用 M-Code 建模控制1、引言2、目标3、步骤 二、步骤 2&#xff1a;用 HDL 建模模块1、引言2、目标3、步骤 三、用 C/C 代码建模块1、引言2、目标3、步骤4、第 1 部分&#xff1a;从 Vivado HLS 创建一个系统生成器包5、第 2 部分&#xff1…

《机器学习实战》学习记录-ch2

PS: 个人笔记&#xff0c;建议不看 原书资料&#xff1a;https://github.com/ageron/handson-ml2 2.1数据获取 import pandas as pd data pd.read_csv(r"C:\Users\cyan\Desktop\AI\ML\handson-ml2\datasets\housing\housing.csv")data.head() data.info()<clas…

优先级队列的模拟实现

目录 1. 优先级队列的概念 1.1堆的概念 1.2堆的性质 1.3堆的存储方式 2. 堆的创建 2.1堆的创建代码解析 2.2建堆的时间复杂度 2.3堆的插入 2.4 堆的删除 2.5常见习题 1. 优先级队列的概念 队列是一种先进先出 (FIFO) 的数据结构 &#xff0c;但有些情况下&#xff0c; 操作的数…

Allegro174版本如何关闭模块复用后铜皮自动从动态变成静态操作指导

Allegro174版本如何关闭模块复用后铜皮自动从动态变成静态操作指导 在用Allegro进行PCB设计的时候,模块复用是使用的十分频繁的操作,当Allegro升级到了174 S034版本的时候,当使用模块复用的功能的时候,模块内的铜皮会自动动静转换,大部分情况是不需要的。 如下图 如何关闭…

【再识C进阶4】详细介绍自定义类型——结构体、枚举和联合

学习目标&#xff1a; 在上一篇博客中&#xff0c;我们已经详细地学习了字符分类函数、字符转换函数和内存函数。那这一篇博客和上一篇博客的关系不是那么相连。 这一篇博客主要介绍一下自定义类型&#xff0c;因为在解决实际问题时&#xff0c;由于世界上的因素有很多&#xf…

01.爬虫基础

1、Python爬虫介绍 爬虫的实战性要求很强。爬虫经常需要爬取商业网站或政府网站的内容&#xff0c;而这些网站随时可能进行更新&#xff0c;另外网络原因和网站反爬虫机制也会对爬虫代码演示造成干扰。 1、1 爬虫的用处 网络爬虫&#xff1a;按照一定的规则&#xff0c;自动…

【Java 进阶篇】JDBC 管理事务详解

在数据库操作中&#xff0c;事务是一个非常重要的概念。事务可以确保一系列的数据库操作要么全部成功执行&#xff0c;要么全部失败回滚&#xff0c;以保持数据库的一致性和完整性。在 Java 中&#xff0c;我们可以使用 JDBC 来管理事务。本文将详细介绍 JDBC 管理事务的方法和…

【Java 进阶篇】JDBC 数据库连接池详解

数据库连接池是数据库连接的管理和复用工具&#xff0c;它可以有效地降低数据库连接和断开连接的开销&#xff0c;提高了数据库访问的性能和效率。在 Java 中&#xff0c;JDBC 数据库连接池是一个常见的实现方式&#xff0c;本文将详细介绍 JDBC 数据库连接池的使用和原理。 1…

vs2015 函数声明、定义与引用

10.VS-函数声明、定义和引用 - 简书 简言之&#xff0c;函数先在头文件中被声明&#xff0c;然后在对应cpp文件中实现&#xff08;定义&#xff09;&#xff0c;最后被不同文件的代码调用&#xff08;引用&#xff09;。

集合原理简记

HashMap 无论在构造函数是否指定数组长度&#xff0c;进行的都是延迟初始化 构造函数作用&#xff1a; 阈值&#xff1a;threshold&#xff0c;每次<<1 &#xff0c;数组长度 负载因子 无参构造&#xff1a;设置默认的负载因子 有参&#xff1a;可以指定初始容量或…

ES6中对象的扩展

1. 属性的简洁表示法 可以直接写入变量和函数作为对象的属性和方法。在对象中只写属性名&#xff0c;不写属性值&#xff0c;代表属性值等于和属性名相同的的变量的值。 属性的简写 let foo bar; let baz {foo}; // { foo: bar } // 等同于 let baz { foo: foo}方法的简写…

力扣 -- 377. 组合总和 Ⅳ

解题步骤&#xff1a; 参考代码&#xff1a; class Solution { public:int combinationSum4(vector<int>& nums, int target) {int nnums.size();vector<double> dp(target1);//初始化dp[0]1;//填表for(int i1;i<target;i){for(int j0;j<n;j){//填表if(…

Windows下启动freeRDP并自适应远端桌面大小

几个二进制文件 xfreerdp # Linux下的&#xff0c;an X11 Remote Desktop Protocol (RDP) client which is part of the FreeRDP project wfreerdp.exe # Windows下的&#xff0c;freerdp2.0 主程序&#xff0c;freerdp3.0将废弃 sdl-freerdp.exe # Windows下的&…