深度学习每周学习总结P1(pytorch手写数字识别)

news2025/1/17 6:10:36
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

目录

    • 0. 总结
    • 1. 数据导入部分
    • 2. 模型构建部分
    • 3. 训练前的准备
    • 4. 定义训练函数
    • 5. 定义测试函数
    • 6. 训练过程

0. 总结

总结:

  • 数据导入部分:数据导入使用了torchvision自带的数据集,获取到数据后需要使用torch.utils.data中的DataLoader()加载数据

  • 模型构建部分:有两个部分一个初始化部分(init())列出了网络结构的所有层,比如卷积层池化层等。第二个部分是前向传播部分,定义了数据在各层的处理过程。

  • 训练前的准备:在这之前需要定义损失函数,学习率,以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

  • 定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的。将 optimizer.zero_grad() 放在了每个批次处理的开始,这是最标准和常见的做法。这样可以确保每次迭代处理一个新批次时,梯度是从零开始累加的。准确率是通过累计预测正确的数量得到的,处理每个批次的数据后都要不断累加正确的个数,最终的准确率是由预测正确的数量除以所有样本得数量得到的。损失值也是类似每次循环都累计损失值,最终的损失值是总的损失值除以训练批次得到的

  • 定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

  • 训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

  • 模型的保存,调取及使用,暂时没有看到这部分,但是训练好的模型肯定是会用到这步的,需要自己添加进去。在PyTorch中,通常使用 torch.save(model.state_dict(), ‘model.pth’) 保存模型的参数,使用 model.load_state_dict(torch.load(‘model.pth’)) 加载参数。

1. 数据导入部分

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device
device(type='cpu')
print(torch.__version__) # 查看pytorch版本
1.9.0+cpu

# 导入数据
train_ds = torchvision.datasets.MNIST(
    'data', 
    train = True,
    transform = torchvision.transforms.ToTensor(),
    download = True
)
test_ds = torchvision.datasets.MNIST(
    'data',
    train = False,
    transform = torchvision.transforms.ToTensor(),
    download = True
)
batch_size = 32

# shuffle = True 意味着每次迭代数据集时,数据都会被随机打乱。
# 因此,当您从train_dl中获取一个批次的数据时,每次都可能得到不同的图像
train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle = True
)
test_dl = torch.utils.data.DataLoader(
    test_ds,
    batch_size = batch_size
)
# 取一个批次查看数据格式
# 数据的shape为:[batch_size,channel,height,weight]
# 其中batch_size为自己设定,channel,height,weight分别对应图片的通道数,高度和宽度
imgs,labels = next(iter(train_dl)) # 由于数据加载器被设置为随机打乱数据(shuffle=True),因此每次调用next函数时,都会从数据集中随机选择一个批次的数据。
imgs.shape
torch.Size([32, 1, 28, 28])
import numpy as np

#指定图片大小,图像大小为20宽,5高的绘图(单位为英寸inch)
plt.figure(figsize=(20,5))
for i,img in enumerate(imgs[:20]):
    # 维度缩减
    npimg = np.squeeze(img.numpy())
    plt.subplot(2,10,i+1) # 将整个figure分成2行10列,绘制第i+1个子图
    plt.imshow(npimg,cmap=plt.cm.binary)
    plt.axis('off') # 这行代码关闭了当前子图的坐标轴,使得图像没有任何坐标轴标签或刻度。

在这里插入图片描述

2. 模型构建部分

# 模型构建
import torch.nn.functional as F

num_classes = 10 # 图片的类别数

class Model(nn.Module):
    def __init__(self):
        super().__init__() # super(Model,self).__init__() 的简化写法
        # 特征提取网络
        self.conv1 = nn.Conv2d(1,32,kernel_size = 3) # 第一层卷积,卷积核大小为3*3
        self.pool1 = nn.MaxPool2d(2) # 设置池化层,池化核大小为2*2
        self.conv2 = nn.Conv2d(32,64,kernel_size = 3) # 第二层卷积,卷积核大小为3*3
        self.pool2 = nn.MaxPool2d(2)
        
        # 分类网络
        self.fc1 = nn.Linear(1600,64)
        self.fc2 = nn.Linear(64,num_classes)
    # 前向传播
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        
        x = torch.flatten(x,start_dim=1) # x.view(x.size(0), -1) 展平张量
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x
# !pip install torchinfo -i https://pypi.mirrors.ustc.edu.cn/simple/
# 查看模型结构
model = Model()
model
Model(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1600, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)
# 打印模型
from torchinfo import summary

model = Model().to(device) # 在指定的设备(device,可能是CPU或CUDA/GPU)上实例化原始Model。

summary(model) # 使用torchinfo库中的summary函数来打印模型的摘要。
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Model                                    --
├─Conv2d: 1-1                            320
├─MaxPool2d: 1-2                         --
├─Conv2d: 1-3                            18,496
├─MaxPool2d: 1-4                         --
├─Linear: 1-5                            102,464
├─Linear: 1-6                            650
=================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
=================================================================

3. 训练前的准备

# 训练模型

# 设置超参数
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2 # 学习率
opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

4. 定义训练函数

# 训练循环
def train(dataloader,model,loss_fn,optimizer):
    size = len(dataloader.dataset) # 训练集的大小,一共60000张图片
    num_batches = len(dataloader) # 批次数目,1875(60000/32)
    
    train_loss,train_acc = 0,0 # 初始化训练损失和正确率
    
    for X,y in dataloader: # 获取图片及其标签
        X,y = X.to(device),y.to(device)
        
        # 计算预测误差
        pred = model(X) # 网络输出
        loss = loss_fn(pred,y) # 计算网络输出值和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad() # grad属性归零
        loss.backward() # 反向传播
        optimizer.step() # 每一步自动更新
        
        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
    
    train_acc /= size
    train_loss /= num_batches
    
    return train_acc,train_loss

5. 定义测试函数

# 编写测试函数
def test(dataloader,model,loss_fn):
    size = len(dataloader.dataset) # 测试集的大小,一共10000张图片
    num_batches = len(dataloader) # 批次数目,313(10000/32=312.5)
    test_loss,test_acc = 0,0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target = imgs.to(device),target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred,target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
        
    test_acc /= size
    test_loss /= num_batches
    
    return test_acc,test_loss

6. 训练过程

# 训练
epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)
    
    model.eval()
    epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))
    
print('Done')
C:\Users\chengyuanting\.conda\envs\pytorch_cpu\lib\site-packages\torch\nn\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ..\c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: 1,Train_acc:77.5%,Train_loss:0.753,Test_acc:0.2%,Test_loss:0.000
Epoch: 2,Train_acc:94.3%,Train_loss:0.190,Test_acc:0.1%,Test_loss:0.000
Epoch: 3,Train_acc:96.2%,Train_loss:0.125,Test_acc:0.2%,Test_loss:0.000
Epoch: 4,Train_acc:96.9%,Train_loss:0.098,Test_acc:0.2%,Test_loss:0.000
Epoch: 5,Train_acc:97.5%,Train_loss:0.081,Test_acc:0.2%,Test_loss:0.000
Done
# 结果可视化
import matplotlib.pyplot as plt
# 隐藏警告
import warnings
warnings.filterwarnings("ignore") # 忽略警告信息
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1518961.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jmeter进行http接口测试

🍅 视频学习:文末有免费的配套视频可观看 🍅 关注公众号【互联网杂货铺】,回复 1 ,免费获取软件测试全套资料,资料在手,涨薪更快 本文主要针对http接口进行测试,使用 jmeter工具实现…

Jenkins自定义镜像推送到Harbor仓库

之前Jenkins需要推送jar包到远程到目标服务器再进行构建 现在Jenkins容器内部可以直接使用Docker了 直接在Jenkins构建好推送到Harbor仓库上,然后不管是哪台目标服务器,只需要去Harbor仓库拉取镜像即可。 修改Jenkins任务 将代码检测下面的远程服务器步…

UI 学习 三 可访问性 UX

设计、交流和实现不同领域内容的易访问性决策,涉及到一系列考虑因素,以达到更容易访问的产品体验。 Material使用的框架借鉴了WCAG标准和行业最佳实践,以帮助任何人预测、计划、记录和实现可访问体验。 下面描述的三个阶段有助于将可视化UI…

SIMATIC C7-635西门子触摸屏维修6ES7635-2EB02-0AE3

西门子工控机触摸屏维修C7-626/P西门子控制面板维修DP Panel 6ES7626-1DG03-0AE3;6ES7626-2DG04-0AE3 SIMATIC HMI移动面板维修(西门子面板)可轻松进行电源管理与操作,成为移动应用的可能之选。该面板支持线缆或 Wi-Fi 通信&…

Weblogic 常规渗透测试环境

测试环境 本环境模拟了一个真实的weblogic环境,其后台存在一个弱口令,并且前台存在任意文件读取漏洞。分别通过这两种漏洞,模拟对weblogic场景的渗透。 Weblogic版本:10.3.6(11g) Java版本:1.6 弱口令 环境启动后…

数字孪生与智慧城市:实现城市治理现代化的新路径

随着信息技术的迅猛发展,智慧城市已成为城市发展的必然趋势。数字孪生技术作为智慧城市建设的重要支撑,以其独特的优势为城市治理现代化提供了新的路径。本文将探讨数字孪生技术在智慧城市中的应用,以及如何实现城市治理的现代化。 一、数字…

OJ_还是畅通工程

题干 #include <iostream> #include <vector> #include <algorithm> using namespace std;//并查集的应用&#xff1a;判断图的连通性int set[10001]; //i下标是集合数据编号,set[i]是i的父亲的编号 //若i是根&#xff0c;可令set[i] i void InitDisjointSe…

Docker容器化技术(使用Dockerfile制作Nginx镜像)

编写Dockerfile制作Web应用系统nginx镜像&#xff0c;生成镜像名为nginx:v1.1&#xff0c;并推送其到私有仓库。 1、基于centos7基础镜像&#xff1b; 2、指定作者为Chinaskill&#xff1b; 3、安装nginx服务&#xff0c;将提供的dest目录传到镜像内&#xff0c;并将de…

案例分析篇00-【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】(2024年软考高级系统架构设计师冲刺知识点总结-案例分析篇-先导篇)

专栏系列文章&#xff1a; 2024高级系统架构设计师备考资料&#xff08;高频考点&真题&经验&#xff09;https://blog.csdn.net/seeker1994/category_12593400.html 案例分析篇01&#xff1a;软件架构设计考点架构风格及质量属性 案例分析篇11&#xff1a;UML设计考…

YOLO学习

至少不比原来差 网格大小&#xff1a;13、26、52 不同感受野匹配 损失函数是对数 二分类

通过路由器监控,优化网络效率

路由器是网络的基本连接组件&#xff0c;路由器监控涉及将路由器网络作为一个整体进行管理&#xff0c;其中持续监控路由器的性能、运行状况、安全性和可用性&#xff0c;以确保更好的操作和最短的停机时间&#xff0c;因此监控路由器至关重要。 为什么路由器监控对组织很重要…

面试相关问题准备

一.MySql篇 1优化相关问题 1.1MySql中如何定位慢查询&#xff1f; 慢查询的概念&#xff1a;在MySQL中&#xff0c;慢查询是指执行时间超过一定阈值的SQL语句。这个阈值是由long_query_time参数设定的&#xff0c;它的默认值是10秒1。也就是说&#xff0c;如果一条SQL语句的执…

物联网技术在农药化肥行业的远程监控解决方案

物联网技术在农药化肥行业的远程监控解决方案 随着物联网技术的日益成熟&#xff0c;其在农药化肥行业的应用呈现出广阔的前景。通过物联网远程监控解决方案&#xff0c;可以实现生产、存储和施用环节的全程智能化管理&#xff0c;大大提高行业效率和环保水平。 通过物联网云…

CleanMyMac X2024永久免费的强大的Mac清理工具

作为产品功能介绍专员&#xff0c;很高兴向您详细介绍CleanMyMac X这款强大的Mac清理工具。CleanMyMac X具有广泛的清理能力&#xff0c;支持多种文件类型的清理&#xff0c;让您的Mac始终保持最佳状态。 系统垃圾 CleanMyMac X能够深入系统内部&#xff0c;智能识别并清理各种…

软件设计师:12 - 下午题历年真题

章节章节01-计算机组成原理与体系结构07 - 法律法规与标准化与多媒体基础02 - 操作系统基本原理08 - 设计模式03 - 数据库系统09 - 软件工程04 - 计算机网络10 - 面向对象05 - 数据结构与算法11 - 结构化开发与UML06 - 程序设计语言与语言处理程序基础12 - 下午题历年真题End -…

Java_12 杨辉三角 II

杨辉三角 II 给定一个非负索引 rowIndex&#xff0c;返回「杨辉三角」的第 rowIndex 行。 在「杨辉三角」中&#xff0c;每个数是它左上方和右上方的数的和。 示例 1: 输入: rowIndex 3 输出: [1,3,3,1] 示例 2: 输入: rowIndex 0 输出: [1] 示例 3: 输入: rowIndex 1 输…

llama2 代码实验记录

torchrun分布式启动&#xff0c;所以要想在云端的环境下在本地的IDE上debug&#xff0c;需要设置一下&#xff0c;具体可以参考这里&#xff0c;需要传入的路径参数全部使用绝对路径。 目录 1、传入的句子 2、tokenizer 的tokenization 3、model的主要组成部分 4、过程中自…

如何关闭 Visual Studio 双击异常高亮

[问题描述]&#xff1a; 最近 Visual Studio 更新后&#xff0c;双击选中关键字快要亮瞎我的眼睛了 &#x1f440;&#x1f440; [解决方法]&#xff1a; 摸索了一下&#xff0c;找到了关闭的方法&#xff1a;工具 → 选项 → 文本编辑器 → 常规&#xff0c;然后取消 勾选 sel…

【进阶五】Python实现SDVRP(需求拆分)常见求解算法——蚁群算法(ACO)

基于python语言&#xff0c;采用经典遗传算法&#xff08;ACO&#xff09;对 需求拆分车辆路径规划问题&#xff08;SDVRP&#xff09; 进行求解。 目录 往期优质资源1. 适用场景2. 代码调整3. 求解结果4. 代码片段参考 往期优质资源 经过一年多的创作&#xff0c;目前已经成熟…

掼蛋俗语40条

1、以牌会友&#xff0c;天长地久 ​2、智慧掼蛋&#xff0c;欢乐无限 ​3、尊重对手&#xff0c;信任队友 ​4、小掼蛋&#xff0c;大世界 ​5、掼蛋敢做主&#xff0c;不会太受苦 ​6、掼蛋不出手&#xff0c;平时定保守 ​7、精彩掼蛋&#xff0c;生活精彩 ​8、细心…